Spaces:
Runtime error
Runtime error
import os | |
from langchain_community.document_loaders import PyPDFDirectoryLoader | |
from langchain_text_splitters import RecursiveCharacterTextSplitter | |
from langchain_community.embeddings import SentenceTransformerEmbeddings | |
from langchain.vectorstores import Chroma | |
from langchain_community.llms import LlamaCpp | |
from langchain.prompts import ChatPromptTemplate | |
from langchain.schema.runnable import RunnablePassthrough | |
from langchain.schema.output_parser import StrOutputParser | |
import gradio as gr | |
# Environment variable for Hugging Face API token | |
os.environ["HUGGINGFACEHUB_API_TOKEN"] = os.getenv("HUGGINGFACEHUB_API_TOKEN") | |
# Paths for PDFs and model (upload these to the Hugging Face Space) | |
PDF_DIR = "./Data" # Replace with the path where you upload your PDFs | |
MODEL_PATH = "./BioMistral-7B.Q4_K_M.gguf" # Replace with the model's path in the Space | |
# Load and process PDF documents | |
loader = PyPDFDirectoryLoader(PDF_DIR) | |
docs = loader.load() | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=300, chunk_overlap=50) | |
chunks = text_splitter.split_documents(docs) | |
# Create embeddings and vector store | |
embeddings = SentenceTransformerEmbeddings(model_name="NeuML/pubmedbert-base-embeddings") | |
vectorstore = Chroma.from_documents(chunks, embeddings) | |
# Retriever for querying | |
retriever = vectorstore.as_retriever(search_kwargs={"k": 5}) | |
# Initialize the LLM | |
llm = LlamaCpp( | |
model_path=MODEL_PATH, | |
temperature=0.2, | |
max_tokens=2048, | |
top_p=1 | |
) | |
# Define the prompt template | |
template = """ | |
<|context|> | |
You are a Medical Assistant that follows instructions and generates accurate responses based on the query and the context provided. | |
Please be truthful and give direct answers. | |
</s> | |
<|user|> | |
{query} | |
</s> | |
<|assistant|> | |
""" | |
prompt = ChatPromptTemplate.from_template(template) | |
# Define the RAG chain | |
rag_chain = ( | |
{"context": retriever, "query": RunnablePassthrough()} | |
| prompt | |
| llm | |
| StrOutputParser() | |
) | |
# Define a function for the Gradio UI | |
def chatbot_ui(user_query): | |
if not user_query.strip(): | |
return "Please enter a valid query." | |
try: | |
result = rag_chain.invoke(user_query) | |
return result | |
except Exception as e: | |
return f"Error: {str(e)}" | |
# Create the Gradio interface | |
interface = gr.Interface( | |
fn=chatbot_ui, | |
inputs=gr.Textbox(label="Enter your medical query:", placeholder="Ask a medical question here..."), | |
outputs=gr.Textbox(label="Chatbot Response"), | |
title="Medical Assistant Chatbot", | |
description="A chatbot designed for heart patients, providing accurate and reliable medical information.", | |
examples=[ | |
["What are the symptoms of diabetes?"], | |
["Explain the risk factors of heart disease."], | |
["How can I reduce cholesterol levels naturally?"], | |
] | |
) | |
# Launch the Gradio app | |
if __name__ == "__main__": | |
interface.launch() | |