|
import os |
|
import dotenv |
|
import logging |
|
import gradio as gr |
|
import glob |
|
import concurrent.futures |
|
from typing import List, Any |
|
from tqdm import tqdm |
|
|
|
|
|
from langchain_community.document_loaders import PyPDFLoader |
|
from langchain_community.vectorstores import Chroma |
|
from langchain_openai import OpenAIEmbeddings, ChatOpenAI |
|
from langchain_core.prompts import ChatPromptTemplate |
|
from langchain_core.pydantic_v1 import BaseModel, Field |
|
from langchain_core.output_parsers import StrOutputParser |
|
from langchain import hub |
|
from langgraph.graph import END, StateGraph, START |
|
from typing_extensions import TypedDict |
|
from langchain_text_splitters import RecursiveCharacterTextSplitter |
|
from langchain_community.callbacks import get_openai_callback |
|
|
|
|
|
dotenv.load_dotenv() |
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
if os.getenv("OPENAI_API_KEY") is None: |
|
raise ValueError("OPENAI_API_KEY is not set in .env file") |
|
|
|
|
|
def initialize_retriever(): |
|
|
|
pdf_directory = "Data/" |
|
os.makedirs(pdf_directory, exist_ok=True) |
|
|
|
|
|
pdf_files = glob.glob(os.path.join(pdf_directory, "*.pdf")) |
|
logger.info(f"Found {len(pdf_files)} PDF files to process") |
|
|
|
if not pdf_files: |
|
logger.warning("No PDF files found in the Data directory. Please add PDF files.") |
|
return None |
|
|
|
|
|
def process_pdf(file_path: str) -> List[Any]: |
|
try: |
|
loader = PyPDFLoader(file_path) |
|
docs = loader.load() |
|
|
|
|
|
text_splitter = RecursiveCharacterTextSplitter( |
|
chunk_size=1024, |
|
chunk_overlap=100 |
|
) |
|
split_docs = [] |
|
for doc in docs: |
|
for chunk in text_splitter.split_text(doc.page_content): |
|
new_doc = doc.__class__( |
|
page_content=chunk, |
|
metadata=doc.metadata.copy() |
|
) |
|
new_doc.metadata["source_file"] = os.path.basename(file_path) |
|
new_doc.metadata["file_path"] = file_path |
|
new_doc.metadata["chunk_size"] = len(chunk) |
|
new_doc.metadata["chunk_id"] = f"{os.path.basename(file_path)}-page-{doc.metadata.get('page', '0')}-chunk" |
|
if "page" in doc.metadata: |
|
new_doc.metadata["page_num"] = doc.metadata["page"] |
|
split_docs.append(new_doc) |
|
logger.info(f"Processed {file_path}: extracted {len(split_docs)} chunks") |
|
return split_docs |
|
except Exception as e: |
|
logger.error(f"Error processing {file_path}: {str(e)}") |
|
return [] |
|
|
|
|
|
all_doc_splits = [] |
|
with concurrent.futures.ThreadPoolExecutor(max_workers=min(16, len(pdf_files))) as executor: |
|
for result in tqdm(executor.map(process_pdf, pdf_files), total=len(pdf_files), desc="Processing PDF files"): |
|
all_doc_splits.extend(result) |
|
|
|
logger.info(f"Total chunks extracted: {len(all_doc_splits)}") |
|
|
|
if not all_doc_splits: |
|
return None |
|
|
|
|
|
vectorstore = Chroma.from_documents( |
|
documents=all_doc_splits, |
|
collection_name="rag-chroma", |
|
embedding=OpenAIEmbeddings( |
|
model="text-embedding-3-large", |
|
dimensions=3072, |
|
timeout=120, |
|
), |
|
persist_directory="./chroma_rag_cache" |
|
) |
|
|
|
|
|
retriever = vectorstore.as_retriever( |
|
search_type="mmr", |
|
search_kwargs={ |
|
"k": 40, |
|
"fetch_k": 200, |
|
"lambda_mult": 0.2, |
|
"filter": None, |
|
"score_threshold": 0.7, |
|
} |
|
) |
|
|
|
return retriever |
|
|
|
|
|
def setup_components(retriever, model_choice): |
|
|
|
class GradeDocuments(BaseModel): |
|
"""Binary score for relevance check on retrieved documents.""" |
|
binary_score: str = Field( |
|
description="Documents are relevant to the question, 'yes' or 'no'" |
|
) |
|
|
|
class GradeHallucinations(BaseModel): |
|
"""Binary score for hallucination present in generation answer.""" |
|
binary_score: str = Field( |
|
description="Answer is grounded in the facts, 'yes' or 'no'" |
|
) |
|
|
|
class GradeAnswer(BaseModel): |
|
"""Binary score to assess answer addresses question.""" |
|
binary_score: str = Field( |
|
description="Answer addresses the question, 'yes' or 'no'" |
|
) |
|
|
|
|
|
llm = ChatOpenAI(model=model_choice, temperature=0) |
|
doc_grader = llm.with_structured_output(GradeDocuments) |
|
hallucination_grader_llm = llm.with_structured_output(GradeHallucinations) |
|
answer_grader_llm = llm.with_structured_output(GradeAnswer) |
|
|
|
|
|
|
|
system_doc = """You are a grader assessing relevance of a retrieved document to a user question. \n |
|
It does not need to be a stringent test. The goal is to filter out erroneous retrievals. \n |
|
If the document contains keyword(s) or semantic meaning related to the user question, grade it as relevant. \n |
|
Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question.""" |
|
grade_prompt = ChatPromptTemplate.from_messages( |
|
[ |
|
("system", system_doc), |
|
("human", "Retrieved document: \n\n {document} \n\n User question: {question}"), |
|
] |
|
) |
|
retrieval_grader = grade_prompt | doc_grader |
|
|
|
|
|
system_hallucination = """You are a grader assessing whether an LLM generation is grounded in / supported by a set of retrieved facts. \n |
|
Give a binary score 'yes' or 'no'. 'Yes' means that the answer is grounded in / supported by the set of facts.""" |
|
hallucination_prompt = ChatPromptTemplate.from_messages( |
|
[ |
|
("system", system_hallucination), |
|
("human", "Set of facts: \n\n {documents} \n\n LLM generation: {generation}"), |
|
] |
|
) |
|
hallucination_grader = hallucination_prompt | hallucination_grader_llm |
|
|
|
|
|
system_answer = """You are a grader assessing whether an answer addresses / resolves a question \n |
|
Give a binary score 'yes' or 'no'. Yes' means that the answer resolves the question.""" |
|
answer_prompt = ChatPromptTemplate.from_messages( |
|
[ |
|
("system", system_answer), |
|
("human", "User question: \n\n {question} \n\n LLM generation: {generation}"), |
|
] |
|
) |
|
answer_grader = answer_prompt | answer_grader_llm |
|
|
|
|
|
system_rewrite = """You a question re-writer that converts an input question to a better version that is optimized \n |
|
for vectorstore retrieval. Look at the input and try to reason about the underlying semantic intent / meaning.""" |
|
re_write_prompt = ChatPromptTemplate.from_messages( |
|
[ |
|
("system", system_rewrite), |
|
( |
|
"human", |
|
"Here is the initial question: \n\n {question} \n Formulate an improved question.", |
|
), |
|
] |
|
) |
|
question_rewriter = re_write_prompt | llm | StrOutputParser() |
|
|
|
|
|
prompt = hub.pull("rlm/rag-prompt") |
|
rag_chain = prompt | llm | StrOutputParser() |
|
|
|
return { |
|
"retriever": retriever, |
|
"retrieval_grader": retrieval_grader, |
|
"hallucination_grader": hallucination_grader, |
|
"answer_grader": answer_grader, |
|
"question_rewriter": question_rewriter, |
|
"rag_chain": rag_chain |
|
} |
|
|
|
|
|
def build_rag_graph(components): |
|
|
|
class GraphState(TypedDict): |
|
"""Represents the state of our graph.""" |
|
question: str |
|
generation: str |
|
documents: List[str] |
|
|
|
|
|
def retrieve(state): |
|
"""Retrieve documents""" |
|
question = state["question"] |
|
documents = components["retriever"].get_relevant_documents(question) |
|
return {"documents": documents, "question": question} |
|
|
|
def generate(state): |
|
"""Generate answer""" |
|
question = state["question"] |
|
documents = state["documents"] |
|
generation = components["rag_chain"].invoke({"context": documents, "question": question}) |
|
return {"documents": documents, "question": question, "generation": generation} |
|
|
|
def grade_documents(state): |
|
"""Determines whether the retrieved documents are relevant to the question.""" |
|
question = state["question"] |
|
documents = state["documents"] |
|
|
|
|
|
filtered_docs = [] |
|
for d in documents: |
|
score = components["retrieval_grader"].invoke( |
|
{"question": question, "document": d.page_content} |
|
) |
|
grade = score.binary_score |
|
if grade == "yes": |
|
filtered_docs.append(d) |
|
|
|
return {"documents": filtered_docs, "question": question} |
|
|
|
def transform_query(state): |
|
"""Transform the query to produce a better question.""" |
|
question = state["question"] |
|
documents = state["documents"] |
|
better_question = components["question_rewriter"].invoke({"question": question}) |
|
return {"documents": documents, "question": better_question} |
|
|
|
|
|
def decide_to_generate(state): |
|
"""Determines whether to generate an answer, or re-generate a question.""" |
|
filtered_documents = state["documents"] |
|
|
|
if not filtered_documents: |
|
|
|
return "transform_query" |
|
else: |
|
|
|
return "generate" |
|
|
|
def grade_generation_v_documents_and_question(state): |
|
"""Determines whether the generation is grounded in the document and answers question.""" |
|
question = state["question"] |
|
documents = state["documents"] |
|
generation = state["generation"] |
|
|
|
score = components["hallucination_grader"].invoke( |
|
{"documents": documents, "generation": generation} |
|
) |
|
grade = score.binary_score |
|
|
|
|
|
if grade == "yes": |
|
|
|
score = components["answer_grader"].invoke({"question": question, "generation": generation}) |
|
grade = score.binary_score |
|
if grade == "yes": |
|
return "useful" |
|
else: |
|
return "not useful" |
|
else: |
|
return "not supported" |
|
|
|
|
|
workflow = StateGraph(GraphState) |
|
|
|
|
|
workflow.add_node("retrieve", retrieve) |
|
workflow.add_node("grade_documents", grade_documents) |
|
workflow.add_node("generate", generate) |
|
workflow.add_node("transform_query", transform_query) |
|
|
|
|
|
workflow.add_edge(START, "retrieve") |
|
workflow.add_edge("retrieve", "grade_documents") |
|
workflow.add_conditional_edges( |
|
"grade_documents", |
|
decide_to_generate, |
|
{ |
|
"transform_query": "transform_query", |
|
"generate": "generate", |
|
}, |
|
) |
|
workflow.add_edge("transform_query", "retrieve") |
|
workflow.add_conditional_edges( |
|
"generate", |
|
grade_generation_v_documents_and_question, |
|
{ |
|
"not supported": "generate", |
|
"useful": END, |
|
"not useful": "transform_query", |
|
}, |
|
) |
|
|
|
|
|
return workflow.compile() |
|
|
|
|
|
retriever = None |
|
rag_app = None |
|
components = None |
|
current_model_choice = "gpt-4.1" |
|
|
|
|
|
retriever = initialize_retriever() |
|
if retriever is not None: |
|
components = setup_components(retriever, current_model_choice) |
|
rag_app = build_rag_graph(components) |
|
else: |
|
logger.error("No retriever could be initialized. Please add PDF files to the Data directory.") |
|
|
|
|
|
def process_query(question, display_logs=False, model_choice="gpt-4.1"): |
|
logs = [] |
|
answer = "" |
|
token_usage = {} |
|
try: |
|
global retriever, rag_app, components, current_model_choice |
|
if retriever is None: |
|
logs.append("Error: No PDF files found. Please add PDF files to the Data directory and restart the app.") |
|
return "Error: No PDF files found. Please add PDF files to the Data directory.", "\n".join(logs), token_usage |
|
|
|
|
|
if model_choice != current_model_choice: |
|
logs.append(f"Switching model to {model_choice} ...") |
|
components = setup_components(retriever, model_choice) |
|
rag_app = build_rag_graph(components) |
|
current_model_choice = model_choice |
|
|
|
logs.append("Processing query: " + question) |
|
logs.append(f"Using model: {model_choice}") |
|
logs.append("Starting RAG pipeline...") |
|
final_output = None |
|
|
|
with get_openai_callback() as cb: |
|
for i, output in enumerate(rag_app.stream({"question": question})): |
|
step_info = f"Step {i+1}: " |
|
if 'retrieve' in output: |
|
step_info += f"Retrieved {len(output['retrieve']['documents'])} documents" |
|
elif 'grade_documents' in output: |
|
step_info += f"Graded documents, {len(output['grade_documents']['documents'])} deemed relevant" |
|
elif 'transform_query' in output: |
|
step_info += f"Transformed query to: {output['transform_query']['question']}" |
|
elif 'generate' in output: |
|
step_info += "Generated answer" |
|
final_output = output |
|
logs.append(step_info) |
|
|
|
|
|
token_usage = { |
|
"total_tokens": cb.total_tokens, |
|
"prompt_tokens": cb.prompt_tokens, |
|
"completion_tokens": cb.completion_tokens, |
|
|
|
"total_cost": cb.total_cost |
|
} |
|
logs.append(f"Token usage: {token_usage}") |
|
|
|
if final_output and 'generate' in final_output: |
|
answer = final_output['generate']['generation'] |
|
logs.append("Final answer generated successfully") |
|
else: |
|
answer = "No answer could be generated. Please try rephrasing your question." |
|
logs.append("Failed to generate answer") |
|
except Exception as e: |
|
logs.append(f"Error: {str(e)}") |
|
answer = f"An error occurred: {str(e)}" |
|
return answer, "\n".join(logs) if display_logs else "", token_usage |
|
|
|
|
|
with gr.Blocks(title="Self-RAG Document Assistant", theme=gr.themes.Base()) as demo: |
|
with gr.Row(): |
|
gr.Markdown("# Self-RAG Document Assistant") |
|
|
|
with gr.Row(): |
|
gr.Markdown("""This application uses a Self-RAG (Retrieval Augmented Generation) system to |
|
provide accurate answers by: |
|
1. Retrieving relevant documents from your PDF database |
|
2. Grading document relevance to your question |
|
3. Generating answers grounded in these documents |
|
4. Self-checking for hallucinations and question addressing""") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=3): |
|
query_input = gr.Textbox( |
|
label="Your Question", |
|
placeholder="Ask a question about your documents...", |
|
lines=4 |
|
) |
|
|
|
with gr.Column(scale=1): |
|
model_choice_input = gr.Dropdown( |
|
label="Model", |
|
choices=["gpt-4.1", "gpt-4.1-mini", "gpt-4.1-nano"], |
|
value="gpt-4.1" |
|
) |
|
show_logs = gr.Checkbox(label="Show Debugging Logs", value=False) |
|
submit_btn = gr.Button("Submit", variant="primary") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
answer_output = gr.Textbox( |
|
label="Answer", |
|
lines=10, |
|
placeholder="Your answer will appear here...", |
|
) |
|
|
|
with gr.Row(): |
|
logs_output = gr.Textbox( |
|
label="Process Logs", |
|
lines=15, |
|
visible=False |
|
) |
|
|
|
with gr.Row(): |
|
token_usage_output = gr.JSON( |
|
label="Token Usage Statistics", |
|
visible=True |
|
) |
|
|
|
|
|
submit_btn.click( |
|
fn=process_query, |
|
inputs=[query_input, show_logs, model_choice_input], |
|
outputs=[answer_output, logs_output, token_usage_output] |
|
) |
|
|
|
show_logs.change( |
|
fn=lambda x: gr.update(visible=x), |
|
inputs=[show_logs], |
|
outputs=[logs_output] |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch(share=False) |
|
|