jxtan's picture
Bug fix server.py invalid job inputs
0913403
from __future__ import annotations
import os
import pathlib
import gradio as gr
import numpy as np
import torch
import torchaudio
from fairseq2.assets import InProcAssetMetadataProvider, asset_store
from huggingface_hub import snapshot_download
from seamless_communication.inference import Translator
from lang_list import (
ASR_TARGET_LANGUAGE_NAMES,
LANGUAGE_NAME_TO_CODE,
S2ST_TARGET_LANGUAGE_NAMES,
S2TT_TARGET_LANGUAGE_NAMES,
T2ST_TARGET_LANGUAGE_NAMES,
T2TT_TARGET_LANGUAGE_NAMES,
TEXT_SOURCE_LANGUAGE_NAMES,
)
CHECKPOINTS_PATH = pathlib.Path(os.getenv("CHECKPOINTS_PATH", "/home/user/app/models"))
if not CHECKPOINTS_PATH.exists():
snapshot_download(repo_id="facebook/seamless-m4t-v2-large", repo_type="model", local_dir=CHECKPOINTS_PATH)
asset_store.env_resolvers.clear()
asset_store.env_resolvers.append(lambda: "demo")
demo_metadata = [
{
"name": "seamlessM4T_v2_large@demo",
"checkpoint": f"file://{CHECKPOINTS_PATH}/seamlessM4T_v2_large.pt",
"char_tokenizer": f"file://{CHECKPOINTS_PATH}/spm_char_lang38_tc.model",
},
{
"name": "vocoder_v2@demo",
"checkpoint": f"file://{CHECKPOINTS_PATH}/vocoder_v2.pt",
},
]
asset_store.metadata_providers.append(InProcAssetMetadataProvider(demo_metadata))
if torch.cuda.is_available():
device = torch.device("cuda:0")
dtype = torch.float16
else:
device = torch.device("cpu")
dtype = torch.float32
translator = Translator(
model_name_or_card="seamlessM4T_v2_large",
vocoder_name_or_card="vocoder_v2",
device=device,
dtype=dtype,
apply_mintox=True,
)
def run_t2tt(input_text: str, source_language: str, target_language: str) -> str:
source_language_code = LANGUAGE_NAME_TO_CODE[source_language]
target_language_code = LANGUAGE_NAME_TO_CODE[target_language]
out_texts, _ = translator.predict(
input=input_text,
task_str="T2TT",
src_lang=source_language_code,
tgt_lang=target_language_code,
)
return str(out_texts[0])
import runpod
def runpod_handler(job):
job_input = job['input']
input_text = job_input["input_text"]
source_language = job_input["source_language"]
target_language = job_input["target_language"]
return run_t2tt(input_text, source_language, target_language)
runpod.serverless.start({"handler": runpod_handler})