Spaces:
Runtime error
Runtime error
Update runpod handler for Seamless
Browse files- server.py +46 -55
- test_input.json +7 -0
server.py
CHANGED
@@ -21,65 +21,56 @@ from lang_list import (
|
|
21 |
TEXT_SOURCE_LANGUAGE_NAMES,
|
22 |
)
|
23 |
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
|
68 |
import runpod
|
69 |
|
70 |
-
def
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
if not isinstance(the_number, int):
|
76 |
-
return {"error": "Silly human, you need to pass an integer."}
|
77 |
-
|
78 |
-
if the_number % 2 == 0:
|
79 |
-
return True
|
80 |
-
|
81 |
-
return False
|
82 |
-
|
83 |
-
# output_text = run_t2tt(input_text, source_language, target_language)
|
84 |
|
85 |
-
runpod.serverless.start({"handler":
|
|
|
21 |
TEXT_SOURCE_LANGUAGE_NAMES,
|
22 |
)
|
23 |
|
24 |
+
CHECKPOINTS_PATH = pathlib.Path(os.getenv("CHECKPOINTS_PATH", "/home/user/app/models"))
|
25 |
+
if not CHECKPOINTS_PATH.exists():
|
26 |
+
snapshot_download(repo_id="facebook/seamless-m4t-v2-large", repo_type="model", local_dir=CHECKPOINTS_PATH)
|
27 |
+
asset_store.env_resolvers.clear()
|
28 |
+
asset_store.env_resolvers.append(lambda: "demo")
|
29 |
+
demo_metadata = [
|
30 |
+
{
|
31 |
+
"name": "seamlessM4T_v2_large@demo",
|
32 |
+
"checkpoint": f"file://{CHECKPOINTS_PATH}/seamlessM4T_v2_large.pt",
|
33 |
+
"char_tokenizer": f"file://{CHECKPOINTS_PATH}/spm_char_lang38_tc.model",
|
34 |
+
},
|
35 |
+
{
|
36 |
+
"name": "vocoder_v2@demo",
|
37 |
+
"checkpoint": f"file://{CHECKPOINTS_PATH}/vocoder_v2.pt",
|
38 |
+
},
|
39 |
+
]
|
40 |
+
asset_store.metadata_providers.append(InProcAssetMetadataProvider(demo_metadata))
|
41 |
|
42 |
+
if torch.cuda.is_available():
|
43 |
+
device = torch.device("cuda:0")
|
44 |
+
dtype = torch.float16
|
45 |
+
else:
|
46 |
+
device = torch.device("cpu")
|
47 |
+
dtype = torch.float32
|
48 |
|
49 |
+
translator = Translator(
|
50 |
+
model_name_or_card="seamlessM4T_v2_large",
|
51 |
+
vocoder_name_or_card="vocoder_v2",
|
52 |
+
device=device,
|
53 |
+
dtype=dtype,
|
54 |
+
apply_mintox=True,
|
55 |
+
)
|
56 |
|
57 |
+
def run_t2tt(input_text: str, source_language: str, target_language: str) -> str:
|
58 |
+
source_language_code = LANGUAGE_NAME_TO_CODE[source_language]
|
59 |
+
target_language_code = LANGUAGE_NAME_TO_CODE[target_language]
|
60 |
+
out_texts, _ = translator.predict(
|
61 |
+
input=input_text,
|
62 |
+
task_str="T2TT",
|
63 |
+
src_lang=source_language_code,
|
64 |
+
tgt_lang=target_language_code,
|
65 |
+
)
|
66 |
+
return str(out_texts[0])
|
67 |
|
68 |
import runpod
|
69 |
|
70 |
+
def runpod_handler(job):
|
71 |
+
input_text = job["input_text"]
|
72 |
+
source_language = job["source_language"]
|
73 |
+
target_language = job["target_language"]
|
74 |
+
return run_t2tt(input_text, source_language, target_language)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
|
76 |
+
runpod.serverless.start({"handler": runpod_handler})
|
test_input.json
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"input": {
|
3 |
+
"input_text": "How are you doing today?",
|
4 |
+
"source_language": "English",
|
5 |
+
"target_language": "Mandarin Chinese"
|
6 |
+
}
|
7 |
+
}
|