Spaces:
Runtime error
Runtime error
| from langchain_community.embeddings import HuggingFaceEmbeddings | |
| from langchain_community.vectorstores import FAISS | |
| from langchain_core.documents import Document | |
| from dataset_loader import load_medmcqa_subset | |
| from tqdm import tqdm # Progress bar for better visibility during indexing | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| import os | |
| os.environ["GOOGLE_API_KEY"] = os.getenv("GOOGLE_API_KEY") | |
| def retrieve_relevant_docs(): | |
| embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") | |
| db = FAISS.load_local("data/medmcqa_index", embeddings, allow_dangerous_deserialization=True) | |
| return db.as_retriever(search_type="similarity", search_kwargs={"k": 3}) | |
| def create_vector_store(): | |
| examples = load_medmcqa_subset() | |
| # Format each entry into a LangChain Document with progress bar | |
| docs = [ | |
| Document( | |
| page_content=e["formatted"], | |
| metadata={"question": e["question"]} | |
| ) | |
| for e in tqdm(examples, desc="π Embedding MedMCQA examples") | |
| ] | |
| # Create embedding model | |
| embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") | |
| # Build and save FAISS index | |
| db = FAISS.from_documents(docs, embeddings) | |
| os.makedirs("data", exist_ok=True) | |
| db.save_local("data/medmcqa_index") | |
| print("β Vector DB saved at data/medmcqa_index") | |
| def load_vector_store(): | |
| embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") | |
| return FAISS.load_local("data/medmcqa_index", embeddings, allow_dangerous_deserialization=True) | |
| if __name__ == "__main__": | |
| from langchain.prompts import PromptTemplate | |
| from langchain_core.output_parsers import StrOutputParser | |
| from langchain_google_genai import ChatGoogleGenerativeAI | |
| # Load DB | |
| db = load_vector_store() | |
| query = "What is the treatment for asthma?" | |
| docs = db.similarity_search(query, k=4) | |
| # Prompt Template | |
| prompt_template = PromptTemplate.from_template( | |
| """ | |
| You are a helpful medical assistant. Use only the dataset context below to answer. | |
| Context: | |
| {context} | |
| Question: | |
| {question} | |
| If you are unsure, say 'Sorry, I couldn't find an answer based on the dataset.' | |
| """ | |
| ) | |
| # LLM | |
| llm = ChatGoogleGenerativeAI(model="models/gemini-1.5-flash", temperature=0.3) | |
| chain = prompt_template | llm | StrOutputParser() | |
| print("\n\nπ§ Answer:\n", chain.invoke({"context": "\n\n".join(d.page_content for d in docs), "question": query})) | |