Skip to content

Instantly share code, notes, and snippets.

@gaspardpetit
Created December 27, 2023 03:48
Show Gist options
  • Select an option

  • Save gaspardpetit/e2af3728d922239e0a6ec80e53fb5f58 to your computer and use it in GitHub Desktop.

Select an option

Save gaspardpetit/e2af3728d922239e0a6ec80e53fb5f58 to your computer and use it in GitHub Desktop.

Revisions

  1. gaspardpetit created this gist Dec 27, 2023.
    128 changes: 128 additions & 0 deletions diarization_pyannote.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,128 @@
    import os
    import torch
    import torchaudio
    import logging
    from pyannote.audio import Pipeline
    from pyannote.core import Annotation
    from pyannote.audio.pipelines.speaker_diarization import SpeakerDiarization
    from pyannote.audio.pipelines.utils.hook import ProgressHook

    # Environment variable to use for setting the HuggingFace Token
    ENV_HUGGINGFACE_TOKEN : str = "HUGGINGFACE_TOKEN"
    LOG: logging.Logger = logging.getLogger(__name__)

    class DiarizationPipeline(object):
    """
    DiarizationPipeline is a singleton class that represents the diarization pipeline.
    Attributes:
    - device (torch.device): The device (CPU or CUDA) to use for diarization.
    - pipeline (Pipeline): The diarization pipeline.
    Methods:
    - _get_huggingface_token(): Retrieves the Hugging Face token from the environment variables.
    - _get_device(): Retrieves the device to use for diarization.
    - _load_pipeline(device): Loads the diarization pipeline.
    - _init_once(): Initializes the DiarizationPipeline singleton instance.
    """
    __instance = None

    def __new__(cls):
    if cls.__instance is None:
    cls.__instance = super(DiarizationPipeline, cls).__new__(cls)
    cls.__instance._init_once()
    return cls.__instance

    @staticmethod
    def _get_huggingface_token() -> str:
    """Retrieves the Hugging Face token from the environment variables."""
    token = os.getenv(ENV_HUGGINGFACE_TOKEN)
    if not token:
    raise EnvironmentError(f"{ENV_HUGGINGFACE_TOKEN} environment variable is not set")
    return token

    @staticmethod
    def _get_device() -> torch.device:
    """Retrieves the device (CPU or CUDA) to use for diarization."""
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    LOG.info(f"Using device: {device.type}")
    return device

    @staticmethod
    def _load_pipeline(device: torch.device) -> Pipeline:
    """Loads the diarization pipeline."""
    pipeline = Pipeline.from_pretrained(
    "pyannote/speaker-diarization-3.1",
    use_auth_token=DiarizationPipeline._get_huggingface_token()
    ).to(device)
    return pipeline

    def _init_once(self):
    """Initializes the DiarizationPipeline singleton instance."""
    self.device : torch.device = DiarizationPipeline._get_device()
    self.pipeline : Pipeline = DiarizationPipeline._load_pipeline(self.device)

    class DiarizationService:
    """
    Diarization is a class that performs speaker diarization on audio files.
    Attributes:
    - diarization (DiarizationPipeline): The diarization pipeline instance.
    Methods:
    - _load_audio(audio_path): Loads the audio waveform and sample rate from an audio file.
    - diarize(audio_path): Performs diarization on the specified audio file.
    """
    def __init__(self):
    self.diarization: DiarizationPipeline = DiarizationPipeline()

    @staticmethod
    def _load_audio(audio_path: str) -> tuple:
    """Loads the audio waveform and sample rate from an audio file."""
    waveform, sample_rate = torchaudio.load(audio_path)
    return waveform, sample_rate

    def diarize(self, audio_path: str) -> Annotation:
    """
    Performs diarization on the specified audio file.
    Returns:
    - diarization: The diarization results.
    """
    waveform, sample_rate = DiarizationService._load_audio(audio_path)
    with ProgressHook() as hook:
    diarization : Annotation = self.diarization.pipeline({
    "waveform": waveform,
    "sample_rate": sample_rate
    },
    hook=hook)

    # Set the diarization file id
    from urllib.parse import quote
    audio_name_without_extension = os.path.splitext(os.path.basename(audio_path))[0]
    diarization.uri = quote(audio_name_without_extension)

    return diarization

    def main():
    """
    The main function that performs diarization on an audio file.
    """
    logging.basicConfig(level=logging.DEBUG, format='%(asctime)s [%(levelname)s][%(filename)s:%(lineno)d][%(funcName)s] %(message)s', datefmt='%Y-%m-%dT%H:%M:%SZ')
    audio_file = "test.wav"
    if (ENV_HUGGINGFACE_TOKEN not in os.environ):
    os.environ[ENV_HUGGINGFACE_TOKEN] = '<huggingface_token>'

    diarization = DiarizationService().diarize(audio_file)
    turn_str = ""
    for turn, _, speaker in diarization.itertracks(yield_label=True):
    turn_str += f"start={turn.start:.1f}s stop={turn.end:.1f}s speaker_{speaker}\n"
    LOG.info(turn_str)

    rttm_name = diarization.uri + ".rttm"
    LOG.info(f"saving to {rttm_name}")
    with open(rttm_name, "w") as rttm:
    diarization.write_rttm(rttm)

    if __name__ == "__main__":
    main()