jxtan commited on
Commit
7c4637d
·
1 Parent(s): ae89d62

Update runpod handler for Seamless

Browse files
Files changed (2) hide show
  1. server.py +46 -55
  2. test_input.json +7 -0
server.py CHANGED
@@ -21,65 +21,56 @@ from lang_list import (
21
  TEXT_SOURCE_LANGUAGE_NAMES,
22
  )
23
 
24
- # CHECKPOINTS_PATH = pathlib.Path(os.getenv("CHECKPOINTS_PATH", "/home/user/app/models"))
25
- # if not CHECKPOINTS_PATH.exists():
26
- # snapshot_download(repo_id="facebook/seamless-m4t-v2-large", repo_type="model", local_dir=CHECKPOINTS_PATH)
27
- # asset_store.env_resolvers.clear()
28
- # asset_store.env_resolvers.append(lambda: "demo")
29
- # demo_metadata = [
30
- # {
31
- # "name": "seamlessM4T_v2_large@demo",
32
- # "checkpoint": f"file://{CHECKPOINTS_PATH}/seamlessM4T_v2_large.pt",
33
- # "char_tokenizer": f"file://{CHECKPOINTS_PATH}/spm_char_lang38_tc.model",
34
- # },
35
- # {
36
- # "name": "vocoder_v2@demo",
37
- # "checkpoint": f"file://{CHECKPOINTS_PATH}/vocoder_v2.pt",
38
- # },
39
- # ]
40
- # asset_store.metadata_providers.append(InProcAssetMetadataProvider(demo_metadata))
41
 
42
- # if torch.cuda.is_available():
43
- # device = torch.device("cuda:0")
44
- # dtype = torch.float16
45
- # else:
46
- # device = torch.device("cpu")
47
- # dtype = torch.float32
48
 
49
- # translator = Translator(
50
- # model_name_or_card="seamlessM4T_v2_large",
51
- # vocoder_name_or_card="vocoder_v2",
52
- # device=device,
53
- # dtype=dtype,
54
- # apply_mintox=True,
55
- # )
56
 
57
- # def run_t2tt(input_text: str, source_language: str, target_language: str) -> str:
58
- # source_language_code = LANGUAGE_NAME_TO_CODE[source_language]
59
- # target_language_code = LANGUAGE_NAME_TO_CODE[target_language]
60
- # out_texts, _ = translator.predict(
61
- # input=input_text,
62
- # task_str="T2TT",
63
- # src_lang=source_language_code,
64
- # tgt_lang=target_language_code,
65
- # )
66
- # return str(out_texts[0])
67
 
68
  import runpod
69
 
70
- def is_even(job):
71
-
72
- job_input = job["input"]
73
- the_number = job_input["number"]
74
-
75
- if not isinstance(the_number, int):
76
- return {"error": "Silly human, you need to pass an integer."}
77
-
78
- if the_number % 2 == 0:
79
- return True
80
-
81
- return False
82
-
83
- # output_text = run_t2tt(input_text, source_language, target_language)
84
 
85
- runpod.serverless.start({"handler": is_even})
 
21
  TEXT_SOURCE_LANGUAGE_NAMES,
22
  )
23
 
24
+ CHECKPOINTS_PATH = pathlib.Path(os.getenv("CHECKPOINTS_PATH", "/home/user/app/models"))
25
+ if not CHECKPOINTS_PATH.exists():
26
+ snapshot_download(repo_id="facebook/seamless-m4t-v2-large", repo_type="model", local_dir=CHECKPOINTS_PATH)
27
+ asset_store.env_resolvers.clear()
28
+ asset_store.env_resolvers.append(lambda: "demo")
29
+ demo_metadata = [
30
+ {
31
+ "name": "seamlessM4T_v2_large@demo",
32
+ "checkpoint": f"file://{CHECKPOINTS_PATH}/seamlessM4T_v2_large.pt",
33
+ "char_tokenizer": f"file://{CHECKPOINTS_PATH}/spm_char_lang38_tc.model",
34
+ },
35
+ {
36
+ "name": "vocoder_v2@demo",
37
+ "checkpoint": f"file://{CHECKPOINTS_PATH}/vocoder_v2.pt",
38
+ },
39
+ ]
40
+ asset_store.metadata_providers.append(InProcAssetMetadataProvider(demo_metadata))
41
 
42
+ if torch.cuda.is_available():
43
+ device = torch.device("cuda:0")
44
+ dtype = torch.float16
45
+ else:
46
+ device = torch.device("cpu")
47
+ dtype = torch.float32
48
 
49
+ translator = Translator(
50
+ model_name_or_card="seamlessM4T_v2_large",
51
+ vocoder_name_or_card="vocoder_v2",
52
+ device=device,
53
+ dtype=dtype,
54
+ apply_mintox=True,
55
+ )
56
 
57
+ def run_t2tt(input_text: str, source_language: str, target_language: str) -> str:
58
+ source_language_code = LANGUAGE_NAME_TO_CODE[source_language]
59
+ target_language_code = LANGUAGE_NAME_TO_CODE[target_language]
60
+ out_texts, _ = translator.predict(
61
+ input=input_text,
62
+ task_str="T2TT",
63
+ src_lang=source_language_code,
64
+ tgt_lang=target_language_code,
65
+ )
66
+ return str(out_texts[0])
67
 
68
  import runpod
69
 
70
+ def runpod_handler(job):
71
+ input_text = job["input_text"]
72
+ source_language = job["source_language"]
73
+ target_language = job["target_language"]
74
+ return run_t2tt(input_text, source_language, target_language)
 
 
 
 
 
 
 
 
 
75
 
76
+ runpod.serverless.start({"handler": runpod_handler})
test_input.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "input": {
3
+ "input_text": "How are you doing today?",
4
+ "source_language": "English",
5
+ "target_language": "Mandarin Chinese"
6
+ }
7
+ }