cast42 commited on
Commit
edf8d32
·
verified ·
1 Parent(s): 8ebf516

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -0
app.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import duckdb
2
+ import gradio as gr
3
+ from sentence_transformers import SentenceTransformer
4
+ from sentence_transformers.models import StaticEmbedding
5
+
6
+ static_embedding = StaticEmbedding.from_model2vec("minishlab/potion-base-8M")
7
+ model = SentenceTransformer(modules=[static_embedding])
8
+ embedding_dimensions = model.get_sentence_embedding_dimension()
9
+ dataset_name = "cast42/x_likes_embeddings_potion_base_8M"
10
+ embedding_column = "embeddings"
11
+ embedding_column_float = f"{embedding_column}_float"
12
+ table_name = "fineweb"
13
+
14
+ duckdb.sql(
15
+ query=f"""
16
+ INSTALL vss;
17
+ LOAD vss;
18
+ CREATE TABLE {table_name} AS
19
+ SELECT *, {embedding_column}::float[{embedding_dimensions}] as {embedding_column_float}
20
+ FROM 'hf://datasets/{dataset_name}/**/*.parquet';
21
+ CREATE INDEX my_hnsw_index ON {table_name} USING HNSW ({embedding_column_float}) WITH (metric = 'cosine');
22
+ """
23
+ )
24
+
25
+
26
+ def similarity_search(query: str, k: int = 5):
27
+ embedding = model.encode(query).tolist()
28
+ df = duckdb.sql(
29
+ query=f"""
30
+ SELECT *, array_cosine_distance({embedding_column_float}, {embedding}::FLOAT[{embedding_dimensions}]) as distance
31
+ FROM {table_name}
32
+ ORDER BY distance
33
+ LIMIT {k};
34
+ """
35
+ ).to_df()
36
+ df = df.drop(columns=[embedding_column, embedding_column_float])
37
+ return df
38
+
39
+
40
+ with gr.Blocks() as demo:
41
+ gr.Markdown("""# RAG - retrieve
42
+ Executes vector search on top of [x_likes_embeddings_potion_base_8M](https://huggingface.co/datasets/cast42/x_likes_embeddings_potion_base_8M) using DuckDB.
43
+
44
+ Part of [AI blueprint](https://github.com/huggingface/ai-blueprint) - a blueprint for AI development, focusing on practical examples of RAG, information extraction, analysis and fine-tuning in the age of LLMs. """)
45
+ query = gr.Textbox(label="Query")
46
+ k = gr.Slider(1, 50, value=5, label="Number of results")
47
+ btn = gr.Button("Search")
48
+ results = gr.Dataframe(headers=["url", "chunk", "distance"], wrap=True)
49
+ btn.click(fn=similarity_search, inputs=[query, k], outputs=[results])
50
+
51
+
52
+ demo.launch()