davidberenstein1957 HF staff commited on
Commit
b4d283f
1 Parent(s): 41b224c

fix: avoid global usage

Browse files
Files changed (1) hide show
  1. app.py +26 -23
app.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import duckdb
2
  import gradio as gr
3
  import polars as pl
@@ -5,7 +7,6 @@ from datasets import load_dataset
5
  from gradio_huggingfacehub_search import HuggingfaceHubSearch
6
  from model2vec import StaticModel
7
 
8
- global ds
9
  global df
10
 
11
  # Load a model from the HuggingFace hub (in this case the potion-base-8M model)
@@ -28,14 +29,13 @@ def get_iframe(hub_repo_id):
28
  return iframe
29
 
30
 
31
- def load_dataset_from_hub(hub_repo_id):
32
- gr.Info("Loading dataset...")
33
- global ds
34
  ds = load_dataset(hub_repo_id)
35
 
36
 
37
- def get_columns(split: str):
38
- global ds
39
  ds_split = ds[split]
40
  return gr.Dropdown(
41
  choices=ds_split.column_names,
@@ -45,33 +45,35 @@ def get_columns(split: str):
45
  )
46
 
47
 
48
- def get_splits():
49
- global ds
50
  splits = list(ds.keys())
51
  return gr.Dropdown(
52
  choices=splits, value=splits[0], label="Select a split", visible=True
53
  )
54
 
55
 
56
- def vectorize_dataset(split: str, column: str):
 
57
  gr.Info("Vectorizing dataset...")
58
- global df
59
- global ds
60
  df = ds[split].to_polars()
61
  embeddings = model.encode(df[column].cast(str), max_length=512)
62
- df = df.with_columns(pl.Series(embeddings).alias(f"{column}_embeddings"))
63
 
64
 
65
- def run_query(query: str, column: str):
 
 
 
 
66
  try:
67
- global df
68
-
69
  vector = model.encode(query)
70
  df_results = duckdb.sql(
71
  query=f"""
72
  SELECT *
73
  FROM df
74
- ORDER BY array_cosine_distance({column}_embeddings, {vector.tolist()}::FLOAT[256])
75
  LIMIT 5
76
  """
77
  ).to_df()
@@ -134,6 +136,7 @@ with gr.Blocks() as demo:
134
  query_input = gr.Textbox(label="Query", visible=False)
135
 
136
  btn_run = gr.Button("Search", visible=False)
 
137
  results_output = gr.Dataframe(label="Results", visible=False)
138
 
139
  search_in.submit(get_iframe, inputs=search_in, outputs=search_out).then(
@@ -143,23 +146,23 @@ with gr.Blocks() as demo:
143
  ).then(
144
  fn=hide_components,
145
  outputs=[split_dropdown, column_dropdown, query_input, btn_run, results_output],
146
- ).then(fn=get_splits, outputs=split_dropdown).then(
147
- fn=get_columns, inputs=split_dropdown, outputs=column_dropdown
148
  )
149
 
150
  split_dropdown.change(
151
- fn=get_columns, inputs=split_dropdown, outputs=column_dropdown
152
  )
153
 
154
  column_dropdown.change(
155
  fn=partial_hide_components,
156
  outputs=[query_input, btn_run, results_output],
157
- ).then(fn=vectorize_dataset, inputs=[split_dropdown, column_dropdown]).then(
158
- fn=show_components, outputs=[query_input, btn_run]
159
- )
160
 
161
  btn_run.click(
162
- fn=run_query, inputs=[query_input, column_dropdown], outputs=results_output
 
 
163
  )
164
 
165
  demo.launch()
 
1
+ from functools import lru_cache
2
+
3
  import duckdb
4
  import gradio as gr
5
  import polars as pl
 
7
  from gradio_huggingfacehub_search import HuggingfaceHubSearch
8
  from model2vec import StaticModel
9
 
 
10
  global df
11
 
12
  # Load a model from the HuggingFace hub (in this case the potion-base-8M model)
 
29
  return iframe
30
 
31
 
32
+ def load_dataset_from_hub(hub_repo_id: str):
33
+ gr.Info(message="Loading dataset...")
 
34
  ds = load_dataset(hub_repo_id)
35
 
36
 
37
+ def get_columns(hub_repo_id: str, split: str):
38
+ ds = load_dataset(hub_repo_id)
39
  ds_split = ds[split]
40
  return gr.Dropdown(
41
  choices=ds_split.column_names,
 
45
  )
46
 
47
 
48
+ def get_splits(hub_repo_id: str):
49
+ ds = load_dataset(hub_repo_id)
50
  splits = list(ds.keys())
51
  return gr.Dropdown(
52
  choices=splits, value=splits[0], label="Select a split", visible=True
53
  )
54
 
55
 
56
+ @lru_cache
57
+ def vectorize_dataset(hub_repo_id: str, split: str, column: str):
58
  gr.Info("Vectorizing dataset...")
59
+ ds = load_dataset(hub_repo_id)
 
60
  df = ds[split].to_polars()
61
  embeddings = model.encode(df[column].cast(str), max_length=512)
62
+ return embeddings
63
 
64
 
65
+ def run_query(hub_repo_id: str, query: str, split: str, column: str):
66
+ embeddings = vectorize_dataset(hub_repo_id, split, column)
67
+ ds = load_dataset(hub_repo_id)
68
+ df = ds[split].to_polars()
69
+ df = df.with_columns(pl.Series(embeddings).alias("embeddings"))
70
  try:
 
 
71
  vector = model.encode(query)
72
  df_results = duckdb.sql(
73
  query=f"""
74
  SELECT *
75
  FROM df
76
+ ORDER BY array_cosine_distance(embeddings, {vector.tolist()}::FLOAT[256])
77
  LIMIT 5
78
  """
79
  ).to_df()
 
136
  query_input = gr.Textbox(label="Query", visible=False)
137
 
138
  btn_run = gr.Button("Search", visible=False)
139
+
140
  results_output = gr.Dataframe(label="Results", visible=False)
141
 
142
  search_in.submit(get_iframe, inputs=search_in, outputs=search_out).then(
 
146
  ).then(
147
  fn=hide_components,
148
  outputs=[split_dropdown, column_dropdown, query_input, btn_run, results_output],
149
+ ).then(fn=get_splits, inputs=search_in, outputs=split_dropdown).then(
150
+ fn=get_columns, inputs=[search_in, split_dropdown], outputs=column_dropdown
151
  )
152
 
153
  split_dropdown.change(
154
+ fn=get_columns, inputs=[search_in, split_dropdown], outputs=column_dropdown
155
  )
156
 
157
  column_dropdown.change(
158
  fn=partial_hide_components,
159
  outputs=[query_input, btn_run, results_output],
160
+ ).then(fn=show_components, outputs=[query_input, btn_run])
 
 
161
 
162
  btn_run.click(
163
+ fn=run_query,
164
+ inputs=[search_in, query_input, split_dropdown, column_dropdown],
165
+ outputs=results_output,
166
  )
167
 
168
  demo.launch()