File size: 2,548 Bytes
384acb5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f240b4f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import duckdb
import gradio as gr
from gradio_client import Client
from sentence_transformers import CrossEncoder
from sentence_transformers import SentenceTransformer
from sentence_transformers.models import StaticEmbedding
from huggingface_hub import get_token
import pandas as pd


static_embedding = StaticEmbedding.from_model2vec("minishlab/potion-base-8M")
model = SentenceTransformer(modules=[static_embedding])
reranker = CrossEncoder("sentence-transformers/all-MiniLM-L12-v2")
embedding_dimensions = model.get_sentence_embedding_dimension()
dataset_name = "cyrilzakka/clinical-trials-embeddings"
embedding_column = "embedding"
embedding_column_float = f"{embedding_column}_float"
table_name = "clinical_trials"

duckdb.sql(query=f"""
    INSTALL vss;
    LOAD vss;
    CREATE TABLE {table_name} AS 
    SELECT *, {embedding_column}::float[{embedding_dimensions}] as {embedding_column_float} 
    FROM 'hf://datasets/{dataset_name}/**/*.parquet';
    CREATE INDEX my_hnsw_index ON {table_name} USING HNSW ({embedding_column_float}) WITH (metric = 'cosine');
""")

def similarity_search(query: str, k: int = 5):
    embedding = model.encode(query).tolist()
    df = duckdb.sql(
        query=f"""
        SELECT *, array_cosine_distance({embedding_column_float}, {embedding}::FLOAT[{embedding_dimensions}]) as distance 
        FROM {table_name}
        ORDER BY distance 
        LIMIT {k};
    """
    ).to_df()
    df = df.drop(columns=[embedding_column, embedding_column_float])
    return df

def rerank(query: str, documents: pd.DataFrame) -> pd.DataFrame:
    documents = documents.copy()
    documents = documents.drop_duplicates("briefSummary")
    documents["rank"] = reranker.predict([[query, hit] for hit in documents["briefSummary"]])
    documents = documents.sort_values(by="rank", ascending=False)
    return documents

with gr.Blocks() as demo:
    gr.Markdown("""# RAG - Clinical Trials (clinicaltrials.gov)
                Executes vector search and re-ranking top of [clinical-trials-embeddings](https://huggingface.co/datasets/cyrilzakka/clinical-trials-embeddings).
                
                Part of the [Therapeutics Actionability Challenge](https://sail.health/event/sail-2025/program/) Demo.""")
    query = gr.Textbox(label="Query")
    k = gr.Slider(1, 50, value=5, label="Number of results")
    btn = gr.Button("Search")
    results = gr.Dataframe(headers=["url", "chunk", "distance"], wrap=True)
    btn.click(fn=similarity_search, inputs=[query, k], outputs=[results])
    

demo.launch(mcp_server=True)