Update app.py
Browse files
app.py
CHANGED
@@ -17,7 +17,7 @@ from transformers import (
|
|
17 |
pipeline as hf_pipeline
|
18 |
)
|
19 |
|
20 |
-
# ββ 1) Model setup
|
21 |
|
22 |
MODEL = "facebook/hf-seamless-m4t-medium"
|
23 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
@@ -31,7 +31,6 @@ if device == "cuda":
|
|
31 |
m4t_model.eval()
|
32 |
|
33 |
def translate_m4t(text: str, src_iso3: str, tgt_iso3: str, auto_detect=False) -> str:
|
34 |
-
"""Single-string translation (used for initial autoβdetect β English)."""
|
35 |
src = None if auto_detect else src_iso3
|
36 |
inputs = processor(text=text, src_lang=src, return_tensors="pt").to(device)
|
37 |
tokens = m4t_model.generate(**inputs, tgt_lang=tgt_iso3)
|
@@ -40,7 +39,6 @@ def translate_m4t(text: str, src_iso3: str, tgt_iso3: str, auto_detect=False) ->
|
|
40 |
def translate_m4t_batch(
|
41 |
texts: List[str], src_iso3: str, tgt_iso3: str, auto_detect=False
|
42 |
) -> List[str]:
|
43 |
-
"""Batchβmode translation: one generate() for many inputs."""
|
44 |
src = None if auto_detect else src_iso3
|
45 |
inputs = processor(
|
46 |
text=texts, src_lang=src, return_tensors="pt", padding=True
|
@@ -53,9 +51,15 @@ def translate_m4t_batch(
|
|
53 |
)
|
54 |
return processor.batch_decode(tokens, skip_special_tokens=True)
|
55 |
|
56 |
-
# ββ 2) NER pipeline βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
57 |
|
58 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
|
60 |
# ββ 3) CACHING helpers ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
61 |
|
@@ -98,6 +102,7 @@ def wiki_summary_cache(name: str) -> str:
|
|
98 |
except:
|
99 |
return "No summary available."
|
100 |
|
|
|
101 |
# ββ 4) Per-entity worker ββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
102 |
|
103 |
def process_entity(ent) -> dict:
|
@@ -135,11 +140,11 @@ def process_entity(ent) -> dict:
|
|
135 |
|
136 |
def get_context(
|
137 |
text: str,
|
138 |
-
source_lang: str,
|
139 |
-
output_lang: str,
|
140 |
auto_detect: bool
|
141 |
):
|
142 |
-
# a) Ensure
|
143 |
if auto_detect or source_lang != "eng":
|
144 |
en = translate_m4t(text, source_lang, "eng", auto_detect=auto_detect)
|
145 |
else:
|
@@ -156,23 +161,22 @@ def get_context(
|
|
156 |
seen.add(w)
|
157 |
unique_ents.append(ent)
|
158 |
|
159 |
-
# c)
|
160 |
entities = []
|
161 |
with ThreadPoolExecutor(max_workers=8) as exe:
|
162 |
futures = [exe.submit(process_entity, ent) for ent in unique_ents]
|
163 |
for fut in futures:
|
164 |
entities.append(fut.result())
|
165 |
|
166 |
-
# d) Batch-translate
|
167 |
if output_lang != "eng":
|
168 |
to_translate = []
|
169 |
-
translations_info = []
|
170 |
|
171 |
for i, e in enumerate(entities):
|
172 |
if e["type"] == "wiki":
|
173 |
translations_info.append(("summary", i))
|
174 |
to_translate.append(e["summary"])
|
175 |
-
|
176 |
elif e["type"] == "location":
|
177 |
for j, r in enumerate(e["restaurants"]):
|
178 |
translations_info.append(("restaurant", i, j))
|
@@ -181,10 +185,8 @@ def get_context(
|
|
181 |
translations_info.append(("attraction", i, j))
|
182 |
to_translate.append(a["name"])
|
183 |
|
184 |
-
# single batched call
|
185 |
translated = translate_m4t_batch(to_translate, "eng", output_lang)
|
186 |
|
187 |
-
# redistribute
|
188 |
for txt, info in zip(translated, translations_info):
|
189 |
kind = info[0]
|
190 |
if kind == "summary":
|
@@ -200,7 +202,7 @@ def get_context(
|
|
200 |
return {"entities": entities}
|
201 |
|
202 |
|
203 |
-
# ββ 6) Gradio interface
|
204 |
|
205 |
iface = gr.Interface(
|
206 |
fn=get_context,
|
@@ -213,7 +215,7 @@ iface = gr.Interface(
|
|
213 |
outputs="json",
|
214 |
title="iVoice Context-Aware",
|
215 |
description="Returns only the detected entities and their related info."
|
216 |
-
).queue(
|
217 |
|
218 |
if __name__ == "__main__":
|
219 |
iface.launch(
|
|
|
17 |
pipeline as hf_pipeline
|
18 |
)
|
19 |
|
20 |
+
# ββ 1) Model setup ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
21 |
|
22 |
MODEL = "facebook/hf-seamless-m4t-medium"
|
23 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
31 |
m4t_model.eval()
|
32 |
|
33 |
def translate_m4t(text: str, src_iso3: str, tgt_iso3: str, auto_detect=False) -> str:
|
|
|
34 |
src = None if auto_detect else src_iso3
|
35 |
inputs = processor(text=text, src_lang=src, return_tensors="pt").to(device)
|
36 |
tokens = m4t_model.generate(**inputs, tgt_lang=tgt_iso3)
|
|
|
39 |
def translate_m4t_batch(
|
40 |
texts: List[str], src_iso3: str, tgt_iso3: str, auto_detect=False
|
41 |
) -> List[str]:
|
|
|
42 |
src = None if auto_detect else src_iso3
|
43 |
inputs = processor(
|
44 |
text=texts, src_lang=src, return_tensors="pt", padding=True
|
|
|
51 |
)
|
52 |
return processor.batch_decode(tokens, skip_special_tokens=True)
|
53 |
|
|
|
54 |
|
55 |
+
# ββ 2) NER pipeline (updated for deprecation) ββββββββββββββββββββββββββββββββ
|
56 |
+
|
57 |
+
ner = hf_pipeline(
|
58 |
+
"ner",
|
59 |
+
model="dslim/bert-base-NER-uncased",
|
60 |
+
aggregation_strategy="simple"
|
61 |
+
)
|
62 |
+
|
63 |
|
64 |
# ββ 3) CACHING helpers ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
65 |
|
|
|
102 |
except:
|
103 |
return "No summary available."
|
104 |
|
105 |
+
|
106 |
# ββ 4) Per-entity worker ββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
107 |
|
108 |
def process_entity(ent) -> dict:
|
|
|
140 |
|
141 |
def get_context(
|
142 |
text: str,
|
143 |
+
source_lang: str,
|
144 |
+
output_lang: str,
|
145 |
auto_detect: bool
|
146 |
):
|
147 |
+
# a) Ensure English for NER
|
148 |
if auto_detect or source_lang != "eng":
|
149 |
en = translate_m4t(text, source_lang, "eng", auto_detect=auto_detect)
|
150 |
else:
|
|
|
161 |
seen.add(w)
|
162 |
unique_ents.append(ent)
|
163 |
|
164 |
+
# c) Parallel I/O
|
165 |
entities = []
|
166 |
with ThreadPoolExecutor(max_workers=8) as exe:
|
167 |
futures = [exe.submit(process_entity, ent) for ent in unique_ents]
|
168 |
for fut in futures:
|
169 |
entities.append(fut.result())
|
170 |
|
171 |
+
# d) Batch-translate non-English fields
|
172 |
if output_lang != "eng":
|
173 |
to_translate = []
|
174 |
+
translations_info = []
|
175 |
|
176 |
for i, e in enumerate(entities):
|
177 |
if e["type"] == "wiki":
|
178 |
translations_info.append(("summary", i))
|
179 |
to_translate.append(e["summary"])
|
|
|
180 |
elif e["type"] == "location":
|
181 |
for j, r in enumerate(e["restaurants"]):
|
182 |
translations_info.append(("restaurant", i, j))
|
|
|
185 |
translations_info.append(("attraction", i, j))
|
186 |
to_translate.append(a["name"])
|
187 |
|
|
|
188 |
translated = translate_m4t_batch(to_translate, "eng", output_lang)
|
189 |
|
|
|
190 |
for txt, info in zip(translated, translations_info):
|
191 |
kind = info[0]
|
192 |
if kind == "summary":
|
|
|
202 |
return {"entities": entities}
|
203 |
|
204 |
|
205 |
+
# ββ 6) Gradio interface βββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
206 |
|
207 |
iface = gr.Interface(
|
208 |
fn=get_context,
|
|
|
215 |
outputs="json",
|
216 |
title="iVoice Context-Aware",
|
217 |
description="Returns only the detected entities and their related info."
|
218 |
+
).queue() # β removed unsupported kwargs
|
219 |
|
220 |
if __name__ == "__main__":
|
221 |
iface.launch(
|