|
|
|
|
|
import os |
|
import requests |
|
import wikipedia |
|
import gradio as gr |
|
import torch |
|
|
|
from functools import lru_cache |
|
from concurrent.futures import ThreadPoolExecutor |
|
from typing import List |
|
|
|
from transformers import ( |
|
SeamlessM4TTokenizer, |
|
SeamlessM4TProcessor, |
|
SeamlessM4TForTextToText, |
|
pipeline as hf_pipeline |
|
) |
|
|
|
|
|
|
|
MODEL = "facebook/hf-seamless-m4t-medium" |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
tokenizer = SeamlessM4TTokenizer.from_pretrained(MODEL, use_fast=False) |
|
processor = SeamlessM4TProcessor.from_pretrained(MODEL, tokenizer=tokenizer) |
|
|
|
m4t_model = SeamlessM4TForTextToText.from_pretrained(MODEL).to(device) |
|
if device == "cuda": |
|
m4t_model = m4t_model.half() |
|
m4t_model.eval() |
|
|
|
def translate_m4t(text: str, src_iso3: str, tgt_iso3: str, auto_detect=False) -> str: |
|
src = None if auto_detect else src_iso3 |
|
inputs = processor(text=text, src_lang=src, return_tensors="pt").to(device) |
|
tokens = m4t_model.generate(**inputs, tgt_lang=tgt_iso3) |
|
return processor.decode(tokens[0].tolist(), skip_special_tokens=True) |
|
|
|
def translate_m4t_batch( |
|
texts: List[str], src_iso3: str, tgt_iso3: str, auto_detect=False |
|
) -> List[str]: |
|
src = None if auto_detect else src_iso3 |
|
inputs = processor( |
|
text=texts, src_lang=src, return_tensors="pt", padding=True |
|
).to(device) |
|
tokens = m4t_model.generate( |
|
**inputs, |
|
tgt_lang=tgt_iso3, |
|
max_new_tokens=60, |
|
num_beams=1 |
|
) |
|
return processor.batch_decode(tokens, skip_special_tokens=True) |
|
|
|
|
|
|
|
|
|
ner = hf_pipeline( |
|
"ner", |
|
model="dslim/bert-base-NER-uncased", |
|
aggregation_strategy="simple" |
|
) |
|
|
|
|
|
|
|
|
|
@lru_cache(maxsize=256) |
|
def geocode_cache(place: str): |
|
r = requests.get( |
|
"https://nominatim.openstreetmap.org/search", |
|
params={"q": place, "format": "json", "limit": 1}, |
|
headers={"User-Agent": "iVoiceContext/1.0"} |
|
).json() |
|
if not r: |
|
return None |
|
return {"lat": float(r[0]["lat"]), "lon": float(r[0]["lon"])} |
|
|
|
@lru_cache(maxsize=256) |
|
def fetch_osm_cache(lat: float, lon: float, osm_filter: str, limit: int = 5): |
|
payload = f""" |
|
[out:json][timeout:25]; |
|
( |
|
node{osm_filter}(around:1000,{lat},{lon}); |
|
way{osm_filter}(around:1000,{lat},{lon}); |
|
); |
|
out center {limit}; |
|
""" |
|
resp = requests.post( |
|
"https://overpass-api.de/api/interpreter", |
|
data={"data": payload} |
|
) |
|
elems = resp.json().get("elements", []) |
|
return [ |
|
{"name": e["tags"]["name"]} |
|
for e in elems |
|
if e.get("tags", {}).get("name") |
|
] |
|
|
|
@lru_cache(maxsize=256) |
|
def wiki_summary_cache(name: str) -> str: |
|
try: |
|
return wikipedia.summary(name, sentences=2) |
|
except: |
|
return "No summary available." |
|
|
|
|
|
|
|
|
|
def process_entity(ent) -> dict: |
|
w = ent["word"] |
|
lbl = ent["entity_group"] |
|
|
|
if lbl == "LOC": |
|
geo = geocode_cache(w) |
|
if not geo: |
|
return { |
|
"text": w, |
|
"label": lbl, |
|
"type": "location", |
|
"error": "could not geocode" |
|
} |
|
|
|
restaurants = fetch_osm_cache(geo["lat"], geo["lon"], '["amenity"="restaurant"]') |
|
attractions = fetch_osm_cache(geo["lat"], geo["lon"], '["tourism"="attraction"]') |
|
|
|
return { |
|
"text": w, |
|
"label": lbl, |
|
"type": "location", |
|
"geo": geo, |
|
"restaurants": restaurants, |
|
"attractions": attractions |
|
} |
|
|
|
|
|
summary = wiki_summary_cache(w) |
|
return {"text": w, "label": lbl, "type": "wiki", "summary": summary} |
|
|
|
|
|
|
|
|
|
def get_context( |
|
text: str, |
|
source_lang: str, |
|
output_lang: str, |
|
auto_detect: bool |
|
): |
|
|
|
if auto_detect or source_lang != "eng": |
|
en = translate_m4t(text, source_lang, "eng", auto_detect=auto_detect) |
|
else: |
|
en = text |
|
|
|
|
|
ner_out = ner(en) |
|
seen = set() |
|
unique_ents = [] |
|
for ent in ner_out: |
|
w = ent["word"] |
|
if w in seen: |
|
continue |
|
seen.add(w) |
|
unique_ents.append(ent) |
|
|
|
|
|
entities = [] |
|
with ThreadPoolExecutor(max_workers=8) as exe: |
|
futures = [exe.submit(process_entity, ent) for ent in unique_ents] |
|
for fut in futures: |
|
entities.append(fut.result()) |
|
|
|
|
|
if output_lang != "eng": |
|
to_translate = [] |
|
translations_info = [] |
|
|
|
for i, e in enumerate(entities): |
|
if e["type"] == "wiki": |
|
translations_info.append(("summary", i)) |
|
to_translate.append(e["summary"]) |
|
elif e["type"] == "location": |
|
for j, r in enumerate(e["restaurants"]): |
|
translations_info.append(("restaurant", i, j)) |
|
to_translate.append(r["name"]) |
|
for j, a in enumerate(e["attractions"]): |
|
translations_info.append(("attraction", i, j)) |
|
to_translate.append(a["name"]) |
|
|
|
translated = translate_m4t_batch(to_translate, "eng", output_lang) |
|
|
|
for txt, info in zip(translated, translations_info): |
|
kind = info[0] |
|
if kind == "summary": |
|
_, ei = info |
|
entities[ei]["summary"] = txt |
|
elif kind == "restaurant": |
|
_, ei, ri = info |
|
entities[ei]["restaurants"][ri]["name"] = txt |
|
elif kind == "attraction": |
|
_, ei, ai = info |
|
entities[ei]["attractions"][ai]["name"] = txt |
|
|
|
return {"entities": entities} |
|
|
|
|
|
|
|
|
|
iface = gr.Interface( |
|
fn=get_context, |
|
inputs=[ |
|
gr.Textbox(lines=3, placeholder="Enter textβ¦"), |
|
gr.Textbox(label="Source Language (ISO 639-3)"), |
|
gr.Textbox(label="Target Language (ISO 639-3)"), |
|
gr.Checkbox(label="Auto-detect source language") |
|
], |
|
outputs="json", |
|
title="iVoice Context-Aware", |
|
description="Returns only the detected entities and their related info." |
|
).queue() |
|
|
|
if __name__ == "__main__": |
|
iface.launch( |
|
server_name="0.0.0.0", |
|
server_port=int(os.environ.get("PORT", 7860)), |
|
share=True |
|
) |
|
|