jxtan commited on
Commit
a1ec192
·
1 Parent(s): 7c4637d

Cache Model into Docker Container

Browse files
Files changed (2) hide show
  1. Dockerfile +2 -0
  2. cache.py +55 -0
Dockerfile CHANGED
@@ -58,5 +58,7 @@ ENV PYTHONPATH=${HOME}/app \
58
  TQDM_POSITION=-1 \
59
  TQDM_MININTERVAL=1 \
60
  SYSTEM=spaces
 
 
61
  # CMD ["python", "app.py"]
62
  CMD ["python", "server.py"]
 
58
  TQDM_POSITION=-1 \
59
  TQDM_MININTERVAL=1 \
60
  SYSTEM=spaces
61
+
62
+ RUN python -u cache.py
63
  # CMD ["python", "app.py"]
64
  CMD ["python", "server.py"]
cache.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import pathlib
5
+
6
+ import gradio as gr
7
+ import numpy as np
8
+ import torch
9
+ import torchaudio
10
+ from fairseq2.assets import InProcAssetMetadataProvider, asset_store
11
+ from huggingface_hub import snapshot_download
12
+ from seamless_communication.inference import Translator
13
+
14
+ from lang_list import (
15
+ ASR_TARGET_LANGUAGE_NAMES,
16
+ LANGUAGE_NAME_TO_CODE,
17
+ S2ST_TARGET_LANGUAGE_NAMES,
18
+ S2TT_TARGET_LANGUAGE_NAMES,
19
+ T2ST_TARGET_LANGUAGE_NAMES,
20
+ T2TT_TARGET_LANGUAGE_NAMES,
21
+ TEXT_SOURCE_LANGUAGE_NAMES,
22
+ )
23
+
24
+ CHECKPOINTS_PATH = pathlib.Path(os.getenv("CHECKPOINTS_PATH", "/home/user/app/models"))
25
+ if not CHECKPOINTS_PATH.exists():
26
+ snapshot_download(repo_id="facebook/seamless-m4t-v2-large", repo_type="model", local_dir=CHECKPOINTS_PATH)
27
+ asset_store.env_resolvers.clear()
28
+ asset_store.env_resolvers.append(lambda: "demo")
29
+ demo_metadata = [
30
+ {
31
+ "name": "seamlessM4T_v2_large@demo",
32
+ "checkpoint": f"file://{CHECKPOINTS_PATH}/seamlessM4T_v2_large.pt",
33
+ "char_tokenizer": f"file://{CHECKPOINTS_PATH}/spm_char_lang38_tc.model",
34
+ },
35
+ {
36
+ "name": "vocoder_v2@demo",
37
+ "checkpoint": f"file://{CHECKPOINTS_PATH}/vocoder_v2.pt",
38
+ },
39
+ ]
40
+ asset_store.metadata_providers.append(InProcAssetMetadataProvider(demo_metadata))
41
+
42
+ if torch.cuda.is_available():
43
+ device = torch.device("cuda:0")
44
+ dtype = torch.float16
45
+ else:
46
+ device = torch.device("cpu")
47
+ dtype = torch.float32
48
+
49
+ translator = Translator(
50
+ model_name_or_card="seamlessM4T_v2_large",
51
+ vocoder_name_or_card="vocoder_v2",
52
+ device=device,
53
+ dtype=dtype,
54
+ apply_mintox=True,
55
+ )