davidberenstein1957 HF staff commited on
Commit
5f76c1a
·
1 Parent(s): fbdb332

feat: use SOTA model

Browse files
Files changed (2) hide show
  1. app.py +7 -9
  2. demo.py +5 -5
app.py CHANGED
@@ -9,7 +9,7 @@ global ds
9
  global df
10
 
11
  # Load a model from the HuggingFace hub (in this case the potion-base-8M model)
12
- model_name = "minishlab/M2V_multilingual_output"
13
  model = StaticModel.from_pretrained(model_name)
14
 
15
 
@@ -53,7 +53,7 @@ def vectorize_dataset(split: str, column: str):
53
  global df
54
  global ds
55
  df = ds[split].to_polars()
56
- embeddings = model.encode(df[column])
57
  df = df.with_columns(pl.Series(embeddings).alias("embeddings"))
58
 
59
 
@@ -64,7 +64,7 @@ def run_query(query: str):
64
  query=f"""
65
  SELECT *
66
  FROM df
67
- ORDER BY array_distance(embeddings, {vector.tolist()}::FLOAT[256])
68
  LIMIT 5
69
  """
70
  ).to_df()
@@ -91,18 +91,16 @@ with gr.Blocks() as demo:
91
  )
92
  with gr.Row():
93
  search_out = gr.HTML(label="Search Results")
94
- search_in.submit(get_iframe, inputs=search_in, outputs=search_out)
95
-
96
- btn_load_dataset = gr.Button("Load Dataset")
97
 
98
  with gr.Row(variant="panel"):
99
  split_dropdown = gr.Dropdown(label="Select a split")
100
  column_dropdown = gr.Dropdown(label="Select a column")
101
  with gr.Row(variant="panel"):
102
  query_input = gr.Textbox(label="Query")
103
-
104
- btn_load_dataset.click(
105
- load_dataset_from_hub, inputs=search_in, show_progress=True
 
106
  ).then(fn=get_splits, outputs=split_dropdown).then(
107
  fn=get_columns, inputs=split_dropdown, outputs=column_dropdown
108
  )
 
9
  global df
10
 
11
  # Load a model from the HuggingFace hub (in this case the potion-base-8M model)
12
+ model_name = "minishlab/potion-base-8M"
13
  model = StaticModel.from_pretrained(model_name)
14
 
15
 
 
53
  global df
54
  global ds
55
  df = ds[split].to_polars()
56
+ embeddings = model.encode(df[column], max_length=512 * 4)
57
  df = df.with_columns(pl.Series(embeddings).alias("embeddings"))
58
 
59
 
 
64
  query=f"""
65
  SELECT *
66
  FROM df
67
+ ORDER BY array_cosine_distance(embeddings, {vector.tolist()}::FLOAT[256])
68
  LIMIT 5
69
  """
70
  ).to_df()
 
91
  )
92
  with gr.Row():
93
  search_out = gr.HTML(label="Search Results")
 
 
 
94
 
95
  with gr.Row(variant="panel"):
96
  split_dropdown = gr.Dropdown(label="Select a split")
97
  column_dropdown = gr.Dropdown(label="Select a column")
98
  with gr.Row(variant="panel"):
99
  query_input = gr.Textbox(label="Query")
100
+ search_in.submit(get_iframe, inputs=search_in, outputs=search_out).then(
101
+ fn=load_dataset_from_hub,
102
+ inputs=search_in,
103
+ show_progress=True,
104
  ).then(fn=get_splits, outputs=split_dropdown).then(
105
  fn=get_columns, inputs=split_dropdown, outputs=column_dropdown
106
  )
demo.py CHANGED
@@ -4,20 +4,20 @@ from datasets import load_dataset
4
  from model2vec import StaticModel
5
 
6
  # Load a model from the HuggingFace hub (in this case the potion-base-8M model)
7
- model_name = "minishlab/M2V_multilingual_output"
8
  model = StaticModel.from_pretrained(model_name)
9
 
10
  # Make embeddings
11
  ds = load_dataset("fka/awesome-chatgpt-prompts")
12
  df = ds["train"].to_polars()
13
- embeddings = model.encode(df["prompt"])
14
  df = df.with_columns(pl.Series(embeddings).alias("embeddings"))
15
- vector = model.encode("vector search", show_progress_bar=True)
16
  duckdb.sql(
17
  query=f"""
18
  SELECT *
19
  FROM df
20
- ORDER BY array_distance(embeddings, {vector.tolist()}::FLOAT[256])
21
- LIMIT 1
22
  """
23
  ).show()
 
4
  from model2vec import StaticModel
5
 
6
  # Load a model from the HuggingFace hub (in this case the potion-base-8M model)
7
+ model_name = "minishlab/potion-base-8M"
8
  model = StaticModel.from_pretrained(model_name)
9
 
10
  # Make embeddings
11
  ds = load_dataset("fka/awesome-chatgpt-prompts")
12
  df = ds["train"].to_polars()
13
+ embeddings = model.encode(df["act"])
14
  df = df.with_columns(pl.Series(embeddings).alias("embeddings"))
15
+ vector = model.encode("An Ethereum Developer", show_progress_bar=True)
16
  duckdb.sql(
17
  query=f"""
18
  SELECT *
19
  FROM df
20
+ ORDER BY array_cosine_distance(embeddings, {vector.tolist()}::FLOAT[256])
21
+ LIMIT 10
22
  """
23
  ).show()