Shenuki commited on
Commit
ce5002a
Β·
verified Β·
1 Parent(s): 6bebb04

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -16
app.py CHANGED
@@ -17,7 +17,7 @@ from transformers import (
17
  pipeline as hf_pipeline
18
  )
19
 
20
- # ── 1) Model setup (unchanged) ───────────────────────────────────────────────
21
 
22
  MODEL = "facebook/hf-seamless-m4t-medium"
23
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -31,7 +31,6 @@ if device == "cuda":
31
  m4t_model.eval()
32
 
33
  def translate_m4t(text: str, src_iso3: str, tgt_iso3: str, auto_detect=False) -> str:
34
- """Single-string translation (used for initial auto‐detect β†’ English)."""
35
  src = None if auto_detect else src_iso3
36
  inputs = processor(text=text, src_lang=src, return_tensors="pt").to(device)
37
  tokens = m4t_model.generate(**inputs, tgt_lang=tgt_iso3)
@@ -40,7 +39,6 @@ def translate_m4t(text: str, src_iso3: str, tgt_iso3: str, auto_detect=False) ->
40
  def translate_m4t_batch(
41
  texts: List[str], src_iso3: str, tgt_iso3: str, auto_detect=False
42
  ) -> List[str]:
43
- """Batch‐mode translation: one generate() for many inputs."""
44
  src = None if auto_detect else src_iso3
45
  inputs = processor(
46
  text=texts, src_lang=src, return_tensors="pt", padding=True
@@ -53,9 +51,15 @@ def translate_m4t_batch(
53
  )
54
  return processor.batch_decode(tokens, skip_special_tokens=True)
55
 
56
- # ── 2) NER pipeline ─────────────────────────────────────────────────────────
57
 
58
- ner = hf_pipeline("ner", model="dslim/bert-base-NER-uncased", grouped_entities=True)
 
 
 
 
 
 
 
59
 
60
  # ── 3) CACHING helpers ──────────────────────────────────────────────────────
61
 
@@ -98,6 +102,7 @@ def wiki_summary_cache(name: str) -> str:
98
  except:
99
  return "No summary available."
100
 
 
101
  # ── 4) Per-entity worker ────────────────────────────────────────────────────
102
 
103
  def process_entity(ent) -> dict:
@@ -135,11 +140,11 @@ def process_entity(ent) -> dict:
135
 
136
  def get_context(
137
  text: str,
138
- source_lang: str, # ISO-639-3 e.g. "eng"
139
- output_lang: str, # ISO-639-3 e.g. "fra"
140
  auto_detect: bool
141
  ):
142
- # a) Ensure we have English for NER
143
  if auto_detect or source_lang != "eng":
144
  en = translate_m4t(text, source_lang, "eng", auto_detect=auto_detect)
145
  else:
@@ -156,23 +161,22 @@ def get_context(
156
  seen.add(w)
157
  unique_ents.append(ent)
158
 
159
- # c) Process each entity in parallel
160
  entities = []
161
  with ThreadPoolExecutor(max_workers=8) as exe:
162
  futures = [exe.submit(process_entity, ent) for ent in unique_ents]
163
  for fut in futures:
164
  entities.append(fut.result())
165
 
166
- # d) Batch-translate any non-English fields
167
  if output_lang != "eng":
168
  to_translate = []
169
- translations_info = [] # (kind, entity_i, [sub_i])
170
 
171
  for i, e in enumerate(entities):
172
  if e["type"] == "wiki":
173
  translations_info.append(("summary", i))
174
  to_translate.append(e["summary"])
175
-
176
  elif e["type"] == "location":
177
  for j, r in enumerate(e["restaurants"]):
178
  translations_info.append(("restaurant", i, j))
@@ -181,10 +185,8 @@ def get_context(
181
  translations_info.append(("attraction", i, j))
182
  to_translate.append(a["name"])
183
 
184
- # single batched call
185
  translated = translate_m4t_batch(to_translate, "eng", output_lang)
186
 
187
- # redistribute
188
  for txt, info in zip(translated, translations_info):
189
  kind = info[0]
190
  if kind == "summary":
@@ -200,7 +202,7 @@ def get_context(
200
  return {"entities": entities}
201
 
202
 
203
- # ── 6) Gradio interface with concurrency tuning ─────────────────────────────
204
 
205
  iface = gr.Interface(
206
  fn=get_context,
@@ -213,7 +215,7 @@ iface = gr.Interface(
213
  outputs="json",
214
  title="iVoice Context-Aware",
215
  description="Returns only the detected entities and their related info."
216
- ).queue(concurrency_count=4, max_size=8)
217
 
218
  if __name__ == "__main__":
219
  iface.launch(
 
17
  pipeline as hf_pipeline
18
  )
19
 
20
+ # ── 1) Model setup ────────────────────────────────────────────────────────────
21
 
22
  MODEL = "facebook/hf-seamless-m4t-medium"
23
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
31
  m4t_model.eval()
32
 
33
  def translate_m4t(text: str, src_iso3: str, tgt_iso3: str, auto_detect=False) -> str:
 
34
  src = None if auto_detect else src_iso3
35
  inputs = processor(text=text, src_lang=src, return_tensors="pt").to(device)
36
  tokens = m4t_model.generate(**inputs, tgt_lang=tgt_iso3)
 
39
  def translate_m4t_batch(
40
  texts: List[str], src_iso3: str, tgt_iso3: str, auto_detect=False
41
  ) -> List[str]:
 
42
  src = None if auto_detect else src_iso3
43
  inputs = processor(
44
  text=texts, src_lang=src, return_tensors="pt", padding=True
 
51
  )
52
  return processor.batch_decode(tokens, skip_special_tokens=True)
53
 
 
54
 
55
+ # ── 2) NER pipeline (updated for deprecation) ────────────────────────────────
56
+
57
+ ner = hf_pipeline(
58
+ "ner",
59
+ model="dslim/bert-base-NER-uncased",
60
+ aggregation_strategy="simple"
61
+ )
62
+
63
 
64
  # ── 3) CACHING helpers ──────────────────────────────────────────────────────
65
 
 
102
  except:
103
  return "No summary available."
104
 
105
+
106
  # ── 4) Per-entity worker ────────────────────────────────────────────────────
107
 
108
  def process_entity(ent) -> dict:
 
140
 
141
  def get_context(
142
  text: str,
143
+ source_lang: str,
144
+ output_lang: str,
145
  auto_detect: bool
146
  ):
147
+ # a) Ensure English for NER
148
  if auto_detect or source_lang != "eng":
149
  en = translate_m4t(text, source_lang, "eng", auto_detect=auto_detect)
150
  else:
 
161
  seen.add(w)
162
  unique_ents.append(ent)
163
 
164
+ # c) Parallel I/O
165
  entities = []
166
  with ThreadPoolExecutor(max_workers=8) as exe:
167
  futures = [exe.submit(process_entity, ent) for ent in unique_ents]
168
  for fut in futures:
169
  entities.append(fut.result())
170
 
171
+ # d) Batch-translate non-English fields
172
  if output_lang != "eng":
173
  to_translate = []
174
+ translations_info = []
175
 
176
  for i, e in enumerate(entities):
177
  if e["type"] == "wiki":
178
  translations_info.append(("summary", i))
179
  to_translate.append(e["summary"])
 
180
  elif e["type"] == "location":
181
  for j, r in enumerate(e["restaurants"]):
182
  translations_info.append(("restaurant", i, j))
 
185
  translations_info.append(("attraction", i, j))
186
  to_translate.append(a["name"])
187
 
 
188
  translated = translate_m4t_batch(to_translate, "eng", output_lang)
189
 
 
190
  for txt, info in zip(translated, translations_info):
191
  kind = info[0]
192
  if kind == "summary":
 
202
  return {"entities": entities}
203
 
204
 
205
+ # ── 6) Gradio interface ─────────────────────────────────────────────────────
206
 
207
  iface = gr.Interface(
208
  fn=get_context,
 
215
  outputs="json",
216
  title="iVoice Context-Aware",
217
  description="Returns only the detected entities and their related info."
218
+ ).queue() # ← removed unsupported kwargs
219
 
220
  if __name__ == "__main__":
221
  iface.launch(