HRT / app.py
kidwaiaun's picture
Update app.py
484cd86 verified
import os
import faiss
import numpy as np
import gradio as gr
import PyPDF2
import uuid
from collections import deque
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, BitsAndBytesConfig
from sentence_transformers import SentenceTransformer
from huggingface_hub import login
# Authentication
login(token=os.getenv("HUGGINGFACEHUB_API_TOKEN"))
# Load AI Model
model_name = "Qwen/Qwen2.5-7B-Instruct-1M"
quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype="float16", bnb_4bit_use_double_quant=True)
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=quantization_config, device_map="auto")
text_generator = pipeline("text-generation", model=model, tokenizer=tokenizer)
# Sentence Embedding Model
embedding_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
# FAISS Indexes
embedding_dim = 384
chat_index = faiss.IndexFlatL2(embedding_dim)
doc_index = faiss.IndexFlatL2(embedding_dim)
doc_texts = []
# Session-based Memory (Resets Each Session)
chat_sessions = {}
current_session_id = None
SESSION_HISTORY_LIMIT = 5
def get_embedding(text):
"""Converts text to a vector embedding."""
return embedding_model.encode(text, normalize_embeddings=True)
# Authentication
SECRET_PASSWORD = os.getenv("APP_SECRET_PASSWORD")
authenticated = False
def verify_password(password):
"""Verifies user authentication."""
global authenticated
authenticated = password == SECRET_PASSWORD
return "Access Granted!" if authenticated else "Invalid Password!"
# Chat Session Management (Resets on New Session)
def start_new_session():
"""Starts a new session and resets chat memory."""
global current_session_id, chat_sessions
current_session_id = str(uuid.uuid4())
chat_sessions = {current_session_id: deque(maxlen=SESSION_HISTORY_LIMIT)} # Reset previous sessions
return current_session_id
def store_chat_in_session(user_input, response):
"""Stores a user-chat exchange in the session and FAISS index."""
if current_session_id is None:
start_new_session()
chat_sessions[current_session_id].append((user_input, response))
chat_index.add(np.array([get_embedding(response)]))
def get_recent_chat_history():
"""Retrieves recent exchanges within the session only."""
if current_session_id in chat_sessions:
return "\n".join([f"User: {q}\nAI: {r}" for q, r in chat_sessions[current_session_id]])
return ""
# Document Processing
def process_pdf(pdf_file):
"""Extracts text from an HR policy document and stores it in FAISS."""
if not authenticated:
return "Access Denied!"
pdf_reader = PyPDF2.PdfReader(pdf_file)
document_text = " ".join([page.extract_text().replace("\n", " ") for page in pdf_reader.pages if page.extract_text()])
text_chunks = document_text.split(". ")
embeddings = np.array([get_embedding(chunk) for chunk in text_chunks])
doc_index.add(embeddings)
doc_texts.extend(text_chunks)
# doc_index.add(np.array([get_embedding(document_text)]))
return "Doc Processed."
# Retrieve Relevant HR Policy Passages
def retrieve_relevant_passage(query, top_k=3):
"""Retrieves the most relevant HR document excerpts based on query."""
if not authenticated:
return "Access Denied!"
query_embedding = get_embedding(query)
D, I = doc_index.search(np.array([query_embedding]), top_k)
valid_indices = [i for i in I[0] if i >= 0 and i < len(doc_texts)]
if valid_indices:
relevant_passages = [f"- {doc_texts[i]}" for i in valid_indices]
return "\n".join(relevant_passages)
return "No relevant document found."
# Retrieve Chat Context
def retrieve_chat_context(user_input, top_k=3):
"""Retrieves relevant past interactions from the current session only."""
if not authenticated:
return ""
query_embedding = get_embedding(user_input)
retrieved_texts = []
if chat_index.ntotal > 0:
D, I = chat_index.search(np.array([query_embedding]), top_k)
retrieved_texts = [chat_sessions[current_session_id][i][1] for i in I[0] if i < len(chat_sessions[current_session_id])]
past_chat_context = get_recent_chat_history()
return f"{past_chat_context}\n{''.join(retrieved_texts)}"
# AI Chatbot with Streaming and Markdown Formatting
def chat_with_pdf(user_input, chat_history=[]):
"""Processes user input and generates AI responses based on HR documentation with streaming."""
if not authenticated:
return "Access Denied!", chat_history
relevant_passage = retrieve_relevant_passage(user_input)
past_chat_context = retrieve_chat_context(user_input)
few_shot_examples = (
"Example 1:\nUser: How many paid leaves do I have?\nAI: Full-time employees receive 18 paid leaves per year."
"Example 2:\nUser: How do I apply for maternity leave?\nAI: Submit a request via the HR portal. Eligible employees get up to 26 weeks."
)
prompt = (
"You are an HR virtual assistant providing professional responses based on company policies. Ensure accuracy and clarity. If unsure, say this exactly 'Please contact the HR department for more details.'.\n\n"
f"{few_shot_examples}\n"
f"**Recent Chat:**\n{past_chat_context}\n**HR Policy Context:**\n{relevant_passage}\n**User Inquiry:** {user_input}\nAI Response:"
)
def response_generator():
response = text_generator(prompt, max_new_tokens=1024, do_sample=True, temperature=0.3, top_p=0.85, repetition_penalty=1.2, eos_token_id=tokenizer.eos_token_id)
answer = response[0]['generated_text'].split("AI Response:")[-1].strip()
store_chat_in_session(user_input, answer)
formatted_response = f"{answer}\n\n*Reference:* _{relevant_passage}_"
# for word in formatted_response.split():
# yield word + " "
yield formatted_response
return response_generator(), chat_history
# Gradio Interfaces with Chatbot UI
with gr.Blocks() as chat_ui:
gr.Markdown("# πŸ“„ HR-Talk")
with gr.Accordion("Authenticator", open=False):
with gr.Row(equal_height=True):
password_input = gr.Textbox(placeholder="Enter Password Here", type="password", interactive=True, scale=3, show_label=False)
verify_button = gr.Button("βœ… Verify", variant="primary", scale=1)
access_status = gr.Label(value="Status", scale=2)
verify_button.click(verify_password, inputs=[password_input], outputs=[access_status])
# verify_button = gr.Button("Verify").click(verify_password, inputs=[gr.Textbox(label="Enter Password", type="password")], outputs=[gr.Textbox(label="Status", interactive=False)])
with gr.Accordion("Document Feeder", open=False):
with gr.Row(equal_height=True):
file_upload = gr.File(label="πŸ“‚ Upload PDF", file_types=[".pdf"], interactive=True, scale=5)
upload_btn = gr.Button("πŸ“€ Process PDF", variant="primary", scale=2)
status = gr.Label(value="Waiting for upload...", scale=3)
# upload_btn = gr.Button("Process PDF").click(process_pdf, inputs=[gr.File(label="Upload PDF")], outputs=[gr.Textbox(label="Processing Status", interactive=False)])
chatbot = gr.Chatbot()
with gr.Row(equal_height=True):
user_input = gr.Textbox(placeholder="Type your message here...", show_label=False, scale=8)
send_btn = gr.Button("Send", scale=2)
upload_btn.click(process_pdf, inputs=[file_upload], outputs=[status])
def stream_response(user_input, chat_history):
response_generator, chat_history = chat_with_pdf(user_input, chat_history)
full_response = ""
for word in response_generator:
full_response += word
yield chat_history[:-1] + [(user_input, full_response)] # Update last message only
chat_history.append((user_input, full_response))
yield chat_history
send_btn.click(stream_response, inputs=[user_input, chatbot], outputs=[chatbot])
if __name__ == "__main__":
chat_ui.launch()