Spaces:
Sleeping
Sleeping
| import os, json | |
| from typing import List | |
| import numpy as np | |
| import gradio as gr | |
| import faiss | |
| from sentence_transformers import SentenceTransformer | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| # ---------- Paths ---------- | |
| APP_DIR = os.path.dirname(__file__) | |
| ASSETS_DIR = os.path.join(APP_DIR, "assets") | |
| CORPUS_JSON = os.path.join(ASSETS_DIR, "corpus.json") | |
| FAISS_MAIN = os.path.join(ASSETS_DIR, "faiss_ip_768.index") | |
| # ---------- Load corpus ---------- | |
| with open(CORPUS_JSON, "r", encoding="utf-8") as f: | |
| corpus = json.load(f) # [{"title", "text"}, ...] | |
| # ---------- Load FAISS index ---------- | |
| if not os.path.exists(FAISS_MAIN): | |
| raise FileNotFoundError(f"Missing FAISS index at {FAISS_MAIN}") | |
| index = faiss.read_index(FAISS_MAIN) | |
| # Infer dimension from index | |
| EMB_DIM = index.d | |
| # ---------- Model ---------- | |
| model = SentenceTransformer("google/embeddinggemma-300m", token=HF_TOKEN) | |
| # ---------- Search ---------- | |
| def do_search(query: str, top_k: int = 5) -> List[List[str]]: | |
| if not query.strip(): | |
| return [] | |
| q_emb = model.encode_query(query, normalize_embeddings=True, convert_to_numpy=True).astype("float32") | |
| scores, idxs = index.search(q_emb[None, :], top_k) | |
| rows = [] | |
| for score, i in zip(scores[0], idxs[0]): | |
| if i == -1: continue | |
| item = corpus[i] | |
| snippet = item["text"][:380] + ("…" if len(item["text"]) > 380 else "") | |
| rows.append([f"{score:.4f}", item["title"], snippet]) | |
| return rows | |
| # ---------- Similarity ---------- | |
| def do_similarity(text_a: str, text_b: str) -> float: | |
| a = model.encode_document([text_a], normalize_embeddings=True, convert_to_numpy=True)[0] | |
| b = model.encode_document([text_b], normalize_embeddings=True, convert_to_numpy=True)[0] | |
| return float(np.dot(a, b)) | |
| # ---------- UI ---------- | |
| with gr.Blocks(title="EmbeddingGemma FAISS Search") as demo: | |
| gr.Markdown("# Simple FAISS Search (No Matryoshka)") | |
| with gr.Tabs(): | |
| with gr.TabItem("Search"): | |
| q = gr.Textbox(label="Query") | |
| topk = gr.Slider(1, 20, value=5, step=1, label="Top-K") | |
| run = gr.Button("Search") | |
| out = gr.Dataframe(headers=["score", "title", "snippet"], wrap=True) | |
| run.click(lambda query, k: do_search(query, int(k)), [q, topk], out) | |
| with gr.TabItem("Similarity"): | |
| a = gr.Textbox(lines=4, label="Text A") | |
| b = gr.Textbox(lines=4, label="Text B") | |
| sim_btn = gr.Button("Compute") | |
| sim_out = gr.Number(label="Cosine similarity") | |
| sim_btn.click(lambda x, y: do_similarity(x, y), [a, b], sim_out) | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860) | |