Spaces:
Sleeping
Sleeping
""" | |
Diarisation Class | |
------------------ | |
This class serves as the heart of the speaker diarization system, responsible for identifying | |
and segmenting individual speakers from a given audio file. It leverages a pretrained model | |
from pyannote.audio, providing an accessible interface for audio processing tasks such as | |
speaker separation, and timestamping. | |
By encapsulating the complexities of the underlying model, it allows for straightforward | |
integration into various applications, ranging from transcription services to voice assistants. | |
Available Classes: | |
- Diariser: Main class for performing speaker diarization. | |
Includes methods for loading models, processing audio files, | |
and formatting the diarization output. | |
Constants: | |
- TOKEN_PATH (str): Path to the Pyannote token. | |
- PYANNOTE_DEFAULT_PATH (str): Default path to Pyannote models. | |
- PYANNOTE_DEFAULT_CONFIG (str): Default configuration for Pyannote models. | |
Usage: | |
from .diarisation import Diariser | |
model = Diariser.load_model(model="path/to/model/config.yaml") | |
diarisation_output = model.diarization("path/to/audiofile.wav") | |
""" | |
import warnings | |
import os | |
import yaml | |
from pathlib import Path | |
from typing import TypeVar, Union | |
from pyannote.audio import Pipeline | |
from pyannote.audio.pipelines.speaker_diarization import SpeakerDiarization | |
from torch import Tensor | |
from torch import device as torch_device | |
from huggingface_hub import HfApi | |
from huggingface_hub.utils import RepositoryNotFoundError | |
from .misc import PYANNOTE_DEFAULT_PATH, PYANNOTE_DEFAULT_CONFIG, SCRAIBE_TORCH_DEVICE | |
Annotation = TypeVar('Annotation') | |
TOKEN_PATH = os.path.join(os.path.dirname( | |
os.path.realpath(__file__)), '.pyannotetoken') | |
class Diariser: | |
""" | |
Handles the diarization process of an audio file using a pretrained model | |
from pyannote.audio. Diarization is the task of determining "who spoke when." | |
Args: | |
model: The pretrained model to use for diarization. | |
""" | |
def __init__(self, model) -> None: | |
self.model = model | |
def diarization(self, audiofile: Union[str, Tensor, dict], | |
*args, **kwargs) -> Annotation: | |
""" | |
Perform speaker diarization on the provided audio file, | |
effectively separating different speakers | |
and providing a timestamp for each segment. | |
Args: | |
audiofile: The path to the audio file or a torch.Tensor | |
containing the audio data. | |
args: Additional arguments for the diarization model. | |
kwargs: Additional keyword arguments for the diarization model. | |
Returns: | |
dict: A dictionary containing speaker names, | |
segments, and other information related | |
to the diarization process. | |
""" | |
kwargs = self._get_diarisation_kwargs(**kwargs) | |
diarization = self.model(audiofile, *args, **kwargs) | |
out = self.format_diarization_output(diarization) | |
return out | |
# @staticmethod | |
# def format_diarization_output(dia: Annotation) -> dict: | |
# """ | |
# Formats the raw diarization output into a more usable structure for this project. | |
# Args: | |
# dia: Raw diarization output. | |
# Returns: | |
# dict: A structured representation of the diarization, with speaker names | |
# as keys and a list of tuples representing segments as values. | |
# """ | |
# dia_list = list(dia.itertracks(yield_label=True)) | |
# diarization_output = {"speakers": [], "segments": []} | |
# normalized_output = [] | |
# index_start_speaker = 0 | |
# index_end_speaker = 0 | |
# current_speaker = str() | |
# ### | |
# # Sometimes two consecutive speakers are the same | |
# # This loop removes these duplicates | |
# ### | |
# if len(dia_list) == 1: | |
# normalized_output.append([0, 0, dia_list[0][2]]) | |
# else: | |
# for i, (_, _, speaker) in enumerate(dia_list): | |
# if i == 0: | |
# current_speaker = speaker | |
# if speaker != current_speaker: | |
# index_end_speaker = i - 1 | |
# normalized_output.append([index_start_speaker, | |
# index_end_speaker, | |
# current_speaker]) | |
# index_start_speaker = i | |
# current_speaker = speaker | |
# if i == len(dia_list) - 1: | |
# index_end_speaker = i | |
# normalized_output.append([index_start_speaker, | |
# index_end_speaker, | |
# current_speaker]) | |
# for outp in normalized_output: | |
# start = dia_list[outp[0]][0].start | |
# end = dia_list[outp[1]][0].end | |
# diarization_output["segments"].append([start, end]) | |
# diarization_output["speakers"].append(outp[2]) | |
# return diarization_output | |
def format_diarization_output(dia: Annotation) -> dict: | |
""" | |
Formats the raw diarization output into a more usable structure for this project, | |
without combining consecutive segments of the same speaker. | |
Args: | |
dia: Raw diarization output. | |
Returns: | |
dict: A structured representation of the diarization, with speaker names | |
as keys and a list of tuples representing segments as values. | |
""" | |
dia_list = list(dia.itertracks(yield_label=True)) | |
diarization_output = {"speakers": [], "segments": []} | |
for segment, _, speaker in dia_list: | |
start = segment.start | |
end = segment.end | |
diarization_output["segments"].append([start, end]) | |
diarization_output["speakers"].append(speaker) | |
return diarization_output | |
def _get_token(): | |
""" | |
Retrieves the Huggingface token from a local file. This token is required | |
for accessing certain online resources. | |
Raises: | |
ValueError: If the token is not found. | |
Returns: | |
str: The Huggingface token. | |
""" | |
if os.path.exists(TOKEN_PATH): | |
with open(TOKEN_PATH, 'r', encoding="utf-8") as file: | |
token = file.read() | |
else: | |
raise ValueError('No token found.' | |
'Please create a token at https://huggingface.co/settings/token' | |
f'and save it in a file called {TOKEN_PATH}') | |
return token | |
def _save_token(token): | |
""" | |
Saves the provided Huggingface token to a local file. This facilitates future | |
access to online resources without needing to repeatedly authenticate. | |
Args: | |
token: The Huggingface token to save. | |
""" | |
with open(TOKEN_PATH, 'w', encoding="utf-8") as file: | |
file.write(token) | |
def load_model(cls, | |
model: str = PYANNOTE_DEFAULT_CONFIG, | |
use_auth_token: str = None, | |
cache_token: bool = False, | |
cache_dir: Union[Path, str] = PYANNOTE_DEFAULT_PATH, | |
hparams_file: Union[str, Path] = None, | |
device: str = SCRAIBE_TORCH_DEVICE, | |
) -> Pipeline: | |
""" | |
Loads a pretrained model from pyannote.audio, | |
either from a local cache or some online repository. | |
Args: | |
model: Path or identifier for the pyannote model. | |
default: '/home/[user]/.cache/torch/models/pyannote/config.yaml' | |
or one of 'jaikinator/scraibe', 'pyannote/speaker-diarization-3.1' | |
token: Optional HUGGINGFACE_TOKEN for authenticated access. | |
cache_token: Whether to cache the token locally for future use. | |
cache_dir: Directory for caching models. | |
hparams_file: Path to a YAML file containing hyperparameters. | |
device: Device to load the model on. | |
args: Additional arguments only to avoid errors. | |
kwargs: Additional keyword arguments only to avoid errors. | |
Returns: | |
Pipeline: A pyannote.audio Pipeline object, encapsulating the loaded model. | |
""" | |
if isinstance(model, str) and os.path.exists(model): | |
# check if model can be found locally nearby the config file | |
with open(model, 'r') as file: | |
config = yaml.safe_load(file) | |
path_to_model = config['pipeline']['params']['segmentation'] | |
if not os.path.exists(path_to_model): | |
warnings.warn(f"Model not found at {path_to_model}. " | |
"Trying to find it nearby the config file.") | |
pwd = model.split("/")[:-1] | |
pwd = "/".join(pwd) | |
path_to_model = os.path.join(pwd, "pytorch_model.bin") | |
if not os.path.exists(path_to_model): | |
warnings.warn(f"Model not found at {path_to_model}. \ | |
'Trying to find it nearby .bin files instead.") | |
warnings.warn( | |
'Searching for nearby files in a folder path is ' | |
'deprecated and will be removed in future versions.', | |
category=DeprecationWarning) | |
# list elementes with the ending .bin | |
bin_files = [f for f in os.listdir( | |
pwd) if f.endswith(".bin")] | |
if len(bin_files) == 1: | |
path_to_model = os.path.join(pwd, bin_files[0]) | |
else: | |
warnings.warn("Found more than one .bin file. " | |
"or none. Please specify the path to the model " | |
"or setup a huggingface token.") | |
raise FileNotFoundError | |
warnings.warn( | |
f"Found model at {path_to_model} overwriting config file.") | |
config['pipeline']['params']['segmentation'] = path_to_model | |
with open(model, 'w') as file: | |
yaml.dump(config, file) | |
elif isinstance(model, tuple): | |
try: | |
_model = model[0] | |
HfApi().model_info(_model) | |
model = _model | |
use_auth_token = None | |
except RepositoryNotFoundError: | |
print(f'{model[0]} not found on Huggingface, \ | |
trying {model[1]}') | |
_model = model[1] | |
HfApi().model_info(_model) | |
model = _model | |
if cache_token and use_auth_token is not None: | |
cls._save_token(use_auth_token) | |
if use_auth_token is None: | |
use_auth_token = cls._get_token() | |
else: | |
raise FileNotFoundError( | |
f'No local model or directory found at {model}.') | |
_model = Pipeline.from_pretrained(model, | |
use_auth_token=use_auth_token, | |
cache_dir=cache_dir, | |
hparams_file=hparams_file,) | |
if _model is None: | |
raise ValueError('Unable to load model either from local cache' | |
'or from huggingface.co models. Please check your token' | |
'or your local model path') | |
# torch_device is renamed from torch.device to avoid name conflict | |
_model = _model.to(torch_device(device)) | |
return cls(_model) | |
def _get_diarisation_kwargs(**kwargs) -> dict: | |
""" | |
Validates and extracts the keyword arguments for the pyannote diarization model. | |
Ensures that the provided keyword arguments match the expected parameters, | |
filtering out any invalid or unnecessary arguments. | |
Returns: | |
dict: A dictionary containing the validated keyword arguments. | |
""" | |
_possible_kwargs = SpeakerDiarization.apply.__code__.co_varnames | |
diarisation_kwargs = {k: v for k, | |
v in kwargs.items() if k in _possible_kwargs} | |
return diarisation_kwargs | |
def __repr__(self): | |
return f"Diarisation(model={self.model})" | |