File size: 3,737 Bytes
4b6ea6b
 
 
 
 
c2318c3
4b6ea6b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
de81f6b
4b6ea6b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
de81f6b
4b6ea6b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import sqlite3
import sqlite_vec

from datasets import load_dataset

dataset = load_dataset("m3hrdadfi/recipe_nlg_lite", trust_remote_code=True)
recipe_names = dataset["train"]["name"]

from sentence_transformers import SentenceTransformer

tfm_base = SentenceTransformer("all-MiniLM-L6-v2")
X_tfm = tfm_base.encode(recipe_names)
n_feats = X_tfm.shape[1]

import polars as pl
import solara

@solara.component
def Display_Full(query,db,limit):
    with db:
        rows_orig = db.execute(
            f"""
              SELECT
                rowid,
                distance
              FROM vec_sents
              WHERE embedding MATCH ? AND k={limit}
              ORDER BY distance
              LIMIT {limit}
            """,
            [sqlite_vec.serialize_float32(query)],
        ).fetchall()
    df1 = pl.DataFrame({"results": [recipe_names[rowid] for rowid in [dict(row)["rowid"] for row in rows_orig]]})
    with solara.Column():
        solara.Markdown("## Full precision")
        solara.DataFrame(df1, items_per_page=10)

@solara.component
def Display_Binary(query,db,limit):
    with db:
        rows_bin = db.execute(
            f"""
              SELECT
                rowid,
                distance
              FROM bin_vec_sents
              WHERE embedding MATCH vec_quantize_binary(?) AND k={limit}
              ORDER BY distance
              LIMIT {limit}
            """,
            [sqlite_vec.serialize_float32(query)],
        ).fetchall()
    df2 = pl.DataFrame({"results": [recipe_names[rowid] for rowid in [dict(row)["rowid"] for row in rows_bin]]})
    with solara.Column():
        solara.Markdown("## Binary quantization")
        solara.DataFrame(df2, items_per_page=10)


@solara.component
def Page():
    with solara.Column(margin=10):
        with solara.Head():
            solara.Title("Recipe finder")
        solara.Markdown("# Recipe finder")
        solara.Markdown("I built this tool to help me get a feeling of binary embedding quantization in [sqlite-vec](https://alexgarcia.xyz/sqlite-vec/). For any given text, it gives the top 10 results. The dataset I'm using is [m3hrdadfi/recipe_nlg_lite](https://hf.co/datasets/m3hrdadfi/recipe_nlg_lite) which consists of 6,119 recipes. Inspired by [Exploring SQLite-vec](https://www.youtube.com/watch?v=wYU66AjRIAc) by [@fishnets88](https://twitter.com/fishnets88)")
        q = solara.use_reactive("I would like to have some vegetable soup")
        solara.InputText("Enter a query", value=q, continuous_update=True)
        query = tfm_base.encode([q.value])[0]
        limit = 10
        db = sqlite3.connect(":memory:")
        db.enable_load_extension(True)
        sqlite_vec.load(db)
        db.enable_load_extension(False)
        db.row_factory = sqlite3.Row

        db.execute(f"create virtual table vec_sents using vec0(embedding float[{n_feats}])")
        
        with db:
            for i, item in enumerate([{"vector": x} for i, x in enumerate(X_tfm)]):
                db.execute(
                    "INSERT INTO vec_sents(rowid, embedding) VALUES (?, ?)",
                    [i, sqlite_vec.serialize_float32(item["vector"])],
                )
        
        db.execute(f"create virtual table bin_vec_sents using vec0(embedding bit[{n_feats}])")
        
        with db:
            for i, item in enumerate([{"vector": x} for i, x in enumerate(X_tfm)]):
                db.execute(
                        "INSERT INTO bin_vec_sents(rowid, embedding) VALUES (?, vec_quantize_binary(?))",
                        [i, sqlite_vec.serialize_float32(item["vector"])],
                    )
    
    
    
        with solara.Row():
            Display_Full(query,db,limit)
            Display_Binary(query,db,limit)