Spaces:
Runtime error
Runtime error
from __future__ import annotations | |
import os | |
import pathlib | |
import gradio as gr | |
import numpy as np | |
import torch | |
import torchaudio | |
from fairseq2.assets import InProcAssetMetadataProvider, asset_store | |
from huggingface_hub import snapshot_download | |
from seamless_communication.inference import Translator | |
from lang_list import ( | |
ASR_TARGET_LANGUAGE_NAMES, | |
LANGUAGE_NAME_TO_CODE, | |
S2ST_TARGET_LANGUAGE_NAMES, | |
S2TT_TARGET_LANGUAGE_NAMES, | |
T2ST_TARGET_LANGUAGE_NAMES, | |
T2TT_TARGET_LANGUAGE_NAMES, | |
TEXT_SOURCE_LANGUAGE_NAMES, | |
) | |
CHECKPOINTS_PATH = pathlib.Path(os.getenv("CHECKPOINTS_PATH", "/home/user/app/models")) | |
if not CHECKPOINTS_PATH.exists(): | |
snapshot_download(repo_id="facebook/seamless-m4t-v2-large", repo_type="model", local_dir=CHECKPOINTS_PATH) | |
asset_store.env_resolvers.clear() | |
asset_store.env_resolvers.append(lambda: "demo") | |
demo_metadata = [ | |
{ | |
"name": "seamlessM4T_v2_large@demo", | |
"checkpoint": f"file://{CHECKPOINTS_PATH}/seamlessM4T_v2_large.pt", | |
"char_tokenizer": f"file://{CHECKPOINTS_PATH}/spm_char_lang38_tc.model", | |
}, | |
{ | |
"name": "vocoder_v2@demo", | |
"checkpoint": f"file://{CHECKPOINTS_PATH}/vocoder_v2.pt", | |
}, | |
] | |
asset_store.metadata_providers.append(InProcAssetMetadataProvider(demo_metadata)) | |
if torch.cuda.is_available(): | |
device = torch.device("cuda:0") | |
dtype = torch.float16 | |
else: | |
device = torch.device("cpu") | |
dtype = torch.float32 | |
translator = Translator( | |
model_name_or_card="seamlessM4T_v2_large", | |
vocoder_name_or_card="vocoder_v2", | |
device=device, | |
dtype=dtype, | |
apply_mintox=True, | |
) | |
def run_t2tt(input_text: str, source_language: str, target_language: str) -> str: | |
source_language_code = LANGUAGE_NAME_TO_CODE[source_language] | |
target_language_code = LANGUAGE_NAME_TO_CODE[target_language] | |
out_texts, _ = translator.predict( | |
input=input_text, | |
task_str="T2TT", | |
src_lang=source_language_code, | |
tgt_lang=target_language_code, | |
) | |
return str(out_texts[0]) | |
import runpod | |
def runpod_handler(job): | |
job_input = job['input'] | |
input_text = job_input["input_text"] | |
source_language = job_input["source_language"] | |
target_language = job_input["target_language"] | |
return run_t2tt(input_text, source_language, target_language) | |
runpod.serverless.start({"handler": runpod_handler}) | |