borgo9 commited on
Commit
686b6b9
Β·
verified Β·
1 Parent(s): f1b4f12

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -112
app.py CHANGED
@@ -1,125 +1,77 @@
1
- import os
2
  import json
3
- import numpy as np
4
  import gradio as gr
5
- import faiss
6
  from sentence_transformers import SentenceTransformer
7
 
8
- # === CONFIGURATION ===
9
- # Define paths for assets (corpus + embeddings + optional FAISS index)
10
- BASE_DIR = os.path.dirname(__file__)
11
- DATA_DIR = os.path.join(BASE_DIR, "assets")
12
-
13
- CORPUS_PATH = os.path.join(DATA_DIR, "corpus.json") # Text data (title + article)
14
- EMB_PATH_32 = os.path.join(DATA_DIR, "embeddings_fp32.npy") # Embeddings in float32 (preferred)
15
- EMB_PATH_16 = os.path.join(DATA_DIR, "embeddings_fp16.npy") # Embeddings in float16 (smaller, needs casting)
16
- FAISS_INDEX_PATH = os.path.join(DATA_DIR, "faiss_main.index")# Precomputed FAISS index (optional, saves time)
17
-
18
- # Allow searching with different embedding sizes (for speed/accuracy tradeoff)
19
- DIM_OPTIONS = [768, 512, 256, 128]
20
- DEFAULT_DIM = 768
21
-
22
- # === LOAD CORPUS ===
23
- # Expecting corpus.json to be a list of dicts: [{"title": "...", "text": "..."}, ...]
24
  with open(CORPUS_PATH, "r", encoding="utf-8") as f:
25
- ARTICLES = json.load(f)
26
-
27
- # === LOAD EMBEDDINGS ===
28
- # Choose FP32 if available; otherwise load FP16 and convert to FP32 (required by FAISS)
29
- if os.path.exists(EMB_PATH_32):
30
- VECTORS = np.load(EMB_PATH_32).astype("float32", copy=False)
31
- elif os.path.exists(EMB_PATH_16):
32
- VECTORS = np.load(EMB_PATH_16).astype("float32")
33
- else:
34
- raise RuntimeError("No embeddings found in /assets")
35
-
36
- # Sanity check: Ensure number of embeddings matches number of articles
37
- assert VECTORS.shape[0] == len(ARTICLES), "Embedding and corpus length mismatch"
38
-
39
- EMB_SIZE = VECTORS.shape[1] # e.g. 768 for MiniLM or Gemma
40
-
41
- # === LOAD OR INITIALIZE FAISS INDEX ===
42
- # FAISS allows ultra-fast vector search using approximate or exact nearest neighbor methods.
43
- if os.path.exists(FAISS_INDEX_PATH):
44
- base_index = faiss.read_index(FAISS_INDEX_PATH) # Load prebuilt index for speed
45
- else:
46
- # Normalizing vectors turns inner product (dot product) into cosine similarity
47
- faiss.normalize_L2(VECTORS)
48
- base_index = faiss.IndexFlatIP(EMB_SIZE) # Flat index = brute force but optimized C++
49
- base_index.add(VECTORS)
50
-
51
- # === BUILD VARIANT INDEXES WITH REDUCED DIMENSIONS (optional speed boost) ===
52
- class DimensionalIndexBank:
53
  """
54
- Manages multiple FAISS indexes at different embedding dimensions.
55
- Useful for experimenting with search speed vs accuracy.
56
  """
57
- def __init__(self, full_matrix):
58
- self.indices = {}
59
- for dim in DIM_OPTIONS:
60
- # If we're using full dimension and prebuilt index exists, reuse it
61
- if dim == EMB_SIZE and os.path.exists(FAISS_INDEX_PATH):
62
- self.indices[dim] = base_index
63
- else:
64
- # Slice to first 'dim' components (Matryoshka-style compression)
65
- cut = full_matrix[:, :dim].astype("float32", copy=False)
66
- faiss.normalize_L2(cut)
67
- idx = faiss.IndexFlatIP(dim)
68
- idx.add(cut)
69
- self.indices[dim] = idx
70
 
71
- def query(self, vector, top_k, dim):
72
- """
73
- Search the appropriate FAISS index based on embedding dimension.
74
- """
75
- v = vector[:dim].reshape(1, -1).astype("float32", copy=False)
76
- faiss.normalize_L2(v)
77
- return self.indices[dim].search(v, top_k)
78
 
79
- searcher = DimensionalIndexBank(VECTORS)
80
-
81
- # === LOAD SENTENCE TRANSFORMER ===
82
- # This is used *only for encoding user queries*. Document embeddings are precomputed.
83
- encoder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
84
-
85
- # === HELPERS ===
86
- def truncate(text, limit=300):
87
- """Shorten text to a preview/snippet form."""
88
- return text[:limit] + "…" if len(text) > limit else text
89
-
90
- def run_search(query, k, dim):
91
- """Core search function used by Gradio."""
92
- if not query.strip():
93
- return []
94
- q_vec = encoder.encode(query, normalize_embeddings=True)
95
- scores, ids = searcher.query(q_vec, k, dim)
96
  results = []
97
- for score, idx in zip(scores[0], ids[0]):
98
- article = ARTICLES[idx]
99
- results.append([f"{score:.4f}", article["title"], truncate(article["text"])])
100
- return results
101
-
102
- # === BUILD GRADIO APP ===
103
- def build_app():
104
- with gr.Blocks(title="Semantic Search (FAISS + Transformers)") as demo:
105
- gr.Markdown("## πŸ” Fast Semantic Search Over Wikipedia")
106
 
107
- query = gr.Textbox(label="Search Input", value="Who discovered penicillin?")
108
- cols = gr.Row()
109
- with cols:
110
- k_slider = gr.Slider(1, 20, step=1, value=5, label="Results to Show")
111
- dim_choice = gr.Dropdown([str(d) for d in DIM_OPTIONS], value=str(DEFAULT_DIM), label="Embedding Size")
112
 
113
- btn = gr.Button("Search")
114
- output = gr.Dataframe(headers=["Score", "Title", "Snippet"], wrap=True)
115
-
116
- btn.click(lambda q, k, d: run_search(q, int(k), int(d)), [query, k_slider, dim_choice], output)
117
-
118
- return demo
119
-
120
- app = build_app()
121
-
122
- # === MAIN ENTRY ===
 
 
 
 
 
 
 
 
 
 
 
 
123
  if __name__ == "__main__":
124
- # When hosted on Hugging Face, Gradio auto-manages this
125
- app.launch(server_name="0.0.0.0", server_port=7860)
 
1
+ import faiss
2
  import json
 
3
  import gradio as gr
 
4
  from sentence_transformers import SentenceTransformer
5
 
6
+ # ----------------------------------------------------
7
+ # 1. Load FAISS Index (Prebuilt)
8
+ # ----------------------------------------------------
9
+ # This index already contains all document embeddings.
10
+ INDEX_PATH = "assets/faiss_main.index"
11
+ index = faiss.read_index(INDEX_PATH)
12
+
13
+ # ----------------------------------------------------
14
+ # 2. Load Text Corpus (ID β†’ Content Mapping)
15
+ # ----------------------------------------------------
16
+ # corpus.json should be a list where each position aligns
17
+ # with the embedding index used in FAISS.
18
+ # Example: ["Text 1", "Text 2", "Text 3", ...]
19
+ CORPUS_PATH = "assets/corpus.json"
 
 
20
  with open(CORPUS_PATH, "r", encoding="utf-8") as f:
21
+ CORPUS = json.load(f)
22
+
23
+ # ----------------------------------------------------
24
+ # 3. Load Sentence Transformer Model for Query Encoding
25
+ # ----------------------------------------------------
26
+ # ⚠️ IMPORTANT: Must be the SAME model that was used
27
+ # to generate the original corpus embeddings stored in FAISS!
28
+ EMBEDDING_MODEL_NAME = "google/embeddinggemma-300m" # ← Change if you used another model
29
+ model = SentenceTransformer(EMBEDDING_MODEL_NAME)
30
+
31
+ # ----------------------------------------------------
32
+ # 4. Search Function: Query β†’ Top K Results
33
+ # ----------------------------------------------------
34
+ def search_faiss(query, top_k=5):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  """
36
+ Takes a text query, embeds it, searches FAISS index,
37
+ and returns top-k most similar corpus entries.
38
  """
39
+ # Encode query text into a vector
40
+ query_embedding = model.encode([query], convert_to_numpy=True)
 
 
 
 
 
 
 
 
 
 
 
41
 
42
+ # Perform similarity search (returns distances and indices)
43
+ distances, indices = index.search(query_embedding, top_k)
 
 
 
 
 
44
 
45
+ # Collect matching corpus entries
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  results = []
47
+ for rank, idx in enumerate(indices[0]):
48
+ results.append(
49
+ f"#{rank+1} | Score: {distances[0][rank]:.4f}\n{CORPUS[idx]}"
50
+ )
 
 
 
 
 
51
 
52
+ return "\n\n".join(results)
 
 
 
 
53
 
54
+ # ----------------------------------------------------
55
+ # 5. Build Gradio Interface
56
+ # ----------------------------------------------------
57
+ def gradio_search(query, top_k):
58
+ if not query.strip():
59
+ return "Please enter a search query."
60
+ return search_faiss(query, top_k)
61
+
62
+ demo = gr.Interface(
63
+ fn=gradio_search,
64
+ inputs=[
65
+ gr.Textbox(label="Enter your search query"),
66
+ gr.Slider(1, 20, value=5, step=1, label="Number of results"),
67
+ ],
68
+ outputs=gr.Textbox(label="Search Results"),
69
+ title="FAISS Semantic Search",
70
+ description="A simple search engine powered by FAISS + Sentence Transformers."
71
+ )
72
+
73
+ # ----------------------------------------------------
74
+ # 6. Run App (for local testing or Spaces deployment)
75
+ # ----------------------------------------------------
76
  if __name__ == "__main__":
77
+ demo.launch()