bourdoiscatie commited on
Commit
28467cd
·
verified ·
1 Parent(s): ed9e1bf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +302 -65
app.py CHANGED
@@ -2,89 +2,284 @@ import time
2
  import gradio as gr
3
  from datasets import load_dataset
4
  import pandas as pd
5
- from sentence_transformers import SentenceTransformer
6
  from sentence_transformers.quantization import quantize_embeddings
7
  import faiss
8
  from usearch.index import Index
 
 
 
 
9
 
10
- # Load titles and texts
11
- wikipedia_dataset = load_dataset("bourdoiscatie/wikipedia_fr_2022_250K", split="train", num_proc=4).select_columns(["title", "text", "wiki_id"])
12
 
 
 
 
 
 
 
13
  def add_link(example):
14
  example["title"] = '['+example["title"]+']('+'https://fr.wikipedia.org/wiki?curid='+str(example["wiki_id"])+')'
15
  return example
16
  wikipedia_dataset = wikipedia_dataset.map(add_link)
17
-
18
- # Load the int8 and binary indices. Int8 is loaded as a view to save memory, as we never actually perform search with it.
 
19
  int8_view = Index.restore("wikipedia_fr_2022_250K_int8_usearch.index", view=True)
20
  binary_index: faiss.IndexBinaryFlat = faiss.read_index_binary("wikipedia_fr_2022_250K_ubinary_faiss.index")
21
  binary_ivf: faiss.IndexBinaryIVF = faiss.read_index_binary("wikipedia_fr_2022_250K_ubinary_ivf_faiss.index")
22
 
23
- # Load the SentenceTransformer model for embedding the queries
24
- model = SentenceTransformer("OrdalieTech/Solon-embeddings-large-0.1")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
 
27
- def search(query, top_k: int = 20, rescore_multiplier: int = 1, use_approx: bool = False):
28
- # 1. Embed the query as float32
 
 
 
 
29
  start_time = time.time()
30
- query_embedding = model.encode(query, prompt="query: ")
31
  embed_time = time.time() - start_time
32
 
33
- # 2. Quantize the query to ubinary
34
  start_time = time.time()
35
  query_embedding_ubinary = quantize_embeddings(query_embedding.reshape(1, -1), "ubinary")
36
  quantize_time = time.time() - start_time
37
 
38
- # 3. Search the binary index (either exact or approximate)
39
  index = binary_ivf if use_approx else binary_index
40
  start_time = time.time()
41
- _scores, binary_ids = index.search(query_embedding_ubinary, top_k * rescore_multiplier)
42
  binary_ids = binary_ids[0]
43
- search_time = time.time() - start_time
44
 
45
- # 4. Load the corresponding int8 embeddings
46
  start_time = time.time()
47
  int8_embeddings = int8_view[binary_ids].astype(int)
48
  load_time = time.time() - start_time
49
 
50
- # 5. Rescore the top_k * rescore_multiplier using the float32 query embedding and the int8 document embeddings
51
  start_time = time.time()
52
  scores = query_embedding @ int8_embeddings.T
53
  rescore_time = time.time() - start_time
54
 
55
- # 6. Sort the scores and return the top_k
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  start_time = time.time()
57
- indices = scores.argsort()[::-1][:top_k]
58
- top_k_indices = binary_ids[indices]
59
- top_k_scores = scores[indices]
60
- top_k_titles, top_k_texts = zip(*[(wikipedia_dataset[idx]["title"], wikipedia_dataset[idx]["text"]) for idx in top_k_indices.tolist()])
61
- df = pd.DataFrame({"Score_paragraphe": [round(value, 2) for value in top_k_scores], "Titre": top_k_titles, "Texte": top_k_texts})
62
- score_sum = df.groupby('Titre')['Score_paragraphe'].sum().reset_index()
63
- df = pd.merge(df, score_sum, on='Titre', how='left')
64
- df.rename(columns={'Score_paragraphe_y': 'Score_article'}, inplace=True)
65
- df.rename(columns={'Score_paragraphe_x': 'Score_paragraphe'}, inplace=True)
66
- df = df[["Score_article", "Score_paragraphe", "Titre", "Texte"]]
67
- df = df.sort_values('Score_article', ascending=False)
68
- # df = df.groupby('Titre')[['Score', 'Texte']].agg({'Score': 'sum', 'Texte': '\n\n'.join}).reset_index().sort_values('Score', ascending=False)
 
 
 
 
 
 
 
 
 
 
69
  sort_time = time.time() - start_time
 
 
 
 
 
 
70
 
71
- return df, {
72
- "Temps pour enchâsser la requête ": f"{embed_time:.4f} s",
73
- "Temps pour la quantisation ": f"{quantize_time:.4f} s",
74
- "Temps pour effectuer la recherche ": f"{search_time:.4f} s",
75
- "Temps de chargement ": f"{load_time:.4f} s",
76
- "Temps de rescorage ": f"{rescore_time:.4f} s",
77
- "Temps pour trier les résustats ": f"{sort_time:.4f} s",
78
- "Temps total pour la recherche ": f"{quantize_time + search_time + load_time + rescore_time + sort_time:.4f} s",
79
- }
80
 
81
 
82
- with gr.Blocks(title="Requêter Wikipedia en temps réel 🔍") as demo:
83
 
84
  gr.Markdown(
85
  """
86
  ## Requêter Wikipedia en temps réel 🔍
87
-
88
  Ce démonstrateur permet de requêter un corpus composé des 250K paragraphes les plus consultés du Wikipédia francophone.
89
  Les résultats sont renvoyés en temps réel via un pipeline tournant sur un CPU 🚀
90
  Nous nous sommes grandement inspirés du Space [quantized-retrieval](https://huggingface.co/spaces/sentence-transformers/quantized-retrieval) conçu par [Tom Aarsen](https://huggingface.co/tomaarsen) 🤗
@@ -102,23 +297,11 @@ Il n'est pas exclus d'ensuite utiliser une version plus récente de Wikipedia (o
102
  </details>
103
 
104
  <details><summary>2. Détails le pipeline</summary>
105
- 1. La requête est enchâssée en float32 à l'aide du modèle <a href="https://hf.co/OrdalieTech/Solon-embeddings-large-0.1">Solon-embeddings-large-0.1</a> d'Ordalie.
106
- 2. La requête est quantizée en binaire à l'aide de la fonction `quantize_embeddings` de la bibliothèque <a href="https://sbert.net/">SentenceTransformers</a>.
107
- 3. Un index binaire (250K <i>embeddings</i> binaires pesant 32MB de mémoire/espace disque) est requêté (en binaire si l'option approximative est sélectionnée, en int8 si l'option exacte est sélectionnée).
108
- 4. Les <i>n</i> textes demandés par l'utilisateur jugés les plus pertinents sont chargés à la volée à partir d'un index int8 sur disque (250K <i>embeddings</i> int8 ; 0 bytes de mémoire, 293MB d'espace disque).
109
- 5. Les <i>n</i> textes sont rescorés en utilisant la requête en float32 et les enchâssements en int8.
110
- 6. Les <i>n</i> premiers textes sont triés par score et affichés. Le "Score_paragraphe" correspond au score individuel de chaque paragraphe d'être pertinant vis-à-vis de la requête. Le "Score_article" correspond à la somme de tous les scores individuels des paragraphes issus d'un même article Wikipedia. L'objectif est alors de mettre en avant l'article source plutôt qu'un bout de texte le composant.
111
-
112
- Ce processus est conçu pour être rapide et efficace en termes de mémoire : l'index binaire étant suffisamment petit pour tenir dans la mémoire et l'index int8 étant chargé en tant que vue pour économiser de la mémoire.
113
- Au total, ce processus nécessite de conserver 1) le modèle en mémoire, 2) l'index binaire en mémoire et 3) l'index int8 sur le disque.
114
- Avec une dimension de 1024, nous avons besoin de `1024 / 8 * num_docs` octets pour l'index binaire et de `1024 * num_docs` octets pour l'index int8.
115
-
116
- C'est nettement moins cher que de faire le même processus avec des enchâssements en float32 qui nécessiterait `4 * 1024 * num_docs` octets de mémoire/espace disque pour l'index float32, soit 32x plus de mémoire et 4x plus d'espace disque.
117
- De plus, l'index binaire est beaucoup plus rapide (jusqu'à 32x) à rechercher que l'index float32, tandis que le rescorage est également extrêmement efficace.
118
- En conclusion, ce processus permet une recherche rapide, évolutive, peu coûteuse et efficace en termes de mémoire.
119
  </details>
120
  """
121
  )
 
122
  with gr.Row():
123
  with gr.Column(scale=75):
124
  query = gr.Textbox(
@@ -129,7 +312,7 @@ En conclusion, ce processus permet une recherche rapide, évolutive, peu coûteu
129
  use_approx = gr.Radio(
130
  choices=[("Exacte", False), ("Approximative", True)],
131
  value=True,
132
- label="Type de recherche",
133
  )
134
 
135
  with gr.Row():
@@ -139,8 +322,7 @@ En conclusion, ce processus permet une recherche rapide, évolutive, peu coûteu
139
  maximum=40,
140
  step=1,
141
  value=15,
142
- label="Nombre de documents à rechercher",
143
- info="Recherche effectué via un bi-encodeur binaire",
144
  )
145
  with gr.Column(scale=2):
146
  rescore_multiplier = gr.Slider(
@@ -149,17 +331,72 @@ En conclusion, ce processus permet une recherche rapide, évolutive, peu coûteu
149
  step=1,
150
  value=1,
151
  label="Coefficient de rescorage",
152
- info="Reranking via le coefficient",
 
 
 
 
 
 
 
 
 
 
 
 
 
153
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
 
155
- search_button = gr.Button(value="Search")
156
 
157
  output = gr.Dataframe(headers=["Score_article", "Score_paragraphe", "Titre", "Texte"], datatype="markdown")
158
- json = gr.JSON()
159
-
160
- query.submit(search, inputs=[query, top_k, rescore_multiplier, use_approx], outputs=[output, json])
161
- search_button.click(search, inputs=[query, top_k, rescore_multiplier, use_approx], outputs=[output, json])
162
- gr.Image("/file=catie(2).png", height=250,width=80, show_download_button=False)
163
 
 
 
 
164
  demo.queue()
165
- demo.launch(allowed_paths=["catie(2).png"])
 
2
  import gradio as gr
3
  from datasets import load_dataset
4
  import pandas as pd
5
+ from sentence_transformers import SentenceTransformer, SparseEncoder
6
  from sentence_transformers.quantization import quantize_embeddings
7
  import faiss
8
  from usearch.index import Index
9
+ import torch
10
+ import numpy as np
11
+ from collections import defaultdict
12
+ from scipy import stats
13
 
 
 
14
 
15
+
16
+ # Load titles, texts and pre-computed sparse embeddings
17
+ wikipedia_dataset = load_dataset("CATIE-AQ/wikipedia_fr_2022_250K", split="train", num_proc=4).select_columns(["title", "text", "wiki_id", "sparse_emb"])
18
+
19
+
20
+ # A function to make the titles of Wikipedia articles clickable in the final dataframe, so that you can consult the article on the website
21
  def add_link(example):
22
  example["title"] = '['+example["title"]+']('+'https://fr.wikipedia.org/wiki?curid='+str(example["wiki_id"])+')'
23
  return example
24
  wikipedia_dataset = wikipedia_dataset.map(add_link)
25
+
26
+
27
+ # Load the int8 and binary indices for dense retrieval
28
  int8_view = Index.restore("wikipedia_fr_2022_250K_int8_usearch.index", view=True)
29
  binary_index: faiss.IndexBinaryFlat = faiss.read_index_binary("wikipedia_fr_2022_250K_ubinary_faiss.index")
30
  binary_ivf: faiss.IndexBinaryIVF = faiss.read_index_binary("wikipedia_fr_2022_250K_ubinary_ivf_faiss.index")
31
 
32
+
33
+ # Load models
34
+ dense_model = SentenceTransformer("OrdalieTech/Solon-embeddings-large-0.1")
35
+ sparse_model = SparseEncoder("CATIE-AQ/SPLADE_camembert-base_STS")
36
+
37
+
38
+ # Fusion methods
39
+ def reciprocal_rank_fusion(dense_results, sparse_results, k=20):
40
+ """
41
+ Perform Reciprocal Rank Fusion to combine dense and sparse retrieval results
42
+
43
+ Args:
44
+ dense_results: List of (doc_id, score) from dense retrieval
45
+ sparse_results: List of (doc_id, score) from sparse retrieval
46
+ k: RRF parameter (default 20)
47
+
48
+ Returns:
49
+ List of (doc_id, rrf_score) sorted by RRF score
50
+ """
51
+ rrf_scores = defaultdict(float)
52
+
53
+ # Add scores from dense retrieval
54
+ for rank, (doc_id, _) in enumerate(dense_results, 1):
55
+ rrf_scores[doc_id] += 1 / (k + rank)
56
+
57
+ # Add scores from sparse retrieval
58
+ for rank, (doc_id, _) in enumerate(sparse_results, 1):
59
+ rrf_scores[doc_id] += 1 / (k + rank)
60
+
61
+ # Sort by RRF score
62
+ sorted_results = sorted(rrf_scores.items(), key=lambda x: x[1], reverse=True)
63
+ return sorted_results
64
+
65
+
66
+ def normalized_score_fusion(dense_results, sparse_results, dense_weight=0.5, sparse_weight=0.5):
67
+ """
68
+ Perform Normalized Score Fusion (NSF) with z-score normalization
69
+
70
+ Args:
71
+ dense_results: List of (doc_id, score) from dense retrieval
72
+ sparse_results: List of (doc_id, score) from sparse retrieval
73
+ dense_weight: Weight for dense scores (default 0.5)
74
+ sparse_weight: Weight for sparse scores (default 0.5)
75
+
76
+ Returns:
77
+ List of (doc_id, normalized_score) sorted by normalized score
78
+ """
79
+ # Extract scores for normalization
80
+ dense_scores = np.array([score for _, score in dense_results])
81
+ sparse_scores = np.array([score for _, score in sparse_results])
82
+
83
+ # Z-score normalization (handle edge cases)
84
+ if len(dense_scores) > 1 and np.std(dense_scores) > 1e-10:
85
+ dense_scores_norm = stats.zscore(dense_scores)
86
+ else:
87
+ dense_scores_norm = np.zeros_like(dense_scores)
88
+
89
+ if len(sparse_scores) > 1 and np.std(sparse_scores) > 1e-10:
90
+ sparse_scores_norm = stats.zscore(sparse_scores)
91
+ else:
92
+ sparse_scores_norm = np.zeros_like(sparse_scores)
93
+
94
+ # Create dictionaries for normalized scores
95
+ dense_norm_dict = {doc_id: score for (doc_id, _), score in zip(dense_results, dense_scores_norm)}
96
+ sparse_norm_dict = {doc_id: score for (doc_id, _), score in zip(sparse_results, sparse_scores_norm)}
97
+
98
+ # Combine all unique document IDs
99
+ all_doc_ids = set()
100
+ all_doc_ids.update(doc_id for doc_id, _ in dense_results)
101
+ all_doc_ids.update(doc_id for doc_id, _ in sparse_results)
102
+
103
+ # Calculate weighted normalized scores
104
+ nsf_scores = {}
105
+ for doc_id in all_doc_ids:
106
+ dense_norm_score = dense_norm_dict.get(doc_id, 0.0)
107
+ sparse_norm_score = sparse_norm_dict.get(doc_id, 0.0)
108
+
109
+ # Weighted combination of normalized scores
110
+ nsf_scores[doc_id] = (dense_weight * dense_norm_score +
111
+ sparse_weight * sparse_norm_score)
112
+
113
+ # Sort by NSF score
114
+ sorted_results = sorted(nsf_scores.items(), key=lambda x: x[1], reverse=True)
115
+ return sorted_results
116
+
117
+
118
+ # Search part
119
+ ## Currentl to try to speed up things, I keep only non-zero values for the sparse search
120
+ ## Need to optimise the sparse research part (about 25-30s currently...)
121
+ ## Tom must surely have some tips on this point
122
+ sparse_index = {}
123
+ for i, sparse_emd in enumerate(wikipedia_dataset["sparse_emb"]):
124
+ # Convert sparse_emd to a NumPy array
125
+ sparse_emd = np.array(sparse_emd)
126
+
127
+ # Only store non-zero values and their indices
128
+ non_zero_indices = np.nonzero(sparse_emd)[0]
129
+ non_zero_values = sparse_emd[non_zero_indices]
130
+ sparse_index[i] = (non_zero_indices, non_zero_values)
131
+
132
+
133
+ def sparse_search(query, top_k=20):
134
+ """
135
+ Perform sparse retrieval using SPLADE representations with a dictionary-based sparse index.
136
+ """
137
+ # Encode the query, the output is a sparse torch.Tensor
138
+ query_sparse_vector = sparse_model.encode(query, convert_to_numpy=False)
139
+
140
+ # Convert the sparse tensor to a dense format before using np.nonzero
141
+ query_sparse_vector = query_sparse_vector.to_dense().numpy()
142
+
143
+ # Get non-zero values and indices for the query
144
+ query_non_zero_indices = np.nonzero(query_sparse_vector)[0]
145
+ query_non_zero_values = query_sparse_vector[query_non_zero_indices]
146
+
147
+ # Compute dot product similarity efficiently
148
+ scores = defaultdict(float)
149
+ for doc_id, (doc_indices, doc_values) in sparse_index.items():
150
+ # Find the intersection of non-zero indices
151
+ common_indices = np.intersect1d(query_non_zero_indices, doc_indices)
152
+ # Compute dot product for common indices only
153
+ for idx in common_indices:
154
+ query_val = query_non_zero_values[np.where(query_non_zero_indices == idx)[0][0]]
155
+ doc_val = doc_values[np.where(doc_indices == idx)[0][0]]
156
+ scores[doc_id] += query_val * doc_val
157
+
158
+ # Sort and get top_k
159
+ sorted_scores = sorted(scores.items(), key=lambda x: x[1], reverse=True)
160
+ return sorted_scores[:top_k]
161
 
162
 
163
+ # Final search combining dense and sparse results
164
+ def search(query, top_k: int = 20, rescore_multiplier: int = 1, use_approx: bool = False,
165
+ fusion_method: str = "rrf", rrf_k: int = 20, dense_weight: float = 0.5, sparse_weight: float = 0.5):
166
+ total_start_time = time.time()
167
+
168
+ # 1. Dense retrieval pipeline (existing code)
169
  start_time = time.time()
170
+ query_embedding = dense_model.encode(query, prompt="query: ")
171
  embed_time = time.time() - start_time
172
 
 
173
  start_time = time.time()
174
  query_embedding_ubinary = quantize_embeddings(query_embedding.reshape(1, -1), "ubinary")
175
  quantize_time = time.time() - start_time
176
 
 
177
  index = binary_ivf if use_approx else binary_index
178
  start_time = time.time()
179
+ _scores, binary_ids = index.search(query_embedding_ubinary, top_k * rescore_multiplier * 2) # Get more for fusion
180
  binary_ids = binary_ids[0]
181
+ dense_search_time = time.time() - start_time
182
 
 
183
  start_time = time.time()
184
  int8_embeddings = int8_view[binary_ids].astype(int)
185
  load_time = time.time() - start_time
186
 
 
187
  start_time = time.time()
188
  scores = query_embedding @ int8_embeddings.T
189
  rescore_time = time.time() - start_time
190
 
191
+ # Prepare dense results
192
+ dense_results = [(binary_ids[i], scores[i]) for i in range(len(binary_ids))]
193
+ dense_results.sort(key=lambda x: x[1], reverse=True)
194
+
195
+ timing_info = {
196
+ "Temps pour créer l'embedding de la requête (dense)": f"{embed_time:.4f} s",
197
+ "Temps pour la quantification": f"{quantize_time:.4f} s",
198
+ "Temps pour effectuer la recherche dense": f"{dense_search_time:.4f} s",
199
+ "Temps de chargement": f"{load_time:.4f} s",
200
+ "Temps de rescorage": f"{rescore_time:.4f} s",
201
+ }
202
+
203
+
204
+ if fusion_method != "dense_only":
205
+ # 2. Sparse retrieval pipeline
206
+ start_time = time.time()
207
+ sparse_results = sparse_search(query, top_k * rescore_multiplier * 2)
208
+ sparse_search_time = time.time() - start_time
209
+
210
+ # 3. Apply selected fusion method
211
+ start_time = time.time()
212
+ if fusion_method == "rrf":
213
+ fusion_results = reciprocal_rank_fusion(dense_results, sparse_results, k=rrf_k)
214
+ fusion_method_name = f"RRF (k={rrf_k})"
215
+ elif fusion_method == "nsf":
216
+ fusion_results = normalized_score_fusion(dense_results, sparse_results,
217
+ dense_weight=dense_weight, sparse_weight=sparse_weight)
218
+ fusion_method_name = f"NSF Z-score (dense={dense_weight:.1f}, sparse={sparse_weight:.1f})"
219
+ else:
220
+ fusion_results = dense_results # fallback
221
+ fusion_method_name = "Dense uniquement (fallback)"
222
+
223
+ fusion_time = time.time() - start_time
224
+
225
+ # Use fusion results
226
+ final_results = fusion_results[:top_k * rescore_multiplier]
227
+ final_doc_ids = [doc_id for doc_id, _ in final_results]
228
+ final_scores = [score for _, score in final_results]
229
+
230
+ timing_info.update({
231
+ # Add time to create query embedding with sparse model
232
+ # Do quantification for sparse model as dense one?
233
+ "Temps pour la recherche sparse": f"{sparse_search_time:.4f} s",
234
+ "Temps pour la fusion": f"{fusion_time:.4f} s",
235
+ })
236
+ timing_info["Méthode de fusion utilisée"] = fusion_method_name
237
+ else:
238
+ # Use only dense results
239
+ final_doc_ids = [doc_id for doc_id, _ in dense_results[:top_k * rescore_multiplier]]
240
+ final_scores = [score for _, score in dense_results[:top_k * rescore_multiplier]]
241
+ timing_info["Méthode de fusion utilisée"] = "Dense uniquement"
242
+
243
+ # 4. Prepare final results
244
  start_time = time.time()
245
+ try:
246
+ top_k_titles, top_k_texts = zip(*[(wikipedia_dataset[int(idx)]["title"], wikipedia_dataset[int(idx)]["text"])
247
+ for idx in final_doc_ids[:top_k]])
248
+
249
+ # Create DataFrame with results
250
+ df = pd.DataFrame({
251
+ "Score_paragraphe": [round(float(score), 4) for score in final_scores[:top_k]],
252
+ "Titre": top_k_titles,
253
+ "Texte": top_k_texts
254
+ })
255
+
256
+ # Calculate article scores (sum of paragraph scores)
257
+ score_sum = df.groupby('Titre')['Score_paragraphe'].sum().reset_index()
258
+ df = pd.merge(df, score_sum, on='Titre', how='left')
259
+ df.rename(columns={'Score_paragraphe_y': 'Score_article', 'Score_paragraphe_x': 'Score_paragraphe'}, inplace=True)
260
+ df = df[["Score_article", "Score_paragraphe", "Titre", "Texte"]]
261
+ df = df.sort_values('Score_article', ascending=False)
262
+
263
+ except Exception as e:
264
+ print(f"Error creating results DataFrame: {e}")
265
+ df = pd.DataFrame({"Error": [f"No results found or error processing results: {e}"]})
266
+
267
  sort_time = time.time() - start_time
268
+ total_time = time.time() - total_start_time
269
+
270
+ timing_info.update({
271
+ "Temps pour afficher les résultats": f"{sort_time:.4f} s",
272
+ "Temps total": f"{total_time:.4f} s",
273
+ })
274
 
275
+ return df, timing_info
 
 
 
 
 
 
 
 
276
 
277
 
278
+ with gr.Blocks(title="Requêter Wikipedia avec Fusion Hybride 🔍") as demo:
279
 
280
  gr.Markdown(
281
  """
282
  ## Requêter Wikipedia en temps réel 🔍
 
283
  Ce démonstrateur permet de requêter un corpus composé des 250K paragraphes les plus consultés du Wikipédia francophone.
284
  Les résultats sont renvoyés en temps réel via un pipeline tournant sur un CPU 🚀
285
  Nous nous sommes grandement inspirés du Space [quantized-retrieval](https://huggingface.co/spaces/sentence-transformers/quantized-retrieval) conçu par [Tom Aarsen](https://huggingface.co/tomaarsen) 🤗
 
297
  </details>
298
 
299
  <details><summary>2. Détails le pipeline</summary>
300
+ A écrire quand ça sera terminé.
 
 
 
 
 
 
 
 
 
 
 
 
 
301
  </details>
302
  """
303
  )
304
+
305
  with gr.Row():
306
  with gr.Column(scale=75):
307
  query = gr.Textbox(
 
312
  use_approx = gr.Radio(
313
  choices=[("Exacte", False), ("Approximative", True)],
314
  value=True,
315
+ label="Type de recherche dense",
316
  )
317
 
318
  with gr.Row():
 
322
  maximum=40,
323
  step=1,
324
  value=15,
325
+ label="Nombre de documents à retourner",
 
326
  )
327
  with gr.Column(scale=2):
328
  rescore_multiplier = gr.Slider(
 
331
  step=1,
332
  value=1,
333
  label="Coefficient de rescorage",
334
+ info="Augmente le nombre de candidats avant fusion",
335
+ )
336
+
337
+ with gr.Row():
338
+ with gr.Column(scale=3):
339
+ fusion_method = gr.Radio(
340
+ choices=[
341
+ ("Dense uniquement", "dense_only"),
342
+ ("Fusion RRF", "rrf"),
343
+ ("Fusion NSF Z-score", "nsf")
344
+ ],
345
+ value="rrf",
346
+ label="Méthode de fusion",
347
+ info="Choisissez comment combiner les résultats des modèles dense et sparse"
348
  )
349
+ with gr.Column(scale=2):
350
+ rrf_k = gr.Slider(
351
+ minimum=1,
352
+ maximum=200,
353
+ step=1,
354
+ value=60,
355
+ label="Paramètre k pour RRF",
356
+ info="Plus k est élevé, moins les rangs ont d'importance",
357
+ visible=True
358
+ )
359
+
360
+ with gr.Row():
361
+ with gr.Column(scale=2):
362
+ dense_weight = gr.Slider(
363
+ minimum=0.0,
364
+ maximum=1.0,
365
+ step=0.1,
366
+ value=0.5,
367
+ label="Poids Dense (NSF)",
368
+ info="Importance des résultats du modèle dense",
369
+ visible=False
370
+ )
371
+ with gr.Column(scale=2):
372
+ sparse_weight = gr.Slider(
373
+ minimum=0.0,
374
+ maximum=1.0,
375
+ step=0.1,
376
+ value=0.5,
377
+ label="Poids Sparse (NSF)",
378
+ info="Importance des résultats du modèle sparse",
379
+ visible=False
380
+ )
381
+
382
+ # JavaScript to show/hide parameters based on fusion method
383
+ fusion_method.change(
384
+ fn=lambda method: (
385
+ gr.update(visible=(method == "rrf")), # rrf_k
386
+ gr.update(visible=(method == "nsf")), # dense_weight
387
+ gr.update(visible=(method == "nsf")) # sparse_weight
388
+ ),
389
+ inputs=[fusion_method],
390
+ outputs=[rrf_k, dense_weight, sparse_weight]
391
+ )
392
 
393
+ search_button = gr.Button(value="Rechercher", variant="primary")
394
 
395
  output = gr.Dataframe(headers=["Score_article", "Score_paragraphe", "Titre", "Texte"], datatype="markdown")
396
+ json = gr.JSON(label="Informations de performance")
 
 
 
 
397
 
398
+ query.submit(search, inputs=[query, top_k, rescore_multiplier, use_approx, fusion_method, rrf_k, dense_weight, sparse_weight], outputs=[output, json])
399
+ search_button.click(search, inputs=[query, top_k, rescore_multiplier, use_approx, fusion_method, rrf_k, dense_weight, sparse_weight], outputs=[output, json])
400
+
401
  demo.queue()
402
+ demo.launch()