Spaces:
Running
on
L4
Running
on
L4
""" | |
Universal RAG PDF Chatbot with Improved Context Management - SYNTAX FIXED | |
This is a general-purpose RAG (Retrieval-Augmented Generation) chatbot that can work with any set of PDF documents. | |
It automatically adapts to different domains and topics based on the content of the uploaded PDFs. | |
""" | |
import hashlib | |
import numpy as np | |
import os | |
import re | |
import streamlit as st | |
import torch | |
import glob | |
from huggingface_hub import HfFolder | |
from langchain_community.document_loaders import PyPDFLoader | |
from langchain.text_splitter import CharacterTextSplitter | |
from langchain_community.vectorstores import Chroma | |
from langchain_huggingface import HuggingFaceEmbeddings | |
from sentence_transformers import SentenceTransformer, util | |
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig | |
# ====== CONFIGURATION SECTION ====== | |
APP_TITLE = "π Educational PDF Chatbot" | |
APP_LAYOUT = "wide" | |
MODEL_NAME="Qwen/Qwen2.5-14B-Instruct" | |
#MODEL_NAME = "mistralai/Ministral-8B-Instruct-2410" | |
EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2" | |
CHUNK_SIZE = 2000 | |
CHUNK_OVERLAP = 300 | |
SEARCH_K = 7 | |
MIN_SIMILARITY_THRESHOLD = 0.3 | |
MAX_NEW_TOKENS = 600 | |
TEMPERATURE = 0.1 | |
MAX_CONVERSATION_HISTORY = 6 | |
PDF_SEARCH_PATHS = [ | |
"*.pdf", | |
"Data/*.pdf", | |
"documents/*.pdf", | |
"pdfs/*.pdf" | |
] | |
# ====== END CONFIGURATION SECTION ====== | |
# Set environment variables | |
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" | |
os.environ["HF_HUB_DISABLE_EXPERIMENTAL_WARNING"] = "1" | |
os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "0" | |
# NumPy compatibility | |
np.float_ = np.float64 | |
# Streamlit Page Config | |
st.set_page_config(page_title=APP_TITLE, layout=APP_LAYOUT) | |
# Initialize essential session state variables only | |
if "messages" not in st.session_state: | |
st.session_state.messages = [] | |
if "conversation_id" not in st.session_state: | |
st.session_state.conversation_id = 0 | |
if "model_loaded" not in st.session_state: | |
st.session_state.model_loaded = False | |
# Hugging Face API Details | |
HF_API_KEY = st.secrets.get("HF_TOKEN", os.getenv("HF_TOKEN")) | |
if HF_API_KEY: | |
HfFolder.save_token(HF_API_KEY) | |
if not HF_API_KEY: | |
st.error("Hugging Face API key is missing! Please set HF_API_KEY in Streamlit secrets or environment variables.") | |
st.stop() | |
def load_quantized_model(): | |
"""Load model with 4-bit quantization to save memory.""" | |
try: | |
# Configure 4-bit quantization | |
quantization_config = BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_compute_dtype=torch.float16, | |
bnb_4bit_use_double_quant=True, | |
bnb_4bit_quant_type="nf4" | |
) | |
tokenizer = AutoTokenizer.from_pretrained( | |
MODEL_NAME, | |
token=HF_API_KEY, | |
trust_remote_code=True, | |
use_fast=True, # β Add this for better performance | |
padding_side="left", | |
) | |
# β BETTER PADDING TOKEN HANDLING | |
if tokenizer.pad_token is None: | |
tokenizer.pad_token = tokenizer.eos_token | |
tokenizer.pad_token_id = tokenizer.eos_token_id | |
model = AutoModelForCausalLM.from_pretrained( | |
MODEL_NAME, | |
device_map="auto", | |
torch_dtype=torch.float16, | |
quantization_config=quantization_config, # Add this | |
token=HF_API_KEY, | |
low_cpu_mem_usage=True, # Add this | |
trust_remote_code=True, | |
) | |
st.success("β Quantized model loaded correctly") | |
return model, tokenizer | |
except Exception as e: | |
st.error(f"Error loading quantized model: {str(e)}") | |
return None, None | |
# Load model | |
if not st.session_state.model_loaded: | |
st.info("Initializing model... This may take a few minutes on first load.") | |
with st.spinner("Loading model..."): | |
model, tokenizer = load_quantized_model() | |
if model is not None: | |
st.session_state.model_loaded = True | |
else: | |
model, tokenizer = load_quantized_model() | |
def load_sentence_model(): | |
"""Load sentence transformer model for similarity checking.""" | |
try: | |
st.info(f"Loading sentence transformer model: {EMBEDDING_MODEL}") | |
return SentenceTransformer(EMBEDDING_MODEL, token=HF_API_KEY) | |
except Exception as e: | |
st.warning(f"Error loading sentence model: {str(e)}") | |
st.info("Using fallback sentence model...") | |
try: | |
return SentenceTransformer("sentence-transformers/all-mpnet-base-v2", token=HF_API_KEY) | |
except Exception as e2: | |
st.error(f"Fallback model failed: {str(e2)}") | |
class SimpleEmbedder: | |
def encode(self, texts, convert_to_tensor=True): | |
import numpy as np | |
import torch | |
if isinstance(texts, str): | |
texts = [texts] | |
embeddings = [] | |
for text in texts: | |
words = set(text.lower().split()) | |
embedding = np.zeros(384) | |
for i, word in enumerate(words): | |
for j, char in enumerate(word): | |
if i < 384: | |
embedding[i] = ord(char) / 255.0 | |
embeddings.append(embedding) | |
if convert_to_tensor: | |
return torch.tensor(embeddings) | |
return np.array(embeddings) | |
return SimpleEmbedder() | |
sentence_model = load_sentence_model() | |
def clean_document_text(text): | |
"""Clean document text to remove problematic characters""" | |
if not text: | |
return text | |
# Remove Arabic and other non-Latin scripts | |
import unicodedata | |
try: | |
# Normalize text | |
text = unicodedata.normalize('NFKD', text) | |
except: | |
pass | |
# Remove Arabic characters specifically | |
text = re.sub(r'[\u0600-\u06FF\u0750-\u077F\u08A0-\u08FF\uFB50-\uFDFF\uFE70-\uFEFF]', '', text) | |
# Keep only ASCII and common European characters | |
text = re.sub(r'[^\x00-\x7F\u00C0-\u00FF]', '', text) | |
# Clean up whitespace | |
text = re.sub(r'\s+', ' ', text) | |
return text.strip() | |
def get_pdf_files(): | |
"""Automatically discover all PDF files using configured search paths.""" | |
pdf_files = [] | |
for search_path in PDF_SEARCH_PATHS: | |
found_files = glob.glob(search_path) | |
pdf_files.extend(found_files) | |
# Remove duplicates and sort | |
pdf_files = list(set(pdf_files)) | |
pdf_files.sort() | |
return pdf_files | |
PDF_FILES = get_pdf_files() | |
if not PDF_FILES: | |
st.error("β οΈ No PDF files found. Please upload PDF files to use this chatbot.") | |
st.info("π The app will look for PDF files in these locations:") | |
for path in PDF_SEARCH_PATHS: | |
st.info(f"- {path}") | |
st.stop() | |
else: | |
st.success(f"π Found {len(PDF_FILES)} PDF file(s): {', '.join([os.path.basename(f) for f in PDF_FILES])}") | |
def load_and_index_pdfs(): | |
"""Load and process multiple PDFs into a single vector store.""" | |
try: | |
with st.spinner("Processing PDF documents..."): | |
documents = [] | |
for pdf in PDF_FILES: | |
if os.path.exists(pdf): | |
try: | |
loader = PyPDFLoader(pdf) | |
docs = loader.load() | |
for doc in docs: | |
doc.metadata["source"] = pdf | |
if "page" in doc.metadata: | |
doc.metadata["source"] = f"{os.path.basename(pdf)} (Page {doc.metadata['page']+1})" | |
doc.page_content = clean_document_text(doc.page_content) | |
documents.extend(docs) | |
except Exception as pdf_error: | |
st.error(f"Error loading {pdf}: {str(pdf_error)}") | |
if not documents: | |
st.error("No documents were successfully loaded!") | |
return None | |
# Split documents | |
text_splitter = CharacterTextSplitter(chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP) | |
splits = text_splitter.split_documents(documents) | |
# Create embeddings | |
try: | |
embeddings = HuggingFaceEmbeddings( | |
model_name=EMBEDDING_MODEL, | |
model_kwargs={"token": HF_API_KEY} | |
) | |
# Test embeddings | |
test_embed = embeddings.embed_query("test") | |
if not test_embed or len(test_embed) == 0: | |
raise ValueError("Embedding model returned empty embeddings") | |
except Exception as embed_error: | |
st.warning(f"Primary embedding model failed: {str(embed_error)}") | |
# Fallback embedding class | |
from langchain.embeddings.base import Embeddings | |
class BasicEmbeddings(Embeddings): | |
def embed_documents(self, texts): | |
return [self._basic_embed(text) for text in texts] | |
def embed_query(self, text): | |
return self._basic_embed(text) | |
def _basic_embed(self, text): | |
unique_words = set(text.lower().split()) | |
embedding = np.zeros(384) | |
for i, word in enumerate(unique_words): | |
hash_val = sum(ord(c) for c in word) % 384 | |
embedding[hash_val] += 1 | |
norm = np.linalg.norm(embedding) | |
if norm > 0: | |
embedding = embedding / norm | |
return embedding.tolist() | |
embeddings = BasicEmbeddings() | |
# Create vectorstore | |
vectorstore = Chroma.from_documents( | |
splits, | |
embedding=embeddings, | |
persist_directory="./chroma_db" | |
) | |
return vectorstore.as_retriever(search_kwargs={"k": SEARCH_K}) | |
except Exception as e: | |
st.error(f"Error processing PDFs: {str(e)}") | |
return None | |
retriever = load_and_index_pdfs() | |
def check_document_relevance(query, documents, min_similarity=MIN_SIMILARITY_THRESHOLD): | |
"""Check if retrieved documents are truly relevant using semantic similarity.""" | |
if not documents: | |
return [], [] | |
try: | |
query_embedding = sentence_model.encode(query, convert_to_tensor=True) | |
relevant_docs = [] | |
relevant_scores = [] | |
for doc in documents: | |
try: | |
doc_embedding = sentence_model.encode(doc.page_content, convert_to_tensor=True) | |
if hasattr(util, "pytorch_cos_sim"): | |
similarity = util.pytorch_cos_sim(query_embedding, doc_embedding).item() | |
else: | |
# Fallback similarity calculation | |
import torch.nn.functional as F | |
import torch | |
if not isinstance(query_embedding, torch.Tensor): | |
query_embedding = torch.tensor(query_embedding) | |
if not isinstance(doc_embedding, torch.Tensor): | |
doc_embedding = torch.tensor(doc_embedding) | |
if len(query_embedding.shape) == 1: | |
query_embedding = query_embedding.unsqueeze(0) | |
if len(doc_embedding.shape) == 1: | |
doc_embedding = doc_embedding.unsqueeze(0) | |
similarity = F.cosine_similarity(query_embedding, doc_embedding).item() | |
if similarity > min_similarity: | |
relevant_docs.append(doc) | |
relevant_scores.append(similarity) | |
except Exception as e: | |
print(f"Error calculating similarity: {str(e)}") | |
continue | |
# Sort by relevance | |
if relevant_docs: | |
sorted_pairs = sorted(zip(relevant_docs, relevant_scores), key=lambda x: x[1], reverse=True) | |
relevant_docs, relevant_scores = zip(*sorted_pairs) | |
return list(relevant_docs), list(relevant_scores) | |
else: | |
return [], [] | |
except Exception as e: | |
print(f"Error in relevance check: {str(e)}") | |
return documents, [0.5] * len(documents) | |
def clean_message_content(content): | |
"""Clean message content by removing sources and follow-up questions.""" | |
if not content: | |
return "" | |
# Remove source citations (but preserve the main content) | |
content = re.sub(r'π \*\*Source:\*\*.*?(?=\n|$)', '', content, flags=re.DOTALL) | |
# Remove follow-up questions (but preserve the main content) | |
content = re.sub(r'π‘ \*\*Follow-up.*?(?=\n|$)', '', content, flags=re.DOTALL) | |
# Clean up extra whitespace | |
content = re.sub(r'\n{3,}', '\n\n', content) | |
return content.strip() | |
def get_last_assistant_response(): | |
"""Get the last assistant response for self-reference requests.""" | |
if len(st.session_state.messages) >= 2: | |
for msg in reversed(st.session_state.messages[:-1]): # Exclude current user message | |
if msg["role"] == "assistant": | |
return clean_message_content(msg["content"]) | |
return "" | |
def needs_pronoun_resolution(query): | |
"""Quick check if query contains pronouns that might need resolution.""" | |
query_lower = query.lower() | |
pronouns_to_check = ['they', 'them', 'their', 'it', 'its', 'this', 'that', 'these', 'those'] | |
return any(f' {pronoun} ' in f' {query_lower} ' or | |
query_lower.startswith(f'{pronoun} ') or | |
query_lower.endswith(f' {pronoun}') | |
for pronoun in pronouns_to_check) | |
def detect_pronouns_and_resolve(query, conversation_history): | |
"""Detect pronouns in query and resolve them using conversation context.""" | |
query_lower = query.lower() | |
# Common pronouns that need resolution | |
pronouns = { | |
'they': [], 'them': [], 'their': [], 'theirs': [], | |
'it': [], 'its': [], 'this': [], 'that': [], 'these': [], 'those': [], | |
'he': [], 'him': [], 'his': [], 'she': [], 'her': [], 'hers': [] | |
} | |
# Check if query contains pronouns | |
found_pronouns = [] | |
for pronoun in pronouns.keys(): | |
if f' {pronoun} ' in f' {query_lower} ' or query_lower.startswith(f'{pronoun} ') or query_lower.endswith(f' {pronoun}'): | |
found_pronouns.append(pronoun) | |
if not found_pronouns: | |
return query, False | |
# Look for potential referents in recent conversation | |
if len(conversation_history) < 2: | |
return query, False | |
# Get the last user question and assistant response | |
last_user_msg = "" | |
last_assistant_msg = "" | |
for msg in reversed(conversation_history): | |
if msg["role"] == "user" and not last_user_msg: | |
last_user_msg = msg["content"] | |
elif msg["role"] == "assistant" and not last_assistant_msg: | |
last_assistant_msg = clean_message_content(msg["content"]) | |
if last_user_msg and last_assistant_msg: | |
break | |
# Extract key entities/topics from the last question and response | |
potential_referents = [] | |
# Common patterns for entities/subjects | |
entity_patterns = [ | |
r'\b([A-Z][a-z]+ [A-Z][a-z]+)\b', # Proper nouns (Civil Society) | |
r'\b([a-z]+ [a-z]+(?:ies|tion|ment|ness|ity))\b', # Multi-word concepts | |
r'\b(organizations?|institutions?|companies?|governments?|agencies?|groups?)\b', | |
r'\b(students?|teachers?|researchers?|scientists?|experts?|professionals?)\b', | |
r'\b(countries?|nations?|regions?|communities?|populations?)\b' | |
] | |
# Look for entities in the last user question | |
combined_text = f"{last_user_msg} {last_assistant_msg}" | |
for pattern in entity_patterns: | |
matches = re.findall(pattern, combined_text, re.IGNORECASE) | |
potential_referents.extend(matches) | |
# Look for specific key terms that could be referents | |
key_terms = [] | |
# Extract noun phrases that could be subjects | |
words = last_user_msg.lower().split() | |
for i, word in enumerate(words): | |
if word in ['civil', 'sustainable', 'development', 'social', 'environmental', 'economic']: | |
# Try to capture 2-3 word phrases | |
if i < len(words) - 1: | |
phrase = f"{word} {words[i+1]}" | |
if i < len(words) - 2: | |
phrase3 = f"{phrase} {words[i+2]}" | |
if len(phrase3.split()) <= 3: | |
key_terms.append(phrase3) | |
key_terms.append(phrase) | |
# Add single important words | |
important_single_words = [ | |
'government', 'organization', 'institution', 'company', 'agency', | |
'community', 'society', 'country', 'nation', 'group', 'team', | |
'students', 'teachers', 'researchers', 'scientists', 'experts' | |
] | |
for word in important_single_words: | |
if word in last_user_msg.lower() or word in last_assistant_msg.lower(): | |
key_terms.append(word) | |
# Combine all potential referents | |
all_referents = potential_referents + key_terms | |
if not all_referents: | |
return query, False | |
# Find the most likely referent (prioritize multi-word terms) | |
best_referent = None | |
for ref in all_referents: | |
if len(ref.split()) > 1: # Prefer multi-word terms | |
best_referent = ref | |
break | |
if not best_referent and all_referents: | |
best_referent = all_referents[0] | |
if best_referent: | |
# Create expanded query | |
expanded_query = query | |
for pronoun in found_pronouns: | |
# Replace pronoun with referent | |
if pronoun in ['they', 'them', 'their', 'theirs']: | |
if pronoun == 'they': | |
expanded_query = re.sub(rf'\bthey\b', best_referent, expanded_query, flags=re.IGNORECASE) | |
elif pronoun == 'them': | |
expanded_query = re.sub(rf'\bthem\b', best_referent, expanded_query, flags=re.IGNORECASE) | |
elif pronoun == 'their': | |
expanded_query = re.sub(rf'\btheir\b', f"{best_referent}'s", expanded_query, flags=re.IGNORECASE) | |
elif pronoun == 'theirs': | |
expanded_query = re.sub(rf'\btheirs\b', f"{best_referent}'s", expanded_query, flags=re.IGNORECASE) | |
elif pronoun in ['it', 'its', 'this', 'that']: | |
if pronoun == 'it': | |
expanded_query = re.sub(rf'\bit\b', best_referent, expanded_query, flags=re.IGNORECASE) | |
elif pronoun == 'its': | |
expanded_query = re.sub(rf'\bits\b', f"{best_referent}'s", expanded_query, flags=re.IGNORECASE) | |
elif pronoun in ['this', 'that']: | |
expanded_query = re.sub(rf'\b{pronoun}\b', f"this {best_referent}", expanded_query, flags=re.IGNORECASE) | |
return expanded_query, True | |
return query, False | |
def handle_topic_questions(prompt): | |
"""Handle questions about available topics """ | |
prompt_lower = prompt.lower() | |
# More comprehensive pattern matching | |
topic_question_patterns = [ | |
'what are the other', 'what are the 3 other', 'what are all the topics', | |
'what topics', 'what information do you have', 'what can you help with', | |
'what documents', 'what subjects', 'what areas', 'list topics', | |
'show me the topics', 'what else do you know', 'what other topics' | |
] | |
# Check if any pattern matches | |
is_topic_question = any(pattern in prompt_lower for pattern in topic_question_patterns) | |
if is_topic_question: | |
topics = get_document_topics() | |
if topics: | |
response = f"I have information on these {len(topics)} topics:\n\n" | |
for i, topic in enumerate(topics, 1): | |
response += f"{i}. **{topic}**\n" | |
response += "\nWhich topic would you like to explore?" | |
return response, True | |
return None, False | |
def build_conversation_context(): | |
"""Build a clean conversation context from message history.""" | |
if len(st.session_state.messages) <= 1: | |
return "" | |
# Get recent messages (excluding the initial welcome message and current user input) | |
start_idx = 1 if st.session_state.messages[0]["role"] == "assistant" else 0 | |
recent_messages = st.session_state.messages[start_idx:-MAX_CONVERSATION_HISTORY-1:-1] # Reverse to get most recent | |
recent_messages.reverse() # Put back in chronological order | |
context_parts = [] | |
for msg in recent_messages: | |
role = msg["role"] | |
content = clean_message_content(msg["content"]) | |
if content: # Only include non-empty content | |
if role == "user": | |
context_parts.append(f"User: {content}") | |
elif role == "assistant": | |
context_parts.append(f"Assistant: {content}") | |
return "\n".join(context_parts) | |
def format_text(text): | |
"""Basic text formatting for better display.""" | |
# Basic symbol replacements only | |
replacements = { | |
'alpha': 'Ξ±', | |
'beta': 'Ξ²', | |
'pi': 'Ο', | |
'sum': 'β', | |
'leq': 'β€', | |
'geq': 'β₯', | |
'neq': 'β ', | |
'approx': 'β' | |
} | |
for latex, unicode_char in replacements.items(): | |
text = text.replace('\\' + latex, unicode_char) | |
return text | |
def is_self_reference_request(query): | |
"""Check if the query is asking about the assistant's own previous response.""" | |
query_lower = query.lower().strip() | |
# Direct patterns for self-reference | |
self_reference_patterns = [ | |
r'\b(your|that)\s+(answer|response|explanation)\b', | |
r'\bsummariz(e|ing)\s+(that|your|the)\s+(answer|response)\b', | |
r'\b(sum up|recap)\s+(that|your|the)\s+(answer|response)\b', | |
r'\bmake\s+(that|your|the)\s+(answer|response)\s+(shorter|brief|concise)\b', | |
r'\b(that|your)\s+(previous|last)\s+(answer|response)\b', | |
r'\bwhat\s+you\s+just\s+(said|explained|told)\b' | |
] | |
# Simple self-reference phrases | |
simple_self_ref = [ | |
"can you summarize", "can you summarise", "can you sum up", | |
"summarize that", "summarise that", "sum that up", | |
"make it shorter", "shorten it", "brief version", | |
"recap that", "condense that", "in summary" | |
] | |
# Check patterns | |
if any(re.search(pattern, query_lower) for pattern in self_reference_patterns): | |
return True | |
# Check simple phrases | |
if any(phrase in query_lower for phrase in simple_self_ref): | |
return True | |
# Additional check for standalone summarization requests | |
if query_lower in ["summarize", "summarise", "summary", "sum up", "recap", "brief"]: | |
return True | |
return False | |
def is_follow_up_request(query): | |
"""Check if the query is asking for more information (but not self-reference).""" | |
# Don't treat self-reference requests as follow-ups | |
if is_self_reference_request(query): | |
return False | |
query_lower = query.lower() | |
# Simple follow-up indicators (excluding summarization words that are self-reference) | |
follow_up_words = [ | |
"more", "elaborate", "explain", "clarify", "expand", "further", | |
"continue", "what else", "tell me more", "go on", "details", | |
"can you", "could you", "please", "also", "additionally" | |
] | |
return any(word in query_lower for word in follow_up_words) | |
def clean_model_output(raw_response): | |
"""Clean the model output to remove prompt instructions and artifacts.""" | |
# Remove common system artifacts | |
artifacts = [ | |
"You are an educational assistant", "GUIDELINES:", "DOCUMENT CONTENT:", | |
"RECENT CONVERSATION:", "Current question:", "Based on the provided", | |
"According to the document", "STRICT RULES:", "Use ONLY", "Do NOT use" | |
] | |
for artifact in artifacts: | |
raw_response = raw_response.replace(artifact, "").strip() | |
# Remove unnecessary apologies and general knowledge indicators | |
unwanted_patterns = [ | |
r'I apologize if.*?[.!]?\s*', | |
r'I\'m sorry if.*?[.!]?\s*', | |
r'I\'m here to help with.*?[.!]?\s*', | |
r'Based on my knowledge.*?[.!]?\s*', | |
r'Generally speaking.*?[.!]?\s*', | |
r'In general.*?[.!]?\s*', | |
r'Typically.*?[.!]?\s*', | |
r'It is widely known.*?[.!]?\s*' | |
] | |
for pattern in unwanted_patterns: | |
raw_response = re.sub(pattern, '', raw_response, flags=re.IGNORECASE) | |
# Remove lines starting with system indicators or general knowledge phrases | |
lines = raw_response.split("\n") | |
skip_patterns = [ | |
"answer this question", "question:", "you are an", "be concise", | |
"i apologize", "i'm sorry", "generally", "typically", "in general", | |
"it is known", "research shows", "studies indicate" | |
] | |
cleaned_lines = [ | |
line for line in lines | |
if not any(line.lower().strip().startswith(pattern) for pattern in skip_patterns) | |
] | |
cleaned_text = "\n".join(cleaned_lines) | |
cleaned_text = re.sub(r'\n{3,}', '\n\n', cleaned_text) | |
return cleaned_text.strip() | |
def validate_response_uses_documents(response, document_content): | |
"""Check if the response actually uses information from the documents.""" | |
if not document_content or not response: | |
return False | |
# Check if response contains phrases indicating it can't find info in documents | |
decline_phrases = [ | |
"cannot find", "not in the documents", "not available", | |
"not mentioned", "not specified", "not provided" | |
] | |
if any(phrase in response.lower() for phrase in decline_phrases): | |
return False | |
# Check for general knowledge responses (red flags) | |
general_knowledge_flags = [ | |
"generally", "typically", "usually", "commonly", "in general", | |
"as a rule", "it is known that", "it is widely accepted", | |
"research shows", "studies indicate", "experts believe" | |
] | |
if any(flag in response.lower() for flag in general_knowledge_flags): | |
return False | |
# Simple check: response should have reasonable overlap with document content | |
response_words = set(response.lower().split()) | |
doc_words = set(document_content.lower().split()) | |
# Remove common words for better comparison | |
common_words = { | |
'the', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for', 'of', 'with', 'by', | |
'is', 'are', 'was', 'were', 'a', 'an', 'this', 'that', 'these', 'those', | |
'can', 'will', 'would', 'should', 'could', 'may', 'might', 'must' | |
} | |
response_words -= common_words | |
doc_words -= common_words | |
# Check if there's reasonable overlap (at least 15% of response content words in documents) | |
if len(response_words) > 0: | |
overlap = len(response_words.intersection(doc_words)) | |
overlap_ratio = overlap / len(response_words) | |
return overlap_ratio >= 0.15 | |
return False | |
def classify_query_type(prompt): | |
"""Determine how to handle different types of questions - IMPROVED""" | |
prompt_lower = prompt.lower() | |
# Meta questions about the chatbot itself - MORE PATTERNS | |
meta_patterns = [ | |
'what topics', 'what are the other', 'what are the 3 other', | |
'what information', 'what can you help', 'what documents', | |
'what subjects', 'what areas', 'list topics', 'show me the topics', | |
'what else do you know', 'what other topics' | |
] | |
if any(pattern in prompt_lower for pattern in meta_patterns): | |
return "meta_question" | |
# Summarization requests | |
if any(phrase in prompt_lower for phrase in [ | |
'summarize', 'summarise', 'summary', 'overview' | |
]): | |
return "summarization" | |
return "factual_question" | |
def generate_response_from_model(prompt, relevant_docs=None): | |
"""Generate response with improved context awareness.""" | |
if model is None or tokenizer is None: | |
return "Error: Model could not be loaded." | |
try: | |
with st.spinner("Generating response..."): | |
# Check if this is a self-reference request | |
is_self_ref = is_self_reference_request(prompt) | |
# Build conversation context | |
conversation_context = build_conversation_context() | |
# Get the last assistant response for self-reference requests | |
last_assistant_response = "" | |
if is_self_ref and len(st.session_state.messages) >= 2: | |
for msg in reversed(st.session_state.messages[:-1]): # Exclude current user message | |
if msg["role"] == "assistant": | |
last_assistant_response = clean_message_content(msg["content"]) | |
break | |
# Extract document content | |
document_content = "" | |
if relevant_docs: | |
doc_texts = [] | |
for doc in relevant_docs[:3]: | |
doc_texts.append(doc.page_content[:800]) # Increased chunk size | |
document_content = "\n\n".join(doc_texts) | |
# Create improved system message with better context integration | |
if is_self_ref and last_assistant_response: | |
# Special handling for self-reference requests | |
system_message = f"""You are an educational assistant. The user is asking you to modify, summarize, or clarify your previous response. | |
YOUR PREVIOUS RESPONSE: | |
{last_assistant_response} | |
CONVERSATION CONTEXT: | |
{conversation_context} | |
INSTRUCTIONS: | |
- The user is asking about YOUR previous response shown above | |
- When they say "summarize", "sum up", "make it shorter", etc., they mean the response above | |
- Provide the requested modification (summary, clarification, etc.) of YOUR previous response | |
- Focus ONLY on the content you previously provided | |
- Be concise and direct in addressing their request | |
- Do NOT search for new information, just work with what you already said""" | |
elif is_self_ref and not last_assistant_response: | |
# Handle case where no previous response is found | |
system_message = f"""The user is asking you to summarize or modify a previous response, but I cannot find a recent response to reference. | |
CONVERSATION CONTEXT: | |
{conversation_context} | |
Please acknowledge this and ask the user to clarify which response they're referring to.""" | |
elif document_content: | |
system_message = f"""You are an educational assistant that answers questions using provided document content. You maintain conversation context and provide coherent, contextual responses. | |
CONVERSATION CONTEXT: | |
{conversation_context} | |
DOCUMENT CONTENT: | |
{document_content} | |
INSTRUCTIONS: | |
- Use ONLY the provided document content to answer questions | |
- Consider the conversation context when formulating your response | |
- If the current question relates to previous discussion, acknowledge that connection | |
- Be direct and educational while maintaining conversational flow | |
- If the answer is not in the documents, say "I cannot find this information in the provided documents" | |
- Reference specific details from the documents when possible""" | |
else: | |
system_message = f"""You are an educational assistant. The user's question does not match the available educational documents. | |
CONVERSATION CONTEXT: | |
{conversation_context} | |
Respond with: "I cannot find information about this topic in the provided educational documents. Please ask about topics covered in the uploaded materials." | |
Consider the conversation context to provide a helpful response that acknowledges previous discussion if relevant.""" | |
# Create user message | |
user_message = f"Question: {prompt}" | |
try: | |
model_device = next(model.parameters()).device | |
except: | |
model_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# Generate response with improved prompt structure | |
if hasattr(tokenizer, "apply_chat_template"): | |
messages = [ | |
{"role": "system", "content": system_message}, | |
{"role": "user", "content": user_message} | |
] | |
inputs = tokenizer.apply_chat_template( | |
messages, | |
return_tensors="pt", | |
add_generation_prompt=True, # β Add this for Qwen | |
tokenize=True, # β Add this for Qwen | |
padding=False) ##.to("cuda") | |
inputs = inputs.to(model.device) | |
input_length = inputs.shape[1] | |
else: | |
formatted_prompt = f"<|im_start|>system\n{system_message}<|im_end|>\n<|im_start|>user\n{user_message}<|im_end|>\n<|im_start|>assistant\n" | |
inputs = tokenizer(formatted_prompt, return_tensors="pt", padding=False) | |
inputs = inputs["input_ids"].to(model_device) | |
input_length = inputs.shape[1] | |
outputs = model.generate( | |
inputs, | |
max_new_tokens=MAX_NEW_TOKENS, | |
temperature=TEMPERATURE, | |
top_p=0.8, | |
do_sample=True, | |
eos_token_id=tokenizer.eos_token_id, | |
pad_token_id=tokenizer.eos_token_id, | |
repetition_penalty=1.1, | |
attention_mask=None, # β Let model handle this | |
use_cache=True # β Add this for efficiency | |
) | |
raw_response = tokenizer.decode(outputs[0][input_length:], skip_special_tokens=True) | |
return raw_response.strip() | |
except Exception as e: | |
st.error(f"Error generating response: {str(e)}") | |
return "I'm sorry, there was an error generating a response." | |
def is_conversational_input(prompt): | |
"""Check if the user input is conversational rather than a document query.""" | |
conversational_patterns = [ | |
r'^(hi|hello|hey|greetings|howdy)[\s!.?]*$', | |
r'^(how are you|how\'s it going|what\'s up)[\s!.?]*$', | |
r'^(good morning|good afternoon|good evening)[\s!.?]*$', | |
r'^(thanks|thank you|thx|ty)[\s!.?]*$', | |
r'^(bye|goodbye|see you|farewell)[\s!.?]*$', | |
r'^(clear|reset|start over|new conversation)[\s!.?]*$', | |
# ADD THESE NEW PATTERNS FOR ACKNOWLEDGMENTS | |
r'^(ok|okay|alright|sure|yes|yep|yeah|no|nope|got it|understood|i see)[\s!.?]*$', | |
r'^(cool|nice|great|awesome|perfect|fine|good)[\s!.?]*$', | |
r'^(hmm|hm|mhm|uh huh|aha|oh|ooh|wow)[\s!.?]*$' | |
] | |
prompt_lower = prompt.lower().strip() | |
return any(re.match(pattern, prompt_lower) for pattern in conversational_patterns) | |
def get_document_topics(): | |
"""Extract clean, meaningful topics from loaded documents.""" | |
if not PDF_FILES: | |
return [] | |
topics = [] | |
for pdf in PDF_FILES: | |
filename = os.path.basename(pdf).lower() | |
# Remove file extensions using simple string operations | |
clean_name = filename | |
if clean_name.endswith('.pdf'): | |
clean_name = clean_name[:-4] | |
elif clean_name.endswith('.docx'): | |
clean_name = clean_name[:-5] | |
elif clean_name.endswith('.doc'): | |
clean_name = clean_name[:-4] | |
elif clean_name.endswith('.txt'): | |
clean_name = clean_name[:-4] | |
# Replace separators with spaces | |
clean_name = re.sub(r'[_-]+', ' ', clean_name) | |
# Remove numbers at the start (like "21252030") | |
clean_name = re.sub(r'^\d+\s*', '', clean_name) | |
# Remove common non-meaningful words | |
stop_words = ['document', 'file', 'report', 'briefing', 'overview', 'web', 'pdf'] | |
words = [word for word in clean_name.split() if word not in stop_words and len(word) > 2] | |
if words: | |
# Take first 4 meaningful words max | |
clean_topic = ' '.join(words[:4]) | |
# Capitalize first letter of each word | |
clean_topic = ' '.join(word.capitalize() for word in clean_topic.split()) | |
topics.append(clean_topic) | |
# Remove duplicates and limit to 5 topics max | |
unique_topics = list(dict.fromkeys(topics))[:5] | |
return unique_topics | |
def generate_conversational_response(prompt): | |
"""Generate friendly conversational responses.""" | |
prompt_lower = prompt.lower().strip() | |
# Get topic hint for personalization | |
document_topics = get_document_topics() | |
topic_hint = "" | |
if document_topics: | |
if len(document_topics) == 1: | |
topic_hint = f" I can help with {document_topics[0]}." | |
elif len(document_topics) == 2: | |
topic_hint = f" I can help with {document_topics[0]} and {document_topics[1]}." | |
else: | |
topic_hint = f" I can help with {document_topics[0]}, {document_topics[1]}, and more." | |
conversational_patterns = { | |
r'^(hi|hello|hey|greetings|howdy)[\s!.?]*$': | |
(f"Hello! I'm your educational assistant.{topic_hint} What would you like to learn?", True), | |
r'^(how are you|how\'s it going|what\'s up)[\s!.?]*$': | |
(f"I'm ready to help you learn!{topic_hint} What topic interests you?", True), | |
r'^(good morning|good afternoon|good evening)[\s!.?]*$': | |
(f"{prompt.capitalize()}!{topic_hint} What would you like to explore?", True), | |
r'^(thanks|thank you|thx|ty)[\s!.?]*$': | |
("You're welcome! Would you like to explore another topic?", True), | |
r'^(bye|goodbye|see you|farewell)[\s!.?]*$': | |
("Goodbye! Feel free to return anytime!", False), | |
r'^(clear|reset|start over|new conversation)[\s!.?]*$': | |
("Starting fresh! Your conversation history has been cleared.", True), | |
# ADD THESE NEW RESPONSE PATTERNS | |
r'^(ok|okay|alright|sure|got it|understood|i see)[\s!.?]*$': | |
("Great! Is there anything else you'd like to know?", True), | |
r'^(yes|yep|yeah)[\s!.?]*$': | |
("Excellent! What would you like to explore next?", True), | |
r'^(no|nope)[\s!.?]*$': | |
("No problem! Feel free to ask if you change your mind.", True), | |
r'^(cool|nice|great|awesome|perfect)[\s!.?]*$': | |
("I'm glad you found that helpful! What else can I help with?", True), | |
r'^(fine|good)[\s!.?]*$': | |
("Sounds good! What would you like to learn about next?", True), | |
r'^(hmm|hm|mhm|uh huh|aha|oh|ooh|wow)[\s!.?]*$': | |
("Is there something specific you'd like to explore further?", True) | |
} | |
for pattern, (response, continue_flag) in conversational_patterns.items(): | |
if re.match(pattern, prompt_lower): | |
return response, continue_flag | |
return f"I'm here to help you learn.{topic_hint} What specific topic interests you?", True | |
def generate_contextual_guidance(prompt): | |
"""Generate contextual guidance based on document topics.""" | |
document_topics = get_document_topics() | |
if not document_topics: | |
return "What topics from the documents would you like to explore?" | |
# Try to match user intent with available topics | |
prompt_lower = prompt.lower() | |
relevant_topics = [ | |
topic for topic in document_topics | |
if any(word.lower() in prompt_lower for word in topic.split() if len(word) > 3) | |
] | |
if relevant_topics: | |
if len(relevant_topics) == 1: | |
return f"Would you like to explore {relevant_topics[0]}?" | |
else: | |
topic_list = " or ".join(relevant_topics[:2]) | |
return f"Would you like to explore {topic_list}?" | |
else: | |
# Show available topics in a clean way | |
if len(document_topics) == 1: | |
return f"I can help with {document_topics[0]}. What would you like to know?" | |
elif len(document_topics) == 2: | |
return f"I can help with {document_topics[0]} or {document_topics[1]}. What interests you?" | |
else: | |
# Show first 2 topics + count of others | |
return f"I can help with {document_topics[0]}, {document_topics[1]}, and {len(document_topics)-2} other topic{'s' if len(document_topics)-2 > 1 else ''}. What interests you?" | |
def generate_follow_up_question(context, conversation_length, prompt=None): | |
"""Generate a simple follow-up question.""" | |
# If user asked for summary, suggest elaboration | |
if prompt and any(word in prompt.lower() for word in ["summary", "summarize", "sum up"]): | |
return "Would you like me to elaborate on any specific part?" | |
# Don't generate follow-ups for self-reference requests | |
if prompt and is_self_reference_request(prompt): | |
return None | |
# Simple context-based questions | |
context_lower = context.lower() | |
if "process" in context_lower or "step" in context_lower: | |
return "What are the key steps in this process?" | |
elif "method" in context_lower or "approach" in context_lower: | |
return "How is this method applied in practice?" | |
elif "benefit" in context_lower or "advantage" in context_lower: | |
return "What challenges might arise with this approach?" | |
elif "goal" in context_lower or "target" in context_lower: | |
return "How might these goals be implemented?" | |
# Default questions | |
simple_questions = [ | |
"What aspect of this interests you most?", | |
"Would you like to explore related concepts?", | |
"Are there specific examples you'd like to see?", | |
"How does this connect to your studies?", | |
"Would you like more details on any part?" | |
] | |
return simple_questions[conversation_length % len(simple_questions)] | |
def process_query(prompt, context_docs): | |
"""Complete query processing function with improved context handling.""" | |
# DEBUG: Uncomment this line to see what's happening | |
# debug_topic_classification(prompt) | |
# Handle conversational inputs first | |
if is_conversational_input(prompt): | |
response, should_continue = generate_conversational_response(prompt) | |
reset_pattern = r'^(clear|reset|start over|new conversation)[\s!.?]*$' | |
if re.match(reset_pattern, prompt.lower().strip()): | |
return response, None, True, None | |
return response, None, False, None | |
# Classify the query type | |
query_type = classify_query_type(prompt) | |
# Handle meta questions WITHOUT searching documents - PRIORITY | |
if query_type == "meta_question": | |
response, handled = handle_topic_questions(prompt) | |
if handled: | |
return response, None, False, None | |
# Check for self-reference requests (asking about previous assistant response) | |
is_self_ref = is_self_reference_request(prompt) | |
if is_self_ref: | |
# Generate response based on previous conversation, no document retrieval needed | |
with st.spinner("Understanding your request about my previous response..."): | |
raw_response = generate_response_from_model(prompt, relevant_docs=None) | |
clean_response = clean_model_output(raw_response) | |
clean_response = format_text(clean_response) | |
# Don't add follow-up questions or sources for self-reference requests | |
return clean_response, None, False, None | |
# Check for pronoun resolution needs | |
original_prompt = prompt | |
if needs_pronoun_resolution(prompt): | |
expanded_prompt, was_expanded = detect_pronouns_and_resolve(prompt, st.session_state.messages) | |
if was_expanded: | |
prompt = expanded_prompt | |
# Show user what we understood (will be displayed before response) | |
st.info(f"π‘ I understood: '{prompt}'") | |
# Check for follow-up requests | |
is_followup = is_follow_up_request(prompt) | |
# Get relevant documents for new questions | |
relevant_docs, similarity_scores = check_document_relevance(prompt, context_docs, min_similarity=MIN_SIMILARITY_THRESHOLD) | |
# Extract sources | |
sources = set() | |
for doc in relevant_docs: | |
if hasattr(doc, "metadata") and "source" in doc.metadata: | |
sources.add(doc.metadata["source"]) | |
# Generate response | |
if relevant_docs: | |
# Generate response from model with context | |
raw_response = generate_response_from_model(prompt, relevant_docs) | |
# Clean and format | |
clean_response = clean_model_output(raw_response) | |
clean_response = format_text(clean_response) | |
# Extract document content for validation | |
document_content = "\n\n".join([doc.page_content for doc in relevant_docs[:3]]) | |
# Additional validation: check if response is suspiciously long compared to document content | |
if len(clean_response.split()) > len(document_content.split()) * 0.8: | |
# Response is too long relative to source material | |
contextual_guidance = generate_contextual_guidance(prompt) | |
decline_response = "I can only provide information that's directly available in the documents." | |
return f"{decline_response}\n\nπ‘ {contextual_guidance}", None, False, None | |
# Validate that response actually uses document content | |
if not validate_response_uses_documents(clean_response, document_content): | |
# Response doesn't use documents - decline to answer | |
contextual_guidance = generate_contextual_guidance(prompt) | |
decline_response = "I cannot find specific information about this question in the provided documents." | |
return f"{decline_response}\n\nπ‘ {contextual_guidance}", None, False, None | |
# Add follow-up question (but be more selective about when to add them) | |
if not is_followup and len(st.session_state.messages) % 3 == 0: # Less frequent follow-ups | |
follow_up = generate_follow_up_question(clean_response, len(st.session_state.messages), prompt) | |
if follow_up: # Only add if follow_up is not None | |
clean_response += f"\n\nπ‘ **Follow-up:** {follow_up}" | |
# Add sources | |
if sources: | |
clean_response += f"\n\nπ **Source:** {', '.join(sorted(sources))}" | |
return clean_response, ", ".join(sorted(sources)), False, None | |
else: | |
# No relevant docs - strictly decline to answer from general knowledge | |
contextual_guidance = generate_contextual_guidance(prompt) | |
decline_response = "I cannot find information about this topic in the provided educational documents." | |
return f"{decline_response}\n\nπ‘ {contextual_guidance}", None, False, None | |
# Test function to verify self-reference detection (you can call this in the sidebar for debugging) | |
def test_self_reference_detection(): | |
"""Test self-reference detection with sample phrases.""" | |
test_phrases = [ | |
"can you summarise your answer", | |
"summarize that", | |
"make it shorter", | |
"what you just said", | |
"sum up your response", | |
"give me a brief version", | |
"recap", | |
"summary" | |
] | |
st.sidebar.write("**Self-Reference Detection Test:**") | |
for phrase in test_phrases: | |
is_detected = is_self_reference_request(phrase) | |
emoji = "β " if is_detected else "β" | |
st.sidebar.write(f"{emoji} '{phrase}': {is_detected}") | |
def test_pronoun_resolution(): | |
"""Test pronoun resolution with sample phrases.""" | |
test_queries = [ | |
"What are the obstacles they face?", | |
"How do they implement it?", | |
"What challenges do they encounter?", | |
"How does this affect them?", | |
"What is their role?", | |
"Can you explain their process?" | |
] | |
st.sidebar.write("**Pronoun Resolution Test:**") | |
# Use current conversation for context | |
for query in test_queries: | |
needs_resolution = needs_pronoun_resolution(query) | |
emoji = "π" if needs_resolution else "β" | |
st.sidebar.write(f"{emoji} '{query}'") | |
if needs_resolution and len(st.session_state.messages) > 1: | |
expanded, was_expanded = detect_pronouns_and_resolve(query, st.session_state.messages) | |
if was_expanded: | |
st.sidebar.write(f" β '{expanded}'") | |
# MAIN STREAMLIT INTERFACE | |
st.title(APP_TITLE) | |
# Sidebar info | |
st.sidebar.title("System Info") | |
st.sidebar.info("Educational Assistant") | |
st.sidebar.write("Documents loaded:") | |
for pdf in PDF_FILES: | |
display_name = os.path.basename(pdf) | |
st.sidebar.write(f"- {display_name}") | |
# Add debugging section | |
if st.sidebar.checkbox("Show Debug Info"): | |
test_self_reference_detection() | |
st.sidebar.markdown("---") | |
test_pronoun_resolution() | |
# Initialize welcome message | |
if not st.session_state.messages: | |
# Create a dynamic welcome message based on available documents | |
document_topics = get_document_topics() | |
if document_topics: | |
if len(document_topics) == 1: | |
topic_preview = f" I have information about {document_topics[0]}." | |
elif len(document_topics) == 2: | |
topic_preview = f" I have information about {document_topics[0]} and {document_topics[1]}." | |
elif len(document_topics) <= 4: | |
topic_list = ", ".join(document_topics[:-1]) + f", and {document_topics[-1]}" | |
topic_preview = f" I have information about {topic_list}." | |
else: | |
topic_preview = f" I have information about {document_topics[0]}, {document_topics[1]}, and {len(document_topics)-2} other topics." | |
else: | |
topic_preview = "" | |
welcome_msg = f"Hello! I'm your educational assistant.{topic_preview} What would you like to explore today?" | |
st.session_state.messages.append({"role": "assistant", "content": welcome_msg}) | |
# Clear conversation button | |
col1, col2 = st.columns([4, 1]) | |
with col2: | |
if st.button("New Conversation"): | |
# Increment conversation ID to ensure clean state | |
st.session_state.conversation_id += 1 | |
st.session_state.messages = [] | |
# Create new welcome message | |
document_topics = get_document_topics() | |
if document_topics: | |
if len(document_topics) == 1: | |
topic_preview = f" I have information about {document_topics[0]}." | |
elif len(document_topics) == 2: | |
topic_preview = f" I have information about {document_topics[0]} and {document_topics[1]}." | |
elif len(document_topics) <= 4: | |
topic_list = ", ".join(document_topics[:-1]) + f", and {document_topics[-1]}" | |
topic_preview = f" I have information about {topic_list}." | |
else: | |
topic_preview = f" I have information about {document_topics[0]}, {document_topics[1]}, and {len(document_topics)-2} other topics." | |
else: | |
topic_preview = "" | |
welcome_msg = f"Starting a new conversation.{topic_preview} What would you like to learn about today?" | |
st.session_state.messages.append({"role": "assistant", "content": welcome_msg}) | |
st.rerun() | |
if retriever: | |
# Display chat messages | |
for message in st.session_state.messages: | |
with st.chat_message(message["role"]): | |
st.markdown(message["content"]) | |
# User input | |
if prompt := st.chat_input("What would you like to learn today?"): | |
# Add user message to history | |
st.session_state.messages.append({"role": "user", "content": prompt}) | |
with st.chat_message("user"): | |
st.markdown(prompt) | |
# Generate response | |
with st.chat_message("assistant"): | |
with st.spinner("Thinking..."): | |
try: | |
# Process query | |
retrieved_docs = retriever.get_relevant_documents(prompt) | |
answer, sources, should_reset, new_follow_up = process_query(prompt, retrieved_docs) | |
# Handle conversation reset if needed | |
if should_reset: | |
st.session_state.conversation_id += 1 | |
st.session_state.messages = [] | |
st.session_state.messages.append({"role": "assistant", "content": answer}) | |
st.rerun() | |
# Store response in chat history | |
st.session_state.messages.append({"role": "assistant", "content": answer}) | |
# Display the response | |
st.markdown(answer) | |
except Exception as e: | |
error_msg = f"An error occurred: {str(e)}" | |
st.error(error_msg) | |
st.session_state.messages.append({"role": "assistant", "content": error_msg}) | |
else: | |
st.error("Failed to load document retrieval system.") |