Shenuki commited on
Commit
948cffb
Β·
verified Β·
1 Parent(s): 9832e9b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +76 -101
app.py CHANGED
@@ -7,27 +7,18 @@ import gradio as gr
7
  import torch
8
 
9
  from transformers import (
 
10
  SeamlessM4TProcessor,
11
  SeamlessM4TForTextToText,
12
- SeamlessM4TTokenizer, # <<< import the tokenizer class
13
  pipeline as hf_pipeline
14
  )
15
 
16
- # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
17
- # 1) Load SeamlessM4T tokenizer (slow) and processor
18
- MODEL_NAME = "facebook/hf-seamless-m4t-medium"
19
- device = "cuda" if torch.cuda.is_available() else "cpu"
20
-
21
- # load the slow tokenizer (no conversion attempted)
22
- tokenizer = SeamlessM4TTokenizer.from_pretrained(MODEL_NAME, use_fast=False)
23
-
24
- # pass it into the processor so it won't try to convert
25
- processor = SeamlessM4TProcessor.from_pretrained(
26
- MODEL_NAME,
27
- tokenizer=tokenizer
28
- )
29
-
30
- m4t_model = SeamlessM4TForTextToText.from_pretrained(MODEL_NAME).to(device).eval()
31
 
32
  def translate_m4t(text, src_iso3, tgt_iso3, auto_detect=False):
33
  src = None if auto_detect else src_iso3
@@ -35,120 +26,104 @@ def translate_m4t(text, src_iso3, tgt_iso3, auto_detect=False):
35
  tokens = m4t_model.generate(**inputs, tgt_lang=tgt_iso3)
36
  return processor.decode(tokens[0].tolist(), skip_special_tokens=True)
37
 
38
- # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
39
- # 2) BERT‐based NER
40
- ner = hf_pipeline(
41
- "ner",
42
- model="dslim/bert-base-NER-uncased",
43
- grouped_entities=True
44
- )
45
 
46
- # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
47
- # 3) Geocoding & POIs via OpenStreetMap
48
- def geocode(place: str):
49
- resp = requests.get(
50
  "https://nominatim.openstreetmap.org/search",
51
  params={"q": place, "format": "json", "limit": 1},
52
  headers={"User-Agent":"iVoiceContext/1.0"}
53
  ).json()
54
- if not resp:
55
- return None
56
- return float(resp[0]["lat"]), float(resp[0]["lon"])
57
 
58
  def fetch_osm(lat, lon, osm_filter, limit=5):
59
- query = f"""
60
- [out:json][timeout:25];
61
- (
62
- node{osm_filter}(around:1000,{lat},{lon});
63
- way{osm_filter}(around:1000,{lat},{lon});
64
- );
65
- out center {limit};
66
  """
67
- r = requests.post("https://overpass-api.de/api/interpreter", data={"data": query})
68
- elems = r.json().get("elements", [])
69
- return [
70
- {"name": e["tags"].get("name", "")}
71
- for e in elems
72
- if e.get("tags", {}).get("name")
73
- ]
74
-
75
- # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
76
  def get_context(text: str,
77
- source_lang: str, # always ISO639-3, e.g. "eng"
78
- output_lang: str, # always ISO639-3, e.g. "fra"
79
  auto_detect: bool):
80
- # 1) Ensure English text for NER
 
81
  if auto_detect or source_lang != "eng":
82
- en_text = translate_m4t(text, source_lang, "eng", auto_detect=auto_detect)
83
  else:
84
- en_text = text
85
-
86
- # 2) Run NER
87
- ner_out = ner(en_text)
88
- ents = { ent["word"]: ent["entity_group"] for ent in ner_out }
89
-
90
- results = {}
91
- for ent_text, label in ents.items():
92
- if label == "LOC":
93
- geo = geocode(ent_text)
 
 
94
  if not geo:
95
- results[ent_text] = {"type":"location","error":"could not geocode"}
96
  else:
97
- lat, lon = geo
98
- results[ent_text] = {
99
- "type": "location",
100
- "restaurants": fetch_osm(lat, lon, '["amenity"="restaurant"]'),
101
- "attractions": fetch_osm(lat, lon, '["tourism"="attraction"]'),
 
 
102
  }
 
103
  else:
 
104
  try:
105
- summ = wikipedia.summary(ent_text, sentences=2)
106
- except Exception:
107
  summ = "No summary available."
108
- results[ent_text] = {"type":"wiki","summary": summ}
109
 
110
- if not results:
111
- return {"error":"no entities found"}
112
 
113
- # 3) Translate all text fields β†’ output_lang
114
  if output_lang != "eng":
115
- for info in results.values():
116
- if info["type"] == "wiki":
117
- info["summary"] = translate_m4t(
118
- info["summary"], "eng", output_lang, auto_detect=False
119
- )
120
- elif info["type"] == "location":
121
- for key in ("restaurants","attractions"):
122
- info[key] = [
123
- {"name": translate_m4t(item["name"], "eng", output_lang)}
124
- for item in info[key]
125
  ]
126
 
127
- return results
 
128
 
129
- # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
130
  iface = gr.Interface(
131
  fn=get_context,
132
  inputs=[
133
- gr.Textbox(lines=3, placeholder="Enter text…"),
134
- gr.Textbox(label="Source Language (ISO 639-3)"),
135
- gr.Textbox(label="Target Language (ISO 639-3)"),
136
- gr.Checkbox(label="Auto-detect source language")
137
  ],
138
  outputs="json",
139
- title="iVoice Translate + Context-Aware",
140
- description=(
141
- "1) Translate your text β†’ English (if needed)\n"
142
- "2) Extract LOC/PERSON/ORG via BERT-NER\n"
143
- "3) Geocode LOC β†’ fetch nearby restaurants & attractions\n"
144
- "4) Fetch Wikipedia summaries\n"
145
- "5) Translate **all** results β†’ your target language"
146
- )
147
  ).queue()
148
 
149
  if __name__ == "__main__":
150
- iface.launch(
151
- server_name="0.0.0.0",
152
- server_port=int(os.environ.get("PORT", 7860)),
153
- share=True
154
- )
 
7
  import torch
8
 
9
  from transformers import (
10
+ SeamlessM4TTokenizer,
11
  SeamlessM4TProcessor,
12
  SeamlessM4TForTextToText,
 
13
  pipeline as hf_pipeline
14
  )
15
 
16
+ # 1) Load SeamlessM4T (slow tokenizer)
17
+ MODEL = "facebook/hf-seamless-m4t-medium"
18
+ device = "cuda" if torch.cuda.is_available() else "cpu"
19
+ tokenizer = SeamlessM4TTokenizer.from_pretrained(MODEL, use_fast=False)
20
+ processor = SeamlessM4TProcessor.from_pretrained(MODEL, tokenizer=tokenizer)
21
+ m4t_model = SeamlessM4TForTextToText.from_pretrained(MODEL).to(device).eval()
 
 
 
 
 
 
 
 
 
22
 
23
  def translate_m4t(text, src_iso3, tgt_iso3, auto_detect=False):
24
  src = None if auto_detect else src_iso3
 
26
  tokens = m4t_model.generate(**inputs, tgt_lang=tgt_iso3)
27
  return processor.decode(tokens[0].tolist(), skip_special_tokens=True)
28
 
29
+ # 2) NER pipeline
30
+ ner = hf_pipeline("ner", model="dslim/bert-base-NER-uncased", grouped_entities=True)
 
 
 
 
 
31
 
32
+ # 3) Geocode & POIs
33
+ def geocode(place):
34
+ r = requests.get(
 
35
  "https://nominatim.openstreetmap.org/search",
36
  params={"q": place, "format": "json", "limit": 1},
37
  headers={"User-Agent":"iVoiceContext/1.0"}
38
  ).json()
39
+ if not r: return None
40
+ return {"lat": float(r[0]["lat"]), "lon": float(r[0]["lon"])}
 
41
 
42
  def fetch_osm(lat, lon, osm_filter, limit=5):
43
+ payload = f"""
44
+ [out:json][timeout:25];
45
+ ( node{osm_filter}(around:1000,{lat},{lon});
46
+ way{osm_filter}(around:1000,{lat},{lon}); );
47
+ out center {limit};
 
 
48
  """
49
+ resp = requests.post("https://overpass-api.de/api/interpreter", data={"data": payload})
50
+ elems = resp.json().get("elements", [])
51
+ return [{"name": e["tags"]["name"]} for e in elems if e.get("tags",{}).get("name")]
52
+
53
+ # 4) Main function
 
 
 
 
54
  def get_context(text: str,
55
+ source_lang: str, # ISO-639-3 e.g. "eng"
56
+ output_lang: str, # ISO-639-3 e.g. "fra"
57
  auto_detect: bool):
58
+
59
+ # a) Ensure English for NER
60
  if auto_detect or source_lang != "eng":
61
+ en = translate_m4t(text, source_lang, "eng", auto_detect=auto_detect)
62
  else:
63
+ en = text
64
+
65
+ # b) Extract unique entities
66
+ ner_out = ner(en)
67
+ seen, entities = set(), []
68
+ for ent in ner_out:
69
+ w, lbl = ent["word"], ent["entity_group"]
70
+ if w in seen: continue
71
+ seen.add(w)
72
+
73
+ if lbl == "LOC":
74
+ geo = geocode(w)
75
  if not geo:
76
+ obj = {"text": w, "label": lbl, "type": "location", "error": "could not geocode"}
77
  else:
78
+ obj = {
79
+ "text": w,
80
+ "label": lbl,
81
+ "type": "location",
82
+ "geo": geo,
83
+ "restaurants": fetch_osm(geo["lat"], geo["lon"], '["amenity"="restaurant"]'),
84
+ "attractions": fetch_osm(geo["lat"], geo["lon"], '["tourism"="attraction"]')
85
  }
86
+
87
  else:
88
+ # PERSON/ORG/MISC β†’ Wikipedia
89
  try:
90
+ summ = wikipedia.summary(w, sentences=2)
91
+ except:
92
  summ = "No summary available."
93
+ obj = {"text": w, "label": lbl, "type": "wiki", "summary": summ}
94
 
95
+ entities.append(obj)
 
96
 
97
+ # c) Translate all fields β†’ output_lang
98
  if output_lang != "eng":
99
+ for e in entities:
100
+ if e["type"] == "wiki":
101
+ e["summary"] = translate_m4t(e["summary"], "eng", output_lang)
102
+ elif e["type"] == "location":
103
+ for field in ("restaurants","attractions"):
104
+ e[field] = [
105
+ {"name": translate_m4t(item["name"], "eng", output_lang)}
106
+ for item in e[field]
 
 
107
  ]
108
 
109
+ # d) Return only entities
110
+ return {"entities": entities}
111
 
112
+ # 5) Gradio interface
113
  iface = gr.Interface(
114
  fn=get_context,
115
  inputs=[
116
+ gr.Textbox(lines=3, placeholder="Enter text…"),
117
+ gr.Textbox(label="Source Language (ISO 639-3)"),
118
+ gr.Textbox(label="Target Language (ISO 639-3)"),
119
+ gr.Checkbox(label="Auto-detect source language")
120
  ],
121
  outputs="json",
122
+ title="iVoice Context-Aware",
123
+ description="Returns only the detected entities and their related info."
 
 
 
 
 
 
124
  ).queue()
125
 
126
  if __name__ == "__main__":
127
+ iface.launch(server_name="0.0.0.0",
128
+ server_port=int(os.environ.get("PORT", 7860)),
129
+ share=True)