Spaces:
Sleeping
Sleeping
File size: 3,737 Bytes
4b6ea6b c2318c3 4b6ea6b de81f6b 4b6ea6b de81f6b 4b6ea6b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 |
import sqlite3
import sqlite_vec
from datasets import load_dataset
dataset = load_dataset("m3hrdadfi/recipe_nlg_lite", trust_remote_code=True)
recipe_names = dataset["train"]["name"]
from sentence_transformers import SentenceTransformer
tfm_base = SentenceTransformer("all-MiniLM-L6-v2")
X_tfm = tfm_base.encode(recipe_names)
n_feats = X_tfm.shape[1]
import polars as pl
import solara
@solara.component
def Display_Full(query,db,limit):
with db:
rows_orig = db.execute(
f"""
SELECT
rowid,
distance
FROM vec_sents
WHERE embedding MATCH ? AND k={limit}
ORDER BY distance
LIMIT {limit}
""",
[sqlite_vec.serialize_float32(query)],
).fetchall()
df1 = pl.DataFrame({"results": [recipe_names[rowid] for rowid in [dict(row)["rowid"] for row in rows_orig]]})
with solara.Column():
solara.Markdown("## Full precision")
solara.DataFrame(df1, items_per_page=10)
@solara.component
def Display_Binary(query,db,limit):
with db:
rows_bin = db.execute(
f"""
SELECT
rowid,
distance
FROM bin_vec_sents
WHERE embedding MATCH vec_quantize_binary(?) AND k={limit}
ORDER BY distance
LIMIT {limit}
""",
[sqlite_vec.serialize_float32(query)],
).fetchall()
df2 = pl.DataFrame({"results": [recipe_names[rowid] for rowid in [dict(row)["rowid"] for row in rows_bin]]})
with solara.Column():
solara.Markdown("## Binary quantization")
solara.DataFrame(df2, items_per_page=10)
@solara.component
def Page():
with solara.Column(margin=10):
with solara.Head():
solara.Title("Recipe finder")
solara.Markdown("# Recipe finder")
solara.Markdown("I built this tool to help me get a feeling of binary embedding quantization in [sqlite-vec](https://alexgarcia.xyz/sqlite-vec/). For any given text, it gives the top 10 results. The dataset I'm using is [m3hrdadfi/recipe_nlg_lite](https://hf.co/datasets/m3hrdadfi/recipe_nlg_lite) which consists of 6,119 recipes. Inspired by [Exploring SQLite-vec](https://www.youtube.com/watch?v=wYU66AjRIAc) by [@fishnets88](https://twitter.com/fishnets88)")
q = solara.use_reactive("I would like to have some vegetable soup")
solara.InputText("Enter a query", value=q, continuous_update=True)
query = tfm_base.encode([q.value])[0]
limit = 10
db = sqlite3.connect(":memory:")
db.enable_load_extension(True)
sqlite_vec.load(db)
db.enable_load_extension(False)
db.row_factory = sqlite3.Row
db.execute(f"create virtual table vec_sents using vec0(embedding float[{n_feats}])")
with db:
for i, item in enumerate([{"vector": x} for i, x in enumerate(X_tfm)]):
db.execute(
"INSERT INTO vec_sents(rowid, embedding) VALUES (?, ?)",
[i, sqlite_vec.serialize_float32(item["vector"])],
)
db.execute(f"create virtual table bin_vec_sents using vec0(embedding bit[{n_feats}])")
with db:
for i, item in enumerate([{"vector": x} for i, x in enumerate(X_tfm)]):
db.execute(
"INSERT INTO bin_vec_sents(rowid, embedding) VALUES (?, vec_quantize_binary(?))",
[i, sqlite_vec.serialize_float32(item["vector"])],
)
with solara.Row():
Display_Full(query,db,limit)
Display_Binary(query,db,limit)
|