Spaces:
Sleeping
Sleeping
import os | |
import dotenv | |
import logging | |
import gradio as gr | |
import glob | |
import concurrent.futures | |
from typing import List, Any | |
from tqdm import tqdm | |
# LangChain imports | |
from langchain_community.document_loaders import PyPDFLoader | |
from langchain_community.vectorstores import Chroma | |
from langchain_openai import OpenAIEmbeddings, ChatOpenAI | |
from langchain_core.prompts import ChatPromptTemplate | |
from langchain_core.pydantic_v1 import BaseModel, Field | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain import hub | |
from langgraph.graph import END, StateGraph, START | |
from typing_extensions import TypedDict | |
from langchain_text_splitters import RecursiveCharacterTextSplitter | |
from langchain_community.callbacks import get_openai_callback | |
# Load environment variables | |
dotenv.load_dotenv() | |
# Set up logging | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
logger = logging.getLogger(__name__) | |
# Check if OpenAI API key is set | |
if os.getenv("OPENAI_API_KEY") is None: | |
raise ValueError("OPENAI_API_KEY is not set in .env file") | |
# Initialize Retriever for all Markdown files in /MarkdownOutput | |
def initialize_retriever(): | |
from langchain_community.document_loaders import UnstructuredMarkdownLoader | |
from langchain_text_splitters import RecursiveCharacterTextSplitter | |
# Find all markdown files in /MarkdownOutput | |
markdown_files = glob.glob("./MarkdownOutput/**/*.md", recursive=True) | |
logger.info(f"Found {len(markdown_files)} markdown files in ./MarkdownOutput.") | |
# Load and split all markdown documents | |
all_doc_splits = [] | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200) | |
for idx, md_path in enumerate(markdown_files, 1): | |
logger.info(f"Loading and splitting file {idx}/{len(markdown_files)}: {md_path}") | |
loader = UnstructuredMarkdownLoader(md_path) | |
docs = loader.load() | |
splits = text_splitter.split_documents(docs) | |
all_doc_splits.extend(splits) | |
logger.info(f"File {md_path} loaded and split into {len(splits)} chunks.") | |
logger.info(f"Total document splits: {len(all_doc_splits)}. Creating vector store...") | |
# Create vector store | |
vectorstore = Chroma.from_documents( | |
documents=all_doc_splits, | |
collection_name="rag-chroma", | |
embedding=OpenAIEmbeddings( | |
model="text-embedding-3-large", | |
dimensions=3072, | |
timeout=120, | |
), | |
persist_directory="./chroma_rag_cache" | |
) | |
logger.info("Vector store created and persisted to ./chroma_rag_cache.") | |
# Configure retriever | |
retriever = vectorstore.as_retriever( | |
search_type="mmr", | |
search_kwargs={ | |
"k": 40, | |
"fetch_k": 200, | |
"lambda_mult": 0.2, | |
"filter": None, | |
"score_threshold": 0.7, | |
} | |
) | |
logger.info("Retriever configured and ready to use.") | |
return retriever | |
# Define graders and components | |
def setup_components(retriever, model_choice): | |
# Data models for grading | |
class GradeDocuments(BaseModel): | |
"""Binary score for relevance check on retrieved documents.""" | |
binary_score: str = Field( | |
description="Documents are relevant to the question, 'yes' or 'no'" | |
) | |
class GradeHallucinations(BaseModel): | |
"""Binary score for hallucination present in generation answer.""" | |
binary_score: str = Field( | |
description="Answer is grounded in the facts, 'yes' or 'no'" | |
) | |
class GradeAnswer(BaseModel): | |
"""Binary score to assess answer addresses question.""" | |
binary_score: str = Field( | |
description="Answer addresses the question, 'yes' or 'no'" | |
) | |
# LLM models | |
llm = ChatOpenAI(model=model_choice, temperature=0) | |
doc_grader = llm.with_structured_output(GradeDocuments) | |
hallucination_grader_llm = llm.with_structured_output(GradeHallucinations) | |
answer_grader_llm = llm.with_structured_output(GradeAnswer) | |
# Prompts | |
# Document grading prompt | |
system_doc = """You are a grader assessing relevance of a retrieved document to a user question. \n | |
It does not need to be a stringent test. The goal is to filter out erroneous retrievals. \n | |
If the document contains keyword(s) or semantic meaning related to the user question, grade it as relevant. \n | |
Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question.""" | |
grade_prompt = ChatPromptTemplate.from_messages( | |
[ | |
("system", system_doc), | |
("human", "Retrieved document: \n\n {document} \n\n User question: {question}"), | |
] | |
) | |
retrieval_grader = grade_prompt | doc_grader | |
# Hallucination grading prompt | |
system_hallucination = """You are a grader assessing whether an LLM generation is grounded in / supported by a set of retrieved facts. \n | |
Give a binary score 'yes' or 'no'. 'Yes' means that the answer is grounded in / supported by the set of facts.""" | |
hallucination_prompt = ChatPromptTemplate.from_messages( | |
[ | |
("system", system_hallucination), | |
("human", "Set of facts: \n\n {documents} \n\n LLM generation: {generation}"), | |
] | |
) | |
hallucination_grader = hallucination_prompt | hallucination_grader_llm | |
# Answer grading prompt | |
system_answer = """You are a grader assessing whether an answer addresses / resolves a question \n | |
Give a binary score 'yes' or 'no'. Yes' means that the answer resolves the question.""" | |
answer_prompt = ChatPromptTemplate.from_messages( | |
[ | |
("system", system_answer), | |
("human", "User question: \n\n {question} \n\n LLM generation: {generation}"), | |
] | |
) | |
answer_grader = answer_prompt | answer_grader_llm | |
# Question rewriter prompt | |
system_rewrite = """You a question re-writer that converts an input question to a better version that is optimized \n | |
for vectorstore retrieval. Look at the input and try to reason about the underlying semantic intent / meaning.""" | |
re_write_prompt = ChatPromptTemplate.from_messages( | |
[ | |
("system", system_rewrite), | |
( | |
"human", | |
"Here is the initial question: \n\n {question} \n Formulate an improved question.", | |
), | |
] | |
) | |
question_rewriter = re_write_prompt | llm | StrOutputParser() | |
# RAG generation prompt and chain | |
prompt = hub.pull("rlm/rag-prompt") | |
rag_chain = prompt | llm | StrOutputParser() | |
return { | |
"retriever": retriever, | |
"retrieval_grader": retrieval_grader, | |
"hallucination_grader": hallucination_grader, | |
"answer_grader": answer_grader, | |
"question_rewriter": question_rewriter, | |
"rag_chain": rag_chain | |
} | |
# Build the RAG graph | |
def build_rag_graph(components): | |
# Define graph state | |
class GraphState(TypedDict): | |
"""Represents the state of our graph.""" | |
question: str | |
generation: str | |
documents: List[str] | |
# Node functions | |
def retrieve(state): | |
"""Retrieve documents""" | |
question = state["question"] | |
documents = components["retriever"].get_relevant_documents(question) | |
return {"documents": documents, "question": question} | |
def generate(state): | |
"""Generate answer""" | |
question = state["question"] | |
documents = state["documents"] | |
generation = components["rag_chain"].invoke({"context": documents, "question": question}) | |
return {"documents": documents, "question": question, "generation": generation} | |
def grade_documents(state): | |
"""Determines whether the retrieved documents are relevant to the question.""" | |
question = state["question"] | |
documents = state["documents"] | |
# Score each doc | |
filtered_docs = [] | |
for d in documents: | |
score = components["retrieval_grader"].invoke( | |
{"question": question, "document": d.page_content} | |
) | |
grade = score.binary_score | |
if grade == "yes": | |
filtered_docs.append(d) | |
return {"documents": filtered_docs, "question": question} | |
def transform_query(state): | |
"""Transform the query to produce a better question.""" | |
question = state["question"] | |
documents = state["documents"] | |
better_question = components["question_rewriter"].invoke({"question": question}) | |
return {"documents": documents, "question": better_question} | |
# Edge functions | |
def decide_to_generate(state): | |
"""Determines whether to generate an answer, or re-generate a question.""" | |
filtered_documents = state["documents"] | |
if not filtered_documents: | |
# All documents have been filtered out | |
return "transform_query" | |
else: | |
# We have relevant documents, so generate answer | |
return "generate" | |
def grade_generation_v_documents_and_question(state): | |
"""Determines whether the generation is grounded in the document and answers question.""" | |
question = state["question"] | |
documents = state["documents"] | |
generation = state["generation"] | |
score = components["hallucination_grader"].invoke( | |
{"documents": documents, "generation": generation} | |
) | |
grade = score.binary_score | |
# Check hallucination | |
if grade == "yes": | |
# Check question-answering | |
score = components["answer_grader"].invoke({"question": question, "generation": generation}) | |
grade = score.binary_score | |
if grade == "yes": | |
return "useful" | |
else: | |
return "not useful" | |
else: | |
return "not supported" | |
# Build the graph | |
workflow = StateGraph(GraphState) | |
# Add nodes | |
workflow.add_node("retrieve", retrieve) | |
workflow.add_node("grade_documents", grade_documents) | |
workflow.add_node("generate", generate) | |
workflow.add_node("transform_query", transform_query) | |
# Add edges | |
workflow.add_edge(START, "retrieve") | |
workflow.add_edge("retrieve", "grade_documents") | |
workflow.add_conditional_edges( | |
"grade_documents", | |
decide_to_generate, | |
{ | |
"transform_query": "transform_query", | |
"generate": "generate", | |
}, | |
) | |
workflow.add_edge("transform_query", "retrieve") | |
workflow.add_conditional_edges( | |
"generate", | |
grade_generation_v_documents_and_question, | |
{ | |
"not supported": "generate", | |
"useful": END, | |
"not useful": "transform_query", | |
}, | |
) | |
# Compile the graph | |
return workflow.compile() | |
# Initialize global variables | |
retriever = None | |
rag_app = None | |
components = None | |
current_model_choice = "gpt-4.1" # Default | |
# Run PDF processing and RAG setup ONCE at startup, with default model | |
retriever = initialize_retriever() | |
if retriever is not None: | |
components = setup_components(retriever, current_model_choice) | |
rag_app = build_rag_graph(components) | |
else: | |
logger.error("No retriever could be initialized. Please add PDF files to the Data directory.") | |
# Processing function for Gradio | |
def process_query(question, display_logs=False, model_choice="gpt-4.1"): | |
logs = [] | |
answer = "" | |
token_usage = {} | |
try: | |
global retriever, rag_app, components, current_model_choice | |
if retriever is None: | |
logs.append("Error: No PDF files found. Please add PDF files to the Data directory and restart the app.") | |
return "Error: No PDF files found. Please add PDF files to the Data directory.", "\n".join(logs), token_usage | |
# If model_choice changed, re-initialize components and rag_app | |
if model_choice != current_model_choice: | |
logs.append(f"Switching model to {model_choice} ...") | |
components = setup_components(retriever, model_choice) | |
rag_app = build_rag_graph(components) | |
current_model_choice = model_choice | |
logs.append("Processing query: " + question) | |
logs.append(f"Using model: {model_choice}") | |
logs.append("Starting RAG pipeline...") | |
final_output = None | |
with get_openai_callback() as cb: | |
for i, output in enumerate(rag_app.stream({"question": question})): | |
step_info = f"Step {i+1}: " | |
if 'retrieve' in output: | |
step_info += f"Retrieved {len(output['retrieve']['documents'])} documents" | |
elif 'grade_documents' in output: | |
step_info += f"Graded documents, {len(output['grade_documents']['documents'])} deemed relevant" | |
elif 'transform_query' in output: | |
step_info += f"Transformed query to: {output['transform_query']['question']}" | |
elif 'generate' in output: | |
step_info += "Generated answer" | |
final_output = output | |
logs.append(step_info) | |
# Store token usage information | |
token_usage = { | |
"total_tokens": cb.total_tokens, | |
"prompt_tokens": cb.prompt_tokens, | |
"completion_tokens": cb.completion_tokens, | |
"total_cost": cb.total_cost | |
} | |
logs.append(f"Token usage: {token_usage}") | |
if final_output and 'generate' in final_output: | |
answer = final_output['generate']['generation'] | |
logs.append("Final answer generated successfully") | |
else: | |
answer = "No answer could be generated. Please try rephrasing your question." | |
logs.append("Failed to generate answer") | |
except Exception as e: | |
logs.append(f"Error: {str(e)}") | |
answer = f"An error occurred: {str(e)}" | |
return answer, "\n".join(logs) if display_logs else "", token_usage | |
# Create Gradio interface | |
with gr.Blocks(title="Self-RAG Document Assistant", theme=gr.themes.Base()) as demo: | |
with gr.Row(): | |
gr.Markdown("# Self-RAG Document Assistant") | |
with gr.Row(): | |
gr.Markdown("""This application uses a Self-RAG (Retrieval Augmented Generation) system to | |
provide accurate answers by: | |
1. Retrieving relevant documents from your PDF database | |
2. Grading document relevance to your question | |
3. Generating answers grounded in these documents | |
4. Self-checking for hallucinations and question addressing""") | |
with gr.Row(): | |
with gr.Column(scale=3): | |
query_input = gr.Textbox( | |
label="Your Question", | |
placeholder="Ask a question about your documents...", | |
lines=4 | |
) | |
with gr.Column(scale=1): | |
model_choice_input = gr.Dropdown( | |
label="Model", | |
choices=["gpt-4.1", "gpt-4.1-mini", "gpt-4.1-nano"], | |
value="gpt-4.1" | |
) | |
show_logs = gr.Checkbox(label="Show Debugging Logs", value=False) | |
submit_btn = gr.Button("Submit", variant="primary") | |
with gr.Row(): | |
with gr.Column(): | |
answer_output = gr.Textbox( | |
label="Answer", | |
lines=10, | |
placeholder="Your answer will appear here...", | |
) | |
with gr.Row(): | |
logs_output = gr.Textbox( | |
label="Process Logs", | |
lines=15, | |
visible=False | |
) | |
with gr.Row(): | |
token_usage_output = gr.JSON( | |
label="Token Usage Statistics", | |
visible=True | |
) | |
# Event handlers | |
submit_btn.click( | |
fn=process_query, | |
inputs=[query_input, show_logs, model_choice_input], | |
outputs=[answer_output, logs_output, token_usage_output] | |
) | |
show_logs.change( | |
fn=lambda x: gr.update(visible=x), | |
inputs=[show_logs], | |
outputs=[logs_output] | |
) | |
# Launch the app | |
if __name__ == "__main__": | |
demo.launch(share=False) | |