Illia56's picture
Upload 700 files
b23e1bc verified
import os
import dotenv
import logging
import gradio as gr
import glob
import concurrent.futures
from typing import List, Any
from tqdm import tqdm
# LangChain imports
from langchain_community.document_loaders import PyPDFLoader
from langchain_community.vectorstores import Chroma
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.output_parsers import StrOutputParser
from langchain import hub
from langgraph.graph import END, StateGraph, START
from typing_extensions import TypedDict
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.callbacks import get_openai_callback
# Load environment variables
dotenv.load_dotenv()
# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# Check if OpenAI API key is set
if os.getenv("OPENAI_API_KEY") is None:
raise ValueError("OPENAI_API_KEY is not set in .env file")
# Initialize Retriever for all Markdown files in /MarkdownOutput
def initialize_retriever():
from langchain_community.document_loaders import UnstructuredMarkdownLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
# Find all markdown files in /MarkdownOutput
markdown_files = glob.glob("./MarkdownOutput/**/*.md", recursive=True)
logger.info(f"Found {len(markdown_files)} markdown files in ./MarkdownOutput.")
# Load and split all markdown documents
all_doc_splits = []
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
for idx, md_path in enumerate(markdown_files, 1):
logger.info(f"Loading and splitting file {idx}/{len(markdown_files)}: {md_path}")
loader = UnstructuredMarkdownLoader(md_path)
docs = loader.load()
splits = text_splitter.split_documents(docs)
all_doc_splits.extend(splits)
logger.info(f"File {md_path} loaded and split into {len(splits)} chunks.")
logger.info(f"Total document splits: {len(all_doc_splits)}. Creating vector store...")
# Create vector store
vectorstore = Chroma.from_documents(
documents=all_doc_splits,
collection_name="rag-chroma",
embedding=OpenAIEmbeddings(
model="text-embedding-3-large",
dimensions=3072,
timeout=120,
),
persist_directory="./chroma_rag_cache"
)
logger.info("Vector store created and persisted to ./chroma_rag_cache.")
# Configure retriever
retriever = vectorstore.as_retriever(
search_type="mmr",
search_kwargs={
"k": 40,
"fetch_k": 200,
"lambda_mult": 0.2,
"filter": None,
"score_threshold": 0.7,
}
)
logger.info("Retriever configured and ready to use.")
return retriever
# Define graders and components
def setup_components(retriever, model_choice):
# Data models for grading
class GradeDocuments(BaseModel):
"""Binary score for relevance check on retrieved documents."""
binary_score: str = Field(
description="Documents are relevant to the question, 'yes' or 'no'"
)
class GradeHallucinations(BaseModel):
"""Binary score for hallucination present in generation answer."""
binary_score: str = Field(
description="Answer is grounded in the facts, 'yes' or 'no'"
)
class GradeAnswer(BaseModel):
"""Binary score to assess answer addresses question."""
binary_score: str = Field(
description="Answer addresses the question, 'yes' or 'no'"
)
# LLM models
llm = ChatOpenAI(model=model_choice, temperature=0)
doc_grader = llm.with_structured_output(GradeDocuments)
hallucination_grader_llm = llm.with_structured_output(GradeHallucinations)
answer_grader_llm = llm.with_structured_output(GradeAnswer)
# Prompts
# Document grading prompt
system_doc = """You are a grader assessing relevance of a retrieved document to a user question. \n
It does not need to be a stringent test. The goal is to filter out erroneous retrievals. \n
If the document contains keyword(s) or semantic meaning related to the user question, grade it as relevant. \n
Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question."""
grade_prompt = ChatPromptTemplate.from_messages(
[
("system", system_doc),
("human", "Retrieved document: \n\n {document} \n\n User question: {question}"),
]
)
retrieval_grader = grade_prompt | doc_grader
# Hallucination grading prompt
system_hallucination = """You are a grader assessing whether an LLM generation is grounded in / supported by a set of retrieved facts. \n
Give a binary score 'yes' or 'no'. 'Yes' means that the answer is grounded in / supported by the set of facts."""
hallucination_prompt = ChatPromptTemplate.from_messages(
[
("system", system_hallucination),
("human", "Set of facts: \n\n {documents} \n\n LLM generation: {generation}"),
]
)
hallucination_grader = hallucination_prompt | hallucination_grader_llm
# Answer grading prompt
system_answer = """You are a grader assessing whether an answer addresses / resolves a question \n
Give a binary score 'yes' or 'no'. Yes' means that the answer resolves the question."""
answer_prompt = ChatPromptTemplate.from_messages(
[
("system", system_answer),
("human", "User question: \n\n {question} \n\n LLM generation: {generation}"),
]
)
answer_grader = answer_prompt | answer_grader_llm
# Question rewriter prompt
system_rewrite = """You a question re-writer that converts an input question to a better version that is optimized \n
for vectorstore retrieval. Look at the input and try to reason about the underlying semantic intent / meaning."""
re_write_prompt = ChatPromptTemplate.from_messages(
[
("system", system_rewrite),
(
"human",
"Here is the initial question: \n\n {question} \n Formulate an improved question.",
),
]
)
question_rewriter = re_write_prompt | llm | StrOutputParser()
# RAG generation prompt and chain
prompt = hub.pull("rlm/rag-prompt")
rag_chain = prompt | llm | StrOutputParser()
return {
"retriever": retriever,
"retrieval_grader": retrieval_grader,
"hallucination_grader": hallucination_grader,
"answer_grader": answer_grader,
"question_rewriter": question_rewriter,
"rag_chain": rag_chain
}
# Build the RAG graph
def build_rag_graph(components):
# Define graph state
class GraphState(TypedDict):
"""Represents the state of our graph."""
question: str
generation: str
documents: List[str]
# Node functions
def retrieve(state):
"""Retrieve documents"""
question = state["question"]
documents = components["retriever"].get_relevant_documents(question)
return {"documents": documents, "question": question}
def generate(state):
"""Generate answer"""
question = state["question"]
documents = state["documents"]
generation = components["rag_chain"].invoke({"context": documents, "question": question})
return {"documents": documents, "question": question, "generation": generation}
def grade_documents(state):
"""Determines whether the retrieved documents are relevant to the question."""
question = state["question"]
documents = state["documents"]
# Score each doc
filtered_docs = []
for d in documents:
score = components["retrieval_grader"].invoke(
{"question": question, "document": d.page_content}
)
grade = score.binary_score
if grade == "yes":
filtered_docs.append(d)
return {"documents": filtered_docs, "question": question}
def transform_query(state):
"""Transform the query to produce a better question."""
question = state["question"]
documents = state["documents"]
better_question = components["question_rewriter"].invoke({"question": question})
return {"documents": documents, "question": better_question}
# Edge functions
def decide_to_generate(state):
"""Determines whether to generate an answer, or re-generate a question."""
filtered_documents = state["documents"]
if not filtered_documents:
# All documents have been filtered out
return "transform_query"
else:
# We have relevant documents, so generate answer
return "generate"
def grade_generation_v_documents_and_question(state):
"""Determines whether the generation is grounded in the document and answers question."""
question = state["question"]
documents = state["documents"]
generation = state["generation"]
score = components["hallucination_grader"].invoke(
{"documents": documents, "generation": generation}
)
grade = score.binary_score
# Check hallucination
if grade == "yes":
# Check question-answering
score = components["answer_grader"].invoke({"question": question, "generation": generation})
grade = score.binary_score
if grade == "yes":
return "useful"
else:
return "not useful"
else:
return "not supported"
# Build the graph
workflow = StateGraph(GraphState)
# Add nodes
workflow.add_node("retrieve", retrieve)
workflow.add_node("grade_documents", grade_documents)
workflow.add_node("generate", generate)
workflow.add_node("transform_query", transform_query)
# Add edges
workflow.add_edge(START, "retrieve")
workflow.add_edge("retrieve", "grade_documents")
workflow.add_conditional_edges(
"grade_documents",
decide_to_generate,
{
"transform_query": "transform_query",
"generate": "generate",
},
)
workflow.add_edge("transform_query", "retrieve")
workflow.add_conditional_edges(
"generate",
grade_generation_v_documents_and_question,
{
"not supported": "generate",
"useful": END,
"not useful": "transform_query",
},
)
# Compile the graph
return workflow.compile()
# Initialize global variables
retriever = None
rag_app = None
components = None
current_model_choice = "gpt-4.1" # Default
# Run PDF processing and RAG setup ONCE at startup, with default model
retriever = initialize_retriever()
if retriever is not None:
components = setup_components(retriever, current_model_choice)
rag_app = build_rag_graph(components)
else:
logger.error("No retriever could be initialized. Please add PDF files to the Data directory.")
# Processing function for Gradio
def process_query(question, display_logs=False, model_choice="gpt-4.1"):
logs = []
answer = ""
token_usage = {}
try:
global retriever, rag_app, components, current_model_choice
if retriever is None:
logs.append("Error: No PDF files found. Please add PDF files to the Data directory and restart the app.")
return "Error: No PDF files found. Please add PDF files to the Data directory.", "\n".join(logs), token_usage
# If model_choice changed, re-initialize components and rag_app
if model_choice != current_model_choice:
logs.append(f"Switching model to {model_choice} ...")
components = setup_components(retriever, model_choice)
rag_app = build_rag_graph(components)
current_model_choice = model_choice
logs.append("Processing query: " + question)
logs.append(f"Using model: {model_choice}")
logs.append("Starting RAG pipeline...")
final_output = None
with get_openai_callback() as cb:
for i, output in enumerate(rag_app.stream({"question": question})):
step_info = f"Step {i+1}: "
if 'retrieve' in output:
step_info += f"Retrieved {len(output['retrieve']['documents'])} documents"
elif 'grade_documents' in output:
step_info += f"Graded documents, {len(output['grade_documents']['documents'])} deemed relevant"
elif 'transform_query' in output:
step_info += f"Transformed query to: {output['transform_query']['question']}"
elif 'generate' in output:
step_info += "Generated answer"
final_output = output
logs.append(step_info)
# Store token usage information
token_usage = {
"total_tokens": cb.total_tokens,
"prompt_tokens": cb.prompt_tokens,
"completion_tokens": cb.completion_tokens,
"total_cost": cb.total_cost
}
logs.append(f"Token usage: {token_usage}")
if final_output and 'generate' in final_output:
answer = final_output['generate']['generation']
logs.append("Final answer generated successfully")
else:
answer = "No answer could be generated. Please try rephrasing your question."
logs.append("Failed to generate answer")
except Exception as e:
logs.append(f"Error: {str(e)}")
answer = f"An error occurred: {str(e)}"
return answer, "\n".join(logs) if display_logs else "", token_usage
# Create Gradio interface
with gr.Blocks(title="Self-RAG Document Assistant", theme=gr.themes.Base()) as demo:
with gr.Row():
gr.Markdown("# Self-RAG Document Assistant")
with gr.Row():
gr.Markdown("""This application uses a Self-RAG (Retrieval Augmented Generation) system to
provide accurate answers by:
1. Retrieving relevant documents from your PDF database
2. Grading document relevance to your question
3. Generating answers grounded in these documents
4. Self-checking for hallucinations and question addressing""")
with gr.Row():
with gr.Column(scale=3):
query_input = gr.Textbox(
label="Your Question",
placeholder="Ask a question about your documents...",
lines=4
)
with gr.Column(scale=1):
model_choice_input = gr.Dropdown(
label="Model",
choices=["gpt-4.1", "gpt-4.1-mini", "gpt-4.1-nano"],
value="gpt-4.1"
)
show_logs = gr.Checkbox(label="Show Debugging Logs", value=False)
submit_btn = gr.Button("Submit", variant="primary")
with gr.Row():
with gr.Column():
answer_output = gr.Textbox(
label="Answer",
lines=10,
placeholder="Your answer will appear here...",
)
with gr.Row():
logs_output = gr.Textbox(
label="Process Logs",
lines=15,
visible=False
)
with gr.Row():
token_usage_output = gr.JSON(
label="Token Usage Statistics",
visible=True
)
# Event handlers
submit_btn.click(
fn=process_query,
inputs=[query_input, show_logs, model_choice_input],
outputs=[answer_output, logs_output, token_usage_output]
)
show_logs.change(
fn=lambda x: gr.update(visible=x),
inputs=[show_logs],
outputs=[logs_output]
)
# Launch the app
if __name__ == "__main__":
demo.launch(share=False)