from __future__ import annotations import os import pathlib import gradio as gr import numpy as np import torch import torchaudio from fairseq2.assets import InProcAssetMetadataProvider, asset_store from huggingface_hub import snapshot_download from seamless_communication.inference import Translator from lang_list import ( ASR_TARGET_LANGUAGE_NAMES, LANGUAGE_NAME_TO_CODE, S2ST_TARGET_LANGUAGE_NAMES, S2TT_TARGET_LANGUAGE_NAMES, T2ST_TARGET_LANGUAGE_NAMES, T2TT_TARGET_LANGUAGE_NAMES, TEXT_SOURCE_LANGUAGE_NAMES, ) CHECKPOINTS_PATH = pathlib.Path(os.getenv("CHECKPOINTS_PATH", "/home/user/app/models")) if not CHECKPOINTS_PATH.exists(): snapshot_download(repo_id="facebook/seamless-m4t-v2-large", repo_type="model", local_dir=CHECKPOINTS_PATH) asset_store.env_resolvers.clear() asset_store.env_resolvers.append(lambda: "demo") demo_metadata = [ { "name": "seamlessM4T_v2_large@demo", "checkpoint": f"file://{CHECKPOINTS_PATH}/seamlessM4T_v2_large.pt", "char_tokenizer": f"file://{CHECKPOINTS_PATH}/spm_char_lang38_tc.model", }, { "name": "vocoder_v2@demo", "checkpoint": f"file://{CHECKPOINTS_PATH}/vocoder_v2.pt", }, ] asset_store.metadata_providers.append(InProcAssetMetadataProvider(demo_metadata)) if torch.cuda.is_available(): device = torch.device("cuda:0") dtype = torch.float16 else: device = torch.device("cpu") dtype = torch.float32 translator = Translator( model_name_or_card="seamlessM4T_v2_large", vocoder_name_or_card="vocoder_v2", device=device, dtype=dtype, apply_mintox=True, ) def run_t2tt(input_text: str, source_language: str, target_language: str) -> str: source_language_code = LANGUAGE_NAME_TO_CODE[source_language] target_language_code = LANGUAGE_NAME_TO_CODE[target_language] out_texts, _ = translator.predict( input=input_text, task_str="T2TT", src_lang=source_language_code, tgt_lang=target_language_code, ) return str(out_texts[0]) import runpod def runpod_handler(job): job_input = job['input'] input_text = job_input["input_text"] source_language = job_input["source_language"] target_language = job_input["target_language"] return run_t2tt(input_text, source_language, target_language) runpod.serverless.start({"handler": runpod_handler})