Shenuki commited on
Commit
6bebb04
Β·
verified Β·
1 Parent(s): 948cffb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +157 -63
app.py CHANGED
@@ -6,6 +6,10 @@ import wikipedia
6
  import gradio as gr
7
  import torch
8
 
 
 
 
 
9
  from transformers import (
10
  SeamlessM4TTokenizer,
11
  SeamlessM4TProcessor,
@@ -13,103 +17,191 @@ from transformers import (
13
  pipeline as hf_pipeline
14
  )
15
 
16
- # 1) Load SeamlessM4T (slow tokenizer)
 
17
  MODEL = "facebook/hf-seamless-m4t-medium"
18
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
19
  tokenizer = SeamlessM4TTokenizer.from_pretrained(MODEL, use_fast=False)
20
  processor = SeamlessM4TProcessor.from_pretrained(MODEL, tokenizer=tokenizer)
21
- m4t_model = SeamlessM4TForTextToText.from_pretrained(MODEL).to(device).eval()
22
 
23
- def translate_m4t(text, src_iso3, tgt_iso3, auto_detect=False):
 
 
 
 
 
 
24
  src = None if auto_detect else src_iso3
25
  inputs = processor(text=text, src_lang=src, return_tensors="pt").to(device)
26
  tokens = m4t_model.generate(**inputs, tgt_lang=tgt_iso3)
27
  return processor.decode(tokens[0].tolist(), skip_special_tokens=True)
28
 
29
- # 2) NER pipeline
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  ner = hf_pipeline("ner", model="dslim/bert-base-NER-uncased", grouped_entities=True)
31
 
32
- # 3) Geocode & POIs
33
- def geocode(place):
 
 
34
  r = requests.get(
35
  "https://nominatim.openstreetmap.org/search",
36
  params={"q": place, "format": "json", "limit": 1},
37
- headers={"User-Agent":"iVoiceContext/1.0"}
38
  ).json()
39
- if not r: return None
 
40
  return {"lat": float(r[0]["lat"]), "lon": float(r[0]["lon"])}
41
 
42
- def fetch_osm(lat, lon, osm_filter, limit=5):
 
43
  payload = f"""
44
  [out:json][timeout:25];
45
- ( node{osm_filter}(around:1000,{lat},{lon});
46
- way{osm_filter}(around:1000,{lat},{lon}); );
 
 
47
  out center {limit};
48
  """
49
- resp = requests.post("https://overpass-api.de/api/interpreter", data={"data": payload})
 
 
 
50
  elems = resp.json().get("elements", [])
51
- return [{"name": e["tags"]["name"]} for e in elems if e.get("tags",{}).get("name")]
52
-
53
- # 4) Main function
54
- def get_context(text: str,
55
- source_lang: str, # ISO-639-3 e.g. "eng"
56
- output_lang: str, # ISO-639-3 e.g. "fra"
57
- auto_detect: bool):
58
-
59
- # a) Ensure English for NER
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  if auto_detect or source_lang != "eng":
61
  en = translate_m4t(text, source_lang, "eng", auto_detect=auto_detect)
62
  else:
63
  en = text
64
 
65
- # b) Extract unique entities
66
  ner_out = ner(en)
67
- seen, entities = set(), []
 
68
  for ent in ner_out:
69
- w, lbl = ent["word"], ent["entity_group"]
70
- if w in seen: continue
 
71
  seen.add(w)
 
72
 
73
- if lbl == "LOC":
74
- geo = geocode(w)
75
- if not geo:
76
- obj = {"text": w, "label": lbl, "type": "location", "error": "could not geocode"}
77
- else:
78
- obj = {
79
- "text": w,
80
- "label": lbl,
81
- "type": "location",
82
- "geo": geo,
83
- "restaurants": fetch_osm(geo["lat"], geo["lon"], '["amenity"="restaurant"]'),
84
- "attractions": fetch_osm(geo["lat"], geo["lon"], '["tourism"="attraction"]')
85
- }
86
-
87
- else:
88
- # PERSON/ORG/MISC β†’ Wikipedia
89
- try:
90
- summ = wikipedia.summary(w, sentences=2)
91
- except:
92
- summ = "No summary available."
93
- obj = {"text": w, "label": lbl, "type": "wiki", "summary": summ}
94
-
95
- entities.append(obj)
96
-
97
- # c) Translate all fields β†’ output_lang
98
  if output_lang != "eng":
99
- for e in entities:
 
 
 
100
  if e["type"] == "wiki":
101
- e["summary"] = translate_m4t(e["summary"], "eng", output_lang)
 
 
102
  elif e["type"] == "location":
103
- for field in ("restaurants","attractions"):
104
- e[field] = [
105
- {"name": translate_m4t(item["name"], "eng", output_lang)}
106
- for item in e[field]
107
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
 
109
- # d) Return only entities
110
  return {"entities": entities}
111
 
112
- # 5) Gradio interface
 
 
113
  iface = gr.Interface(
114
  fn=get_context,
115
  inputs=[
@@ -121,9 +213,11 @@ iface = gr.Interface(
121
  outputs="json",
122
  title="iVoice Context-Aware",
123
  description="Returns only the detected entities and their related info."
124
- ).queue()
125
 
126
  if __name__ == "__main__":
127
- iface.launch(server_name="0.0.0.0",
128
- server_port=int(os.environ.get("PORT", 7860)),
129
- share=True)
 
 
 
6
  import gradio as gr
7
  import torch
8
 
9
+ from functools import lru_cache
10
+ from concurrent.futures import ThreadPoolExecutor
11
+ from typing import List
12
+
13
  from transformers import (
14
  SeamlessM4TTokenizer,
15
  SeamlessM4TProcessor,
 
17
  pipeline as hf_pipeline
18
  )
19
 
20
+ # ── 1) Model setup (unchanged) ───────────────────────────────────────────────
21
+
22
  MODEL = "facebook/hf-seamless-m4t-medium"
23
  device = "cuda" if torch.cuda.is_available() else "cpu"
24
+
25
  tokenizer = SeamlessM4TTokenizer.from_pretrained(MODEL, use_fast=False)
26
  processor = SeamlessM4TProcessor.from_pretrained(MODEL, tokenizer=tokenizer)
 
27
 
28
+ m4t_model = SeamlessM4TForTextToText.from_pretrained(MODEL).to(device)
29
+ if device == "cuda":
30
+ m4t_model = m4t_model.half() # FP16 for faster inference on GPU
31
+ m4t_model.eval()
32
+
33
+ def translate_m4t(text: str, src_iso3: str, tgt_iso3: str, auto_detect=False) -> str:
34
+ """Single-string translation (used for initial auto‐detect β†’ English)."""
35
  src = None if auto_detect else src_iso3
36
  inputs = processor(text=text, src_lang=src, return_tensors="pt").to(device)
37
  tokens = m4t_model.generate(**inputs, tgt_lang=tgt_iso3)
38
  return processor.decode(tokens[0].tolist(), skip_special_tokens=True)
39
 
40
+ def translate_m4t_batch(
41
+ texts: List[str], src_iso3: str, tgt_iso3: str, auto_detect=False
42
+ ) -> List[str]:
43
+ """Batch‐mode translation: one generate() for many inputs."""
44
+ src = None if auto_detect else src_iso3
45
+ inputs = processor(
46
+ text=texts, src_lang=src, return_tensors="pt", padding=True
47
+ ).to(device)
48
+ tokens = m4t_model.generate(
49
+ **inputs,
50
+ tgt_lang=tgt_iso3,
51
+ max_new_tokens=60,
52
+ num_beams=1
53
+ )
54
+ return processor.batch_decode(tokens, skip_special_tokens=True)
55
+
56
+ # ── 2) NER pipeline ─────────────────────────────────────────────────────────
57
+
58
  ner = hf_pipeline("ner", model="dslim/bert-base-NER-uncased", grouped_entities=True)
59
 
60
+ # ── 3) CACHING helpers ──────────────────────────────────────────────────────
61
+
62
+ @lru_cache(maxsize=256)
63
+ def geocode_cache(place: str):
64
  r = requests.get(
65
  "https://nominatim.openstreetmap.org/search",
66
  params={"q": place, "format": "json", "limit": 1},
67
+ headers={"User-Agent": "iVoiceContext/1.0"}
68
  ).json()
69
+ if not r:
70
+ return None
71
  return {"lat": float(r[0]["lat"]), "lon": float(r[0]["lon"])}
72
 
73
+ @lru_cache(maxsize=256)
74
+ def fetch_osm_cache(lat: float, lon: float, osm_filter: str, limit: int = 5):
75
  payload = f"""
76
  [out:json][timeout:25];
77
+ (
78
+ node{osm_filter}(around:1000,{lat},{lon});
79
+ way{osm_filter}(around:1000,{lat},{lon});
80
+ );
81
  out center {limit};
82
  """
83
+ resp = requests.post(
84
+ "https://overpass-api.de/api/interpreter",
85
+ data={"data": payload}
86
+ )
87
  elems = resp.json().get("elements", [])
88
+ return [
89
+ {"name": e["tags"]["name"]}
90
+ for e in elems
91
+ if e.get("tags", {}).get("name")
92
+ ]
93
+
94
+ @lru_cache(maxsize=256)
95
+ def wiki_summary_cache(name: str) -> str:
96
+ try:
97
+ return wikipedia.summary(name, sentences=2)
98
+ except:
99
+ return "No summary available."
100
+
101
+ # ── 4) Per-entity worker ────────────────────────────────────────────────────
102
+
103
+ def process_entity(ent) -> dict:
104
+ w = ent["word"]
105
+ lbl = ent["entity_group"]
106
+
107
+ if lbl == "LOC":
108
+ geo = geocode_cache(w)
109
+ if not geo:
110
+ return {
111
+ "text": w,
112
+ "label": lbl,
113
+ "type": "location",
114
+ "error": "could not geocode"
115
+ }
116
+
117
+ restaurants = fetch_osm_cache(geo["lat"], geo["lon"], '["amenity"="restaurant"]')
118
+ attractions = fetch_osm_cache(geo["lat"], geo["lon"], '["tourism"="attraction"]')
119
+
120
+ return {
121
+ "text": w,
122
+ "label": lbl,
123
+ "type": "location",
124
+ "geo": geo,
125
+ "restaurants": restaurants,
126
+ "attractions": attractions
127
+ }
128
+
129
+ # PERSON / ORG / MISC β†’ Wikipedia
130
+ summary = wiki_summary_cache(w)
131
+ return {"text": w, "label": lbl, "type": "wiki", "summary": summary}
132
+
133
+
134
+ # ── 5) Main function ────────────────────────────────────────────────────────
135
+
136
+ def get_context(
137
+ text: str,
138
+ source_lang: str, # ISO-639-3 e.g. "eng"
139
+ output_lang: str, # ISO-639-3 e.g. "fra"
140
+ auto_detect: bool
141
+ ):
142
+ # a) Ensure we have English for NER
143
  if auto_detect or source_lang != "eng":
144
  en = translate_m4t(text, source_lang, "eng", auto_detect=auto_detect)
145
  else:
146
  en = text
147
 
148
+ # b) Run NER + dedupe
149
  ner_out = ner(en)
150
+ seen = set()
151
+ unique_ents = []
152
  for ent in ner_out:
153
+ w = ent["word"]
154
+ if w in seen:
155
+ continue
156
  seen.add(w)
157
+ unique_ents.append(ent)
158
 
159
+ # c) Process each entity in parallel
160
+ entities = []
161
+ with ThreadPoolExecutor(max_workers=8) as exe:
162
+ futures = [exe.submit(process_entity, ent) for ent in unique_ents]
163
+ for fut in futures:
164
+ entities.append(fut.result())
165
+
166
+ # d) Batch-translate any non-English fields
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
  if output_lang != "eng":
168
+ to_translate = []
169
+ translations_info = [] # (kind, entity_i, [sub_i])
170
+
171
+ for i, e in enumerate(entities):
172
  if e["type"] == "wiki":
173
+ translations_info.append(("summary", i))
174
+ to_translate.append(e["summary"])
175
+
176
  elif e["type"] == "location":
177
+ for j, r in enumerate(e["restaurants"]):
178
+ translations_info.append(("restaurant", i, j))
179
+ to_translate.append(r["name"])
180
+ for j, a in enumerate(e["attractions"]):
181
+ translations_info.append(("attraction", i, j))
182
+ to_translate.append(a["name"])
183
+
184
+ # single batched call
185
+ translated = translate_m4t_batch(to_translate, "eng", output_lang)
186
+
187
+ # redistribute
188
+ for txt, info in zip(translated, translations_info):
189
+ kind = info[0]
190
+ if kind == "summary":
191
+ _, ei = info
192
+ entities[ei]["summary"] = txt
193
+ elif kind == "restaurant":
194
+ _, ei, ri = info
195
+ entities[ei]["restaurants"][ri]["name"] = txt
196
+ elif kind == "attraction":
197
+ _, ei, ai = info
198
+ entities[ei]["attractions"][ai]["name"] = txt
199
 
 
200
  return {"entities": entities}
201
 
202
+
203
+ # ── 6) Gradio interface with concurrency tuning ─────────────────────────────
204
+
205
  iface = gr.Interface(
206
  fn=get_context,
207
  inputs=[
 
213
  outputs="json",
214
  title="iVoice Context-Aware",
215
  description="Returns only the detected entities and their related info."
216
+ ).queue(concurrency_count=4, max_size=8)
217
 
218
  if __name__ == "__main__":
219
+ iface.launch(
220
+ server_name="0.0.0.0",
221
+ server_port=int(os.environ.get("PORT", 7860)),
222
+ share=True
223
+ )