RAG_MED / app.py
Illia56's picture
Update app.py
a0cfbfc verified
import os
import dotenv
import logging
import gradio as gr
import glob
import concurrent.futures
from typing import List, Any
from tqdm import tqdm
# LangChain imports
from langchain_community.document_loaders import PyPDFLoader
from langchain_community.vectorstores import Chroma
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.output_parsers import StrOutputParser
from langchain import hub
from langgraph.graph import END, StateGraph, START
from typing_extensions import TypedDict
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.callbacks import get_openai_callback
# Load environment variables
dotenv.load_dotenv()
# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# Check if OpenAI API key is set
if os.getenv("OPENAI_API_KEY") is None:
raise ValueError("OPENAI_API_KEY is not set in .env file")
# Initialize Retriever
def initialize_retriever():
# Ensure data directory exists
pdf_directory = "Data/"
os.makedirs(pdf_directory, exist_ok=True)
# Discover PDF files
pdf_files = glob.glob(os.path.join(pdf_directory, "*.pdf"))
logger.info(f"Found {len(pdf_files)} PDF files to process")
if not pdf_files:
logger.warning("No PDF files found in the Data directory. Please add PDF files.")
return None
# Document processing function
def process_pdf(file_path: str) -> List[Any]:
try:
loader = PyPDFLoader(file_path)
docs = loader.load() # Each doc is a page
# Split each page into smaller chunks
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1024, # or 512, adjust as needed
chunk_overlap=100
)
split_docs = []
for doc in docs:
for chunk in text_splitter.split_text(doc.page_content):
new_doc = doc.__class__(
page_content=chunk,
metadata=doc.metadata.copy()
)
new_doc.metadata["source_file"] = os.path.basename(file_path)
new_doc.metadata["file_path"] = file_path
new_doc.metadata["chunk_size"] = len(chunk)
new_doc.metadata["chunk_id"] = f"{os.path.basename(file_path)}-page-{doc.metadata.get('page', '0')}-chunk"
if "page" in doc.metadata:
new_doc.metadata["page_num"] = doc.metadata["page"]
split_docs.append(new_doc)
logger.info(f"Processed {file_path}: extracted {len(split_docs)} chunks")
return split_docs
except Exception as e:
logger.error(f"Error processing {file_path}: {str(e)}")
return []
# Process PDFs and extract chunks
all_doc_splits = []
with concurrent.futures.ThreadPoolExecutor(max_workers=min(16, len(pdf_files))) as executor:
for result in tqdm(executor.map(process_pdf, pdf_files), total=len(pdf_files), desc="Processing PDF files"):
all_doc_splits.extend(result)
logger.info(f"Total chunks extracted: {len(all_doc_splits)}")
if not all_doc_splits:
return None
# Create vector store
vectorstore = Chroma.from_documents(
documents=all_doc_splits,
collection_name="rag-chroma",
embedding=OpenAIEmbeddings(
model="text-embedding-3-large",
dimensions=3072,
timeout=120,
),
persist_directory="./chroma_rag_cache"
)
# Configure retriever
retriever = vectorstore.as_retriever(
search_type="mmr",
search_kwargs={
"k": 40,
"fetch_k": 200,
"lambda_mult": 0.2,
"filter": None,
"score_threshold": 0.7,
}
)
return retriever
# Define graders and components
def setup_components(retriever, model_choice):
# Data models for grading
class GradeDocuments(BaseModel):
"""Binary score for relevance check on retrieved documents."""
binary_score: str = Field(
description="Documents are relevant to the question, 'yes' or 'no'"
)
class GradeHallucinations(BaseModel):
"""Binary score for hallucination present in generation answer."""
binary_score: str = Field(
description="Answer is grounded in the facts, 'yes' or 'no'"
)
class GradeAnswer(BaseModel):
"""Binary score to assess answer addresses question."""
binary_score: str = Field(
description="Answer addresses the question, 'yes' or 'no'"
)
# LLM models
llm = ChatOpenAI(model=model_choice, temperature=0)
doc_grader = llm.with_structured_output(GradeDocuments)
hallucination_grader_llm = llm.with_structured_output(GradeHallucinations)
answer_grader_llm = llm.with_structured_output(GradeAnswer)
# Prompts
# Document grading prompt
system_doc = """You are a grader assessing relevance of a retrieved document to a user question. \n
It does not need to be a stringent test. The goal is to filter out erroneous retrievals. \n
If the document contains keyword(s) or semantic meaning related to the user question, grade it as relevant. \n
Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question."""
grade_prompt = ChatPromptTemplate.from_messages(
[
("system", system_doc),
("human", "Retrieved document: \n\n {document} \n\n User question: {question}"),
]
)
retrieval_grader = grade_prompt | doc_grader
# Hallucination grading prompt
system_hallucination = """You are a grader assessing whether an LLM generation is grounded in / supported by a set of retrieved facts. \n
Give a binary score 'yes' or 'no'. 'Yes' means that the answer is grounded in / supported by the set of facts."""
hallucination_prompt = ChatPromptTemplate.from_messages(
[
("system", system_hallucination),
("human", "Set of facts: \n\n {documents} \n\n LLM generation: {generation}"),
]
)
hallucination_grader = hallucination_prompt | hallucination_grader_llm
# Answer grading prompt
system_answer = """You are a grader assessing whether an answer addresses / resolves a question \n
Give a binary score 'yes' or 'no'. Yes' means that the answer resolves the question."""
answer_prompt = ChatPromptTemplate.from_messages(
[
("system", system_answer),
("human", "User question: \n\n {question} \n\n LLM generation: {generation}"),
]
)
answer_grader = answer_prompt | answer_grader_llm
# Question rewriter prompt
system_rewrite = """You a question re-writer that converts an input question to a better version that is optimized \n
for vectorstore retrieval. Look at the input and try to reason about the underlying semantic intent / meaning."""
re_write_prompt = ChatPromptTemplate.from_messages(
[
("system", system_rewrite),
(
"human",
"Here is the initial question: \n\n {question} \n Formulate an improved question.",
),
]
)
question_rewriter = re_write_prompt | llm | StrOutputParser()
# RAG generation prompt and chain
prompt = hub.pull("rlm/rag-prompt")
rag_chain = prompt | llm | StrOutputParser()
return {
"retriever": retriever,
"retrieval_grader": retrieval_grader,
"hallucination_grader": hallucination_grader,
"answer_grader": answer_grader,
"question_rewriter": question_rewriter,
"rag_chain": rag_chain
}
# Build the RAG graph
def build_rag_graph(components):
# Define graph state
class GraphState(TypedDict):
"""Represents the state of our graph."""
question: str
generation: str
documents: List[str]
# Node functions
def retrieve(state):
"""Retrieve documents"""
question = state["question"]
documents = components["retriever"].get_relevant_documents(question)
return {"documents": documents, "question": question}
def generate(state):
"""Generate answer"""
question = state["question"]
documents = state["documents"]
generation = components["rag_chain"].invoke({"context": documents, "question": question})
return {"documents": documents, "question": question, "generation": generation}
def grade_documents(state):
"""Determines whether the retrieved documents are relevant to the question."""
question = state["question"]
documents = state["documents"]
# Score each doc
filtered_docs = []
for d in documents:
score = components["retrieval_grader"].invoke(
{"question": question, "document": d.page_content}
)
grade = score.binary_score
if grade == "yes":
filtered_docs.append(d)
return {"documents": filtered_docs, "question": question}
def transform_query(state):
"""Transform the query to produce a better question."""
question = state["question"]
documents = state["documents"]
better_question = components["question_rewriter"].invoke({"question": question})
return {"documents": documents, "question": better_question}
# Edge functions
def decide_to_generate(state):
"""Determines whether to generate an answer, or re-generate a question."""
filtered_documents = state["documents"]
if not filtered_documents:
# All documents have been filtered out
return "transform_query"
else:
# We have relevant documents, so generate answer
return "generate"
def grade_generation_v_documents_and_question(state):
"""Determines whether the generation is grounded in the document and answers question."""
question = state["question"]
documents = state["documents"]
generation = state["generation"]
score = components["hallucination_grader"].invoke(
{"documents": documents, "generation": generation}
)
grade = score.binary_score
# Check hallucination
if grade == "yes":
# Check question-answering
score = components["answer_grader"].invoke({"question": question, "generation": generation})
grade = score.binary_score
if grade == "yes":
return "useful"
else:
return "not useful"
else:
return "not supported"
# Build the graph
workflow = StateGraph(GraphState)
# Add nodes
workflow.add_node("retrieve", retrieve)
workflow.add_node("grade_documents", grade_documents)
workflow.add_node("generate", generate)
workflow.add_node("transform_query", transform_query)
# Add edges
workflow.add_edge(START, "retrieve")
workflow.add_edge("retrieve", "grade_documents")
workflow.add_conditional_edges(
"grade_documents",
decide_to_generate,
{
"transform_query": "transform_query",
"generate": "generate",
},
)
workflow.add_edge("transform_query", "retrieve")
workflow.add_conditional_edges(
"generate",
grade_generation_v_documents_and_question,
{
"not supported": "generate",
"useful": END,
"not useful": "transform_query",
},
)
# Compile the graph
return workflow.compile()
# Initialize global variables
retriever = None
rag_app = None
components = None
current_model_choice = "gpt-4.1" # Default
# Run PDF processing and RAG setup ONCE at startup, with default model
retriever = initialize_retriever()
if retriever is not None:
components = setup_components(retriever, current_model_choice)
rag_app = build_rag_graph(components)
else:
logger.error("No retriever could be initialized. Please add PDF files to the Data directory.")
# Processing function for Gradio
def process_query(question, display_logs=False, model_choice="gpt-4.1"):
logs = []
answer = ""
token_usage = {}
try:
global retriever, rag_app, components, current_model_choice
if retriever is None:
logs.append("Error: No PDF files found. Please add PDF files to the Data directory and restart the app.")
return "Error: No PDF files found. Please add PDF files to the Data directory.", "\n".join(logs), token_usage
# If model_choice changed, re-initialize components and rag_app
if model_choice != current_model_choice:
logs.append(f"Switching model to {model_choice} ...")
components = setup_components(retriever, model_choice)
rag_app = build_rag_graph(components)
current_model_choice = model_choice
logs.append("Processing query: " + question)
logs.append(f"Using model: {model_choice}")
logs.append("Starting RAG pipeline...")
final_output = None
with get_openai_callback() as cb:
for i, output in enumerate(rag_app.stream({"question": question})):
step_info = f"Step {i+1}: "
if 'retrieve' in output:
step_info += f"Retrieved {len(output['retrieve']['documents'])} documents"
elif 'grade_documents' in output:
step_info += f"Graded documents, {len(output['grade_documents']['documents'])} deemed relevant"
elif 'transform_query' in output:
step_info += f"Transformed query to: {output['transform_query']['question']}"
elif 'generate' in output:
step_info += "Generated answer"
final_output = output
logs.append(step_info)
# Store token usage information
token_usage = {
"total_tokens": cb.total_tokens,
"prompt_tokens": cb.prompt_tokens,
"completion_tokens": cb.completion_tokens,
"total_cost": cb.total_cost
}
logs.append(f"Token usage: {token_usage}")
if final_output and 'generate' in final_output:
answer = final_output['generate']['generation']
logs.append("Final answer generated successfully")
else:
answer = "No answer could be generated. Please try rephrasing your question."
logs.append("Failed to generate answer")
except Exception as e:
logs.append(f"Error: {str(e)}")
answer = f"An error occurred: {str(e)}"
return answer, "\n".join(logs) if display_logs else "", token_usage
# Create Gradio interface
with gr.Blocks(title="Self-RAG Document Assistant", theme=gr.themes.Base()) as demo:
with gr.Row():
gr.Markdown("# Self-RAG Document Assistant")
with gr.Row():
gr.Markdown("""This application uses a Self-RAG (Retrieval Augmented Generation) system to
provide accurate answers by:
1. Retrieving relevant documents from your PDF database
2. Grading document relevance to your question
3. Generating answers grounded in these documents
4. Self-checking for hallucinations and question addressing""")
with gr.Row():
with gr.Column(scale=3):
query_input = gr.Textbox(
label="Your Question",
placeholder="Ask a question about your documents...",
lines=4
)
with gr.Column(scale=1):
model_choice_input = gr.Dropdown(
label="Model",
choices=["gpt-4.1", "gpt-4.1-mini", "gpt-4.1-nano"],
value="gpt-4.1"
)
show_logs = gr.Checkbox(label="Show Debugging Logs", value=False)
submit_btn = gr.Button("Submit", variant="primary")
with gr.Row():
with gr.Column():
answer_output = gr.Textbox(
label="Answer",
lines=10,
placeholder="Your answer will appear here...",
)
with gr.Row():
logs_output = gr.Textbox(
label="Process Logs",
lines=15,
visible=False
)
with gr.Row():
token_usage_output = gr.JSON(
label="Token Usage Statistics",
visible=True
)
# Event handlers
submit_btn.click(
fn=process_query,
inputs=[query_input, show_logs, model_choice_input],
outputs=[answer_output, logs_output, token_usage_output]
)
show_logs.change(
fn=lambda x: gr.update(visible=x),
inputs=[show_logs],
outputs=[logs_output]
)
# Launch the app
if __name__ == "__main__":
demo.launch(share=False)