Spaces:
Sleeping
Sleeping
import os | |
import yaml | |
from argparse import Action | |
from ast import literal_eval | |
from torch.cuda import is_available | |
from torch import get_num_threads, set_num_threads | |
CACHE_DIR = os.getenv( | |
"AUTOT_CACHE", | |
os.path.expanduser("~/.cache/torch/models"), | |
) | |
os.environ["PYANNOTE_CACHE"] = os.getenv( | |
"PYANNOTE_CACHE", | |
os.path.join(CACHE_DIR, "pyannote"), | |
) | |
WHISPER_DEFAULT_PATH = os.path.join(CACHE_DIR, "whisper") | |
PYANNOTE_DEFAULT_PATH = os.path.join(CACHE_DIR, "pyannote") | |
# PYANNOTE_DEFAULT_CONFIG = ('pyannote/speaker-diarization-3.1','Jaikinator/ScrAIbe') | |
PYANNOTE_DEFAULT_CONFIG = os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml") \ | |
if os.path.exists(os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml")) \ | |
else ('Jaikinator/ScrAIbe', 'pyannote/speaker-diarization-3.1') | |
SCRAIBE_TORCH_DEVICE = os.getenv("SCRAIBE_TORCH_DEVICE", "cuda" if is_available() else "cpu") | |
SCRAIBE_NUM_THREADS = os.getenv("SCRAIBE_NUM_THREADS", min(8, get_num_threads())) | |
def config_diarization_yaml(file_path: str, path_to_segmentation: str = None) -> None: | |
"""Configure diarization pipeline from a YAML file. | |
This function updates the YAML file to use the given segmentation model | |
offline, and avoids manual file manipulation. | |
Args: | |
file_path (str): Path to the YAML file. | |
path_to_segmentation (str, optional): Optional path to the segmentation model. | |
Raises: | |
FileNotFoundError: If the segmentation model file is not found. | |
""" | |
with open(file_path, "r") as stream: | |
yml = yaml.safe_load(stream) | |
segmentation_path = path_to_segmentation or os.path.join( | |
PYANNOTE_DEFAULT_PATH, "pytorch_model.bin") | |
yml["pipeline"]["params"]["segmentation"] = segmentation_path | |
if not os.path.exists(segmentation_path): | |
raise FileNotFoundError( | |
f"Segmentation model not found at {segmentation_path}") | |
with open(file_path, "w") as stream: | |
yaml.dump(yml, stream) | |
def set_threads(parse_threads=None, | |
yaml_threads=None): | |
global SCRAIBE_NUM_THREADS | |
if parse_threads is not None: | |
if not isinstance(parse_threads, int): | |
# probably covered with int type of parser arg | |
raise ValueError(f"Type of --num-threads must be int, but the type is {type(parse_threads)}") | |
elif parse_threads < 1: | |
raise ValueError(f"Number of threads must be a positive integer, {parse_threads} was given") | |
else: | |
set_num_threads(parse_threads) | |
SCRAIBE_NUM_THREADS = parse_threads | |
elif yaml_threads is not None: | |
if not isinstance(yaml_threads, int): | |
raise ValueError(f"Type of num_threads must be int, but the type is {type(yaml_threads)}") | |
elif yaml_threads < 1: | |
raise ValueError(f"Number of threads must be a positive integer, {yaml_threads} was given") | |
else: | |
set_num_threads(yaml_threads) | |
SCRAIBE_NUM_THREADS = yaml_threads | |
class ParseKwargs(Action): | |
""" | |
Custom argparse action to parse keyword arguments. | |
""" | |
def __call__(self, parser, namespace, values, option_string=None): | |
setattr(namespace, self.dest, dict()) | |
for value in values: | |
key, value = value.split('=') | |
try: | |
value = literal_eval(value) | |
except: | |
pass | |
getattr(namespace, self.dest)[key] = value | |