Created
December 27, 2023 03:48
-
-
Save gaspardpetit/e2af3728d922239e0a6ec80e53fb5f58 to your computer and use it in GitHub Desktop.
Revisions
-
gaspardpetit created this gist
Dec 27, 2023 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,128 @@ import os import torch import torchaudio import logging from pyannote.audio import Pipeline from pyannote.core import Annotation from pyannote.audio.pipelines.speaker_diarization import SpeakerDiarization from pyannote.audio.pipelines.utils.hook import ProgressHook # Environment variable to use for setting the HuggingFace Token ENV_HUGGINGFACE_TOKEN : str = "HUGGINGFACE_TOKEN" LOG: logging.Logger = logging.getLogger(__name__) class DiarizationPipeline(object): """ DiarizationPipeline is a singleton class that represents the diarization pipeline. Attributes: - device (torch.device): The device (CPU or CUDA) to use for diarization. - pipeline (Pipeline): The diarization pipeline. Methods: - _get_huggingface_token(): Retrieves the Hugging Face token from the environment variables. - _get_device(): Retrieves the device to use for diarization. - _load_pipeline(device): Loads the diarization pipeline. - _init_once(): Initializes the DiarizationPipeline singleton instance. """ __instance = None def __new__(cls): if cls.__instance is None: cls.__instance = super(DiarizationPipeline, cls).__new__(cls) cls.__instance._init_once() return cls.__instance @staticmethod def _get_huggingface_token() -> str: """Retrieves the Hugging Face token from the environment variables.""" token = os.getenv(ENV_HUGGINGFACE_TOKEN) if not token: raise EnvironmentError(f"{ENV_HUGGINGFACE_TOKEN} environment variable is not set") return token @staticmethod def _get_device() -> torch.device: """Retrieves the device (CPU or CUDA) to use for diarization.""" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") LOG.info(f"Using device: {device.type}") return device @staticmethod def _load_pipeline(device: torch.device) -> Pipeline: """Loads the diarization pipeline.""" pipeline = Pipeline.from_pretrained( "pyannote/speaker-diarization-3.1", use_auth_token=DiarizationPipeline._get_huggingface_token() ).to(device) return pipeline def _init_once(self): """Initializes the DiarizationPipeline singleton instance.""" self.device : torch.device = DiarizationPipeline._get_device() self.pipeline : Pipeline = DiarizationPipeline._load_pipeline(self.device) class DiarizationService: """ Diarization is a class that performs speaker diarization on audio files. Attributes: - diarization (DiarizationPipeline): The diarization pipeline instance. Methods: - _load_audio(audio_path): Loads the audio waveform and sample rate from an audio file. - diarize(audio_path): Performs diarization on the specified audio file. """ def __init__(self): self.diarization: DiarizationPipeline = DiarizationPipeline() @staticmethod def _load_audio(audio_path: str) -> tuple: """Loads the audio waveform and sample rate from an audio file.""" waveform, sample_rate = torchaudio.load(audio_path) return waveform, sample_rate def diarize(self, audio_path: str) -> Annotation: """ Performs diarization on the specified audio file. Returns: - diarization: The diarization results. """ waveform, sample_rate = DiarizationService._load_audio(audio_path) with ProgressHook() as hook: diarization : Annotation = self.diarization.pipeline({ "waveform": waveform, "sample_rate": sample_rate }, hook=hook) # Set the diarization file id from urllib.parse import quote audio_name_without_extension = os.path.splitext(os.path.basename(audio_path))[0] diarization.uri = quote(audio_name_without_extension) return diarization def main(): """ The main function that performs diarization on an audio file. """ logging.basicConfig(level=logging.DEBUG, format='%(asctime)s [%(levelname)s][%(filename)s:%(lineno)d][%(funcName)s] %(message)s', datefmt='%Y-%m-%dT%H:%M:%SZ') audio_file = "test.wav" if (ENV_HUGGINGFACE_TOKEN not in os.environ): os.environ[ENV_HUGGINGFACE_TOKEN] = '<huggingface_token>' diarization = DiarizationService().diarize(audio_file) turn_str = "" for turn, _, speaker in diarization.itertracks(yield_label=True): turn_str += f"start={turn.start:.1f}s stop={turn.end:.1f}s speaker_{speaker}\n" LOG.info(turn_str) rttm_name = diarization.uri + ".rttm" LOG.info(f"saving to {rttm_name}") with open(rttm_name, "w") as rttm: diarization.write_rttm(rttm) if __name__ == "__main__": main()