Spaces:
Sleeping
Sleeping
import duckdb | |
import gradio as gr | |
from gradio_client import Client | |
from sentence_transformers import CrossEncoder | |
from sentence_transformers import SentenceTransformer | |
from sentence_transformers.models import StaticEmbedding | |
from huggingface_hub import get_token | |
import pandas as pd | |
static_embedding = StaticEmbedding.from_model2vec("minishlab/potion-base-8M") | |
model = SentenceTransformer(modules=[static_embedding]) | |
reranker = CrossEncoder("sentence-transformers/all-MiniLM-L12-v2") | |
embedding_dimensions = model.get_sentence_embedding_dimension() | |
dataset_name = "cyrilzakka/clinical-trials-embeddings" | |
embedding_column = "embedding" | |
embedding_column_float = f"{embedding_column}_float" | |
table_name = "clinical_trials" | |
duckdb.sql(query=f""" | |
INSTALL vss; | |
LOAD vss; | |
CREATE TABLE {table_name} AS | |
SELECT *, {embedding_column}::float[{embedding_dimensions}] as {embedding_column_float} | |
FROM 'hf://datasets/{dataset_name}/**/*.parquet'; | |
CREATE INDEX my_hnsw_index ON {table_name} USING HNSW ({embedding_column_float}) WITH (metric = 'cosine'); | |
""") | |
def similarity_search(query: str, k: int = 5): | |
embedding = model.encode(query).tolist() | |
df = duckdb.sql( | |
query=f""" | |
SELECT *, array_cosine_distance({embedding_column_float}, {embedding}::FLOAT[{embedding_dimensions}]) as distance | |
FROM {table_name} | |
ORDER BY distance | |
LIMIT {k}; | |
""" | |
).to_df() | |
df = df.drop(columns=[embedding_column, embedding_column_float]) | |
return df | |
def rerank(query: str, documents: pd.DataFrame) -> pd.DataFrame: | |
documents = documents.copy() | |
documents = documents.drop_duplicates("briefSummary") | |
documents["rank"] = reranker.predict([[query, hit] for hit in documents["briefSummary"]]) | |
documents = documents.sort_values(by="rank", ascending=False) | |
return documents | |
with gr.Blocks() as demo: | |
gr.Markdown("""# RAG - Clinical Trials (clinicaltrials.gov) | |
Executes vector search and re-ranking top of [clinical-trials-embeddings](https://huggingface.co/datasets/cyrilzakka/clinical-trials-embeddings). | |
Part of the [Therapeutics Actionability Challenge](https://sail.health/event/sail-2025/program/) Demo.""") | |
query = gr.Textbox(label="Query") | |
k = gr.Slider(1, 50, value=5, label="Number of results") | |
btn = gr.Button("Search") | |
results = gr.Dataframe(headers=["url", "chunk", "distance"], wrap=True) | |
btn.click(fn=similarity_search, inputs=[query, k], outputs=[results]) | |
demo.launch(mcp_server=True) |