File size: 7,457 Bytes
2a3aa81
 
5bf38c0
a70a295
 
e08081f
5bf38c0
0d7fa59
6bebb04
 
 
 
5bf38c0
948cffb
5bf38c0
 
 
 
 
ce5002a
6bebb04
948cffb
 
6bebb04
948cffb
 
5bf38c0
6bebb04
 
 
 
 
 
5bf38c0
 
 
 
2a3aa81
6bebb04
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ce5002a
 
 
 
 
 
 
 
e08081f
6bebb04
 
 
 
948cffb
2a3aa81
 
6bebb04
5bf38c0
6bebb04
 
948cffb
a70a295
6bebb04
 
948cffb
 
6bebb04
 
 
 
948cffb
2a3aa81
6bebb04
 
 
 
948cffb
6bebb04
 
 
 
 
 
 
 
 
 
 
 
 
ce5002a
6bebb04
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ce5002a
 
6bebb04
 
ce5002a
5bf38c0
948cffb
5bf38c0
948cffb
 
6bebb04
948cffb
6bebb04
 
948cffb
6bebb04
 
 
948cffb
6bebb04
948cffb
ce5002a
6bebb04
 
 
 
 
 
ce5002a
5bf38c0
6bebb04
ce5002a
6bebb04
 
948cffb
6bebb04
 
948cffb
6bebb04
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5bf38c0
948cffb
5bf38c0
6bebb04
ce5002a
6bebb04
a70a295
 
5bf38c0
948cffb
 
 
 
5bf38c0
a70a295
948cffb
 
ce5002a
e08081f
a70a295
6bebb04
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
# app.py

import os
import requests
import wikipedia
import gradio as gr
import torch

from functools import lru_cache
from concurrent.futures import ThreadPoolExecutor
from typing import List

from transformers import (
    SeamlessM4TTokenizer,
    SeamlessM4TProcessor,
    SeamlessM4TForTextToText,
    pipeline as hf_pipeline
)

# ── 1) Model setup ────────────────────────────────────────────────────────────

MODEL = "facebook/hf-seamless-m4t-medium"
device = "cuda" if torch.cuda.is_available() else "cpu"

tokenizer = SeamlessM4TTokenizer.from_pretrained(MODEL, use_fast=False)
processor = SeamlessM4TProcessor.from_pretrained(MODEL, tokenizer=tokenizer)

m4t_model = SeamlessM4TForTextToText.from_pretrained(MODEL).to(device)
if device == "cuda":
    m4t_model = m4t_model.half()   # FP16 for faster inference on GPU
m4t_model.eval()

def translate_m4t(text: str, src_iso3: str, tgt_iso3: str, auto_detect=False) -> str:
    src = None if auto_detect else src_iso3
    inputs = processor(text=text, src_lang=src, return_tensors="pt").to(device)
    tokens = m4t_model.generate(**inputs, tgt_lang=tgt_iso3)
    return processor.decode(tokens[0].tolist(), skip_special_tokens=True)

def translate_m4t_batch(
    texts: List[str], src_iso3: str, tgt_iso3: str, auto_detect=False
) -> List[str]:
    src = None if auto_detect else src_iso3
    inputs = processor(
        text=texts, src_lang=src, return_tensors="pt", padding=True
    ).to(device)
    tokens = m4t_model.generate(
        **inputs,
        tgt_lang=tgt_iso3,
        max_new_tokens=60,
        num_beams=1
    )
    return processor.batch_decode(tokens, skip_special_tokens=True)


# ── 2) NER pipeline (updated for deprecation) ────────────────────────────────

ner = hf_pipeline(
    "ner",
    model="dslim/bert-base-NER-uncased",
    aggregation_strategy="simple"
)


# ── 3) CACHING helpers ──────────────────────────────────────────────────────

@lru_cache(maxsize=256)
def geocode_cache(place: str):
    r = requests.get(
        "https://nominatim.openstreetmap.org/search",
        params={"q": place, "format": "json", "limit": 1},
        headers={"User-Agent": "iVoiceContext/1.0"}
    ).json()
    if not r:
        return None
    return {"lat": float(r[0]["lat"]), "lon": float(r[0]["lon"])}

@lru_cache(maxsize=256)
def fetch_osm_cache(lat: float, lon: float, osm_filter: str, limit: int = 5):
    payload = f"""
      [out:json][timeout:25];
      (
        node{osm_filter}(around:1000,{lat},{lon});
        way{osm_filter}(around:1000,{lat},{lon});
      );
      out center {limit};
    """
    resp = requests.post(
        "https://overpass-api.de/api/interpreter",
        data={"data": payload}
    )
    elems = resp.json().get("elements", [])
    return [
        {"name": e["tags"]["name"]}
        for e in elems
        if e.get("tags", {}).get("name")
    ]

@lru_cache(maxsize=256)
def wiki_summary_cache(name: str) -> str:
    try:
        return wikipedia.summary(name, sentences=2)
    except:
        return "No summary available."


# ── 4) Per-entity worker ────────────────────────────────────────────────────

def process_entity(ent) -> dict:
    w = ent["word"]
    lbl = ent["entity_group"]

    if lbl == "LOC":
        geo = geocode_cache(w)
        if not geo:
            return {
                "text": w,
                "label": lbl,
                "type": "location",
                "error": "could not geocode"
            }

        restaurants = fetch_osm_cache(geo["lat"], geo["lon"], '["amenity"="restaurant"]')
        attractions = fetch_osm_cache(geo["lat"], geo["lon"], '["tourism"="attraction"]')

        return {
            "text": w,
            "label": lbl,
            "type": "location",
            "geo": geo,
            "restaurants": restaurants,
            "attractions": attractions
        }

    # PERSON / ORG / MISC β†’ Wikipedia
    summary = wiki_summary_cache(w)
    return {"text": w, "label": lbl, "type": "wiki", "summary": summary}


# ── 5) Main function ────────────────────────────────────────────────────────

def get_context(
    text: str,
    source_lang: str,
    output_lang: str,
    auto_detect: bool
):
    # a) Ensure English for NER
    if auto_detect or source_lang != "eng":
        en = translate_m4t(text, source_lang, "eng", auto_detect=auto_detect)
    else:
        en = text

    # b) Run NER + dedupe
    ner_out = ner(en)
    seen = set()
    unique_ents = []
    for ent in ner_out:
        w = ent["word"]
        if w in seen:
            continue
        seen.add(w)
        unique_ents.append(ent)

    # c) Parallel I/O
    entities = []
    with ThreadPoolExecutor(max_workers=8) as exe:
        futures = [exe.submit(process_entity, ent) for ent in unique_ents]
        for fut in futures:
            entities.append(fut.result())

    # d) Batch-translate non-English fields
    if output_lang != "eng":
        to_translate = []
        translations_info = []

        for i, e in enumerate(entities):
            if e["type"] == "wiki":
                translations_info.append(("summary", i))
                to_translate.append(e["summary"])
            elif e["type"] == "location":
                for j, r in enumerate(e["restaurants"]):
                    translations_info.append(("restaurant", i, j))
                    to_translate.append(r["name"])
                for j, a in enumerate(e["attractions"]):
                    translations_info.append(("attraction", i, j))
                    to_translate.append(a["name"])

        translated = translate_m4t_batch(to_translate, "eng", output_lang)

        for txt, info in zip(translated, translations_info):
            kind = info[0]
            if kind == "summary":
                _, ei = info
                entities[ei]["summary"] = txt
            elif kind == "restaurant":
                _, ei, ri = info
                entities[ei]["restaurants"][ri]["name"] = txt
            elif kind == "attraction":
                _, ei, ai = info
                entities[ei]["attractions"][ai]["name"] = txt

    return {"entities": entities}


# ── 6) Gradio interface ─────────────────────────────────────────────────────

iface = gr.Interface(
    fn=get_context,
    inputs=[
        gr.Textbox(lines=3, placeholder="Enter text…"),
        gr.Textbox(label="Source Language (ISO 639-3)"),
        gr.Textbox(label="Target Language (ISO 639-3)"),
        gr.Checkbox(label="Auto-detect source language")
    ],
    outputs="json",
    title="iVoice Context-Aware",
    description="Returns only the detected entities and their related info."
).queue()    # ← removed unsupported kwargs

if __name__ == "__main__":
    iface.launch(
        server_name="0.0.0.0",
        server_port=int(os.environ.get("PORT", 7860)),
        share=True
    )