Nuclio_test / app.py
borgo9's picture
Update app.py
4ff412d verified
raw
history blame
5.17 kB
# import faiss
# import json
# import gradio as gr
# from sentence_transformers import SentenceTransformer
# # ----------------------------------------------------
# # 1. Load FAISS Index (Prebuilt)
# # ----------------------------------------------------
# # This index already contains all document embeddings.
# INDEX_PATH = "assets/faiss_ip_768.index"
# index = faiss.read_index(INDEX_PATH)
# # ----------------------------------------------------
# # 2. Load Text Corpus (ID β†’ Content Mapping)
# # ----------------------------------------------------
# # corpus.json should be a list where each position aligns
# # with the embedding index used in FAISS.
# # Example: ["Text 1", "Text 2", "Text 3", ...]
# CORPUS_PATH = "assets/corpus.json"
# with open(CORPUS_PATH, "r", encoding="utf-8") as f:
# CORPUS = json.load(f)
# # ----------------------------------------------------
# # 3. Load Sentence Transformer Model for Query Encoding
# # ----------------------------------------------------
# # ⚠️ IMPORTANT: Must be the SAME model that was used
# HF_TOKEN = os.getenv("HF_TOKEN")
# EMBEDDING_MODEL_NAME = "google/embeddinggemma-300m" # ← Change if you used another model
# model = SentenceTransformer(EMBEDDING_MODEL_NAME, HF_TOKEN)
# # ----------------------------------------------------
# # 4. Search Function: Query β†’ Top K Results
# # ----------------------------------------------------
# def search_faiss(query, top_k=5):
# """
# Takes a text query, embeds it, searches FAISS index,
# and returns top-k most similar corpus entries.
# """
# # Encode query text into a vector
# query_embedding = model.encode([query], convert_to_numpy=True)
# # Perform similarity search (returns distances and indices)
# distances, indices = index.search(query_embedding, top_k)
# # Collect matching corpus entries
# results = []
# for rank, idx in enumerate(indices[0]):
# results.append(
# f"#{rank+1} | Score: {distances[0][rank]:.4f}\n{CORPUS[idx]}"
# )
# return "\n\n".join(results)
# # ----------------------------------------------------
# # 5. Build Gradio Interface
# # ----------------------------------------------------
# def gradio_search(query, top_k):
# if not query.strip():
# return "Please enter a search query."
# return search_faiss(query, top_k)
# demo = gr.Interface(
# fn=gradio_search,
# inputs=[
# gr.Textbox(label="Enter your search query"),
# gr.Slider(1, 20, value=5, step=1, label="Number of results"),
# ],
# outputs=gr.Textbox(label="Search Results"),
# title="FAISS Semantic Search",
# description="A simple search engine powered by FAISS + Sentence Transformers."
# )
# # ----------------------------------------------------
# # 6. Run App (for local testing or Spaces deployment)
# # ----------------------------------------------------
# if __name__ == "__main__":
# demo.launch()
import os
import gradio as gr
from transformers import AutoTokenizer, AutoModel
import torch
import numpy as np
# ----------------------------------------------------
# 1. HF_TOKEN (optional, for private models)
# ----------------------------------------------------
HF_TOKEN = os.getenv("HF_TOKEN") # Retrieved from Hugging Face Secrets
# ----------------------------------------------------
# 2. Load EmbeddingGemma-300M model
# ----------------------------------------------------
MODEL_NAME = "google/embeddinggemma-300m"
# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_auth_token=HF_TOKEN)
model = AutoModel.from_pretrained(MODEL_NAME, use_auth_token=HF_TOKEN)
# Function to encode text into embeddings
def encode(texts):
"""
Encode a list of texts into vector embeddings using EmbeddingGemma-300M.
Mean pooling over token embeddings is used.
"""
inputs = tokenizer(
texts, return_tensors="pt", padding=True, truncation=True
)
with torch.no_grad():
outputs = model(**inputs)
# Mean pooling over tokens
embeddings = outputs.last_hidden_state.mean(dim=1)
# Convert to numpy float32
return embeddings.cpu().numpy().astype(np.float32)
# ----------------------------------------------------
# 3. Gradio test function
# ----------------------------------------------------
def test_encode(text):
"""
Simple test function to check if embeddings are generated correctly.
Returns the shape of the resulting embedding vector.
"""
emb = encode([text])
return f"Embedding shape: {emb.shape}"
# ----------------------------------------------------
# 4. Build Gradio Interface
# ----------------------------------------------------
demo = gr.Interface(
fn=test_encode,
inputs=gr.Textbox(label="Type some text"),
outputs=gr.Textbox(label="Embedding info"),
title="Test EmbeddingGemma-300M",
description="This Space tests whether the EmbeddingGemma-300M model can generate embeddings."
)
# ----------------------------------------------------
# 5. Launch App
# ----------------------------------------------------
if __name__ == "__main__":
demo.launch()