Nuclio_test / app.py
borgo9's picture
Update app.py
89902c3 verified
raw
history blame
2.69 kB
import os, json
from typing import List
import numpy as np
import gradio as gr
import faiss
from sentence_transformers import SentenceTransformer
HF_TOKEN = os.getenv("HF_TOKEN")
# ---------- Paths ----------
APP_DIR = os.path.dirname(__file__)
ASSETS_DIR = os.path.join(APP_DIR, "assets")
CORPUS_JSON = os.path.join(ASSETS_DIR, "corpus.json")
FAISS_MAIN = os.path.join(ASSETS_DIR, "faiss_ip_768.index")
# ---------- Load corpus ----------
with open(CORPUS_JSON, "r", encoding="utf-8") as f:
corpus = json.load(f) # [{"title", "text"}, ...]
# ---------- Load FAISS index ----------
if not os.path.exists(FAISS_MAIN):
raise FileNotFoundError(f"Missing FAISS index at {FAISS_MAIN}")
index = faiss.read_index(FAISS_MAIN)
# Infer dimension from index
EMB_DIM = index.d
# ---------- Model ----------
model = SentenceTransformer("google/embeddinggemma-300m", token=HF_TOKEN)
# ---------- Search ----------
def do_search(query: str, top_k: int = 5) -> List[List[str]]:
if not query.strip():
return []
q_emb = model.encode_query(query, normalize_embeddings=True, convert_to_numpy=True).astype("float32")
scores, idxs = index.search(q_emb[None, :], top_k)
rows = []
for score, i in zip(scores[0], idxs[0]):
if i == -1: continue
item = corpus[i]
snippet = item["text"][:380] + ("…" if len(item["text"]) > 380 else "")
rows.append([f"{score:.4f}", item["title"], snippet])
return rows
# ---------- Similarity ----------
def do_similarity(text_a: str, text_b: str) -> float:
a = model.encode_document([text_a], normalize_embeddings=True, convert_to_numpy=True)[0]
b = model.encode_document([text_b], normalize_embeddings=True, convert_to_numpy=True)[0]
return float(np.dot(a, b))
# ---------- UI ----------
with gr.Blocks(title="EmbeddingGemma FAISS Search") as demo:
gr.Markdown("# Simple FAISS Search (No Matryoshka)")
with gr.Tabs():
with gr.TabItem("Search"):
q = gr.Textbox(label="Query")
topk = gr.Slider(1, 20, value=5, step=1, label="Top-K")
run = gr.Button("Search")
out = gr.Dataframe(headers=["score", "title", "snippet"], wrap=True)
run.click(lambda query, k: do_search(query, int(k)), [q, topk], out)
with gr.TabItem("Similarity"):
a = gr.Textbox(lines=4, label="Text A")
b = gr.Textbox(lines=4, label="Text B")
sim_btn = gr.Button("Compute")
sim_out = gr.Number(label="Cosine similarity")
sim_btn.click(lambda x, y: do_similarity(x, y), [a, b], sim_out)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)