|
import os |
|
import faiss |
|
import numpy as np |
|
import gradio as gr |
|
import PyPDF2 |
|
import uuid |
|
from collections import deque |
|
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, BitsAndBytesConfig |
|
from sentence_transformers import SentenceTransformer |
|
from huggingface_hub import login |
|
|
|
|
|
login(token=os.getenv("HUGGINGFACEHUB_API_TOKEN")) |
|
|
|
|
|
model_name = "Qwen/Qwen2.5-7B-Instruct-1M" |
|
quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype="float16", bnb_4bit_use_double_quant=True) |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=quantization_config, device_map="auto") |
|
text_generator = pipeline("text-generation", model=model, tokenizer=tokenizer) |
|
|
|
|
|
embedding_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") |
|
|
|
|
|
embedding_dim = 384 |
|
chat_index = faiss.IndexFlatL2(embedding_dim) |
|
doc_index = faiss.IndexFlatL2(embedding_dim) |
|
doc_texts = [] |
|
|
|
|
|
chat_sessions = {} |
|
current_session_id = None |
|
SESSION_HISTORY_LIMIT = 5 |
|
|
|
def get_embedding(text): |
|
"""Converts text to a vector embedding.""" |
|
return embedding_model.encode(text, normalize_embeddings=True) |
|
|
|
|
|
SECRET_PASSWORD = os.getenv("APP_SECRET_PASSWORD") |
|
authenticated = False |
|
|
|
def verify_password(password): |
|
"""Verifies user authentication.""" |
|
global authenticated |
|
authenticated = password == SECRET_PASSWORD |
|
return "Access Granted!" if authenticated else "Invalid Password!" |
|
|
|
|
|
def start_new_session(): |
|
"""Starts a new session and resets chat memory.""" |
|
global current_session_id, chat_sessions |
|
current_session_id = str(uuid.uuid4()) |
|
chat_sessions = {current_session_id: deque(maxlen=SESSION_HISTORY_LIMIT)} |
|
return current_session_id |
|
|
|
def store_chat_in_session(user_input, response): |
|
"""Stores a user-chat exchange in the session and FAISS index.""" |
|
if current_session_id is None: |
|
start_new_session() |
|
chat_sessions[current_session_id].append((user_input, response)) |
|
chat_index.add(np.array([get_embedding(response)])) |
|
|
|
def get_recent_chat_history(): |
|
"""Retrieves recent exchanges within the session only.""" |
|
if current_session_id in chat_sessions: |
|
return "\n".join([f"User: {q}\nAI: {r}" for q, r in chat_sessions[current_session_id]]) |
|
return "" |
|
|
|
|
|
def process_pdf(pdf_file): |
|
"""Extracts text from an HR policy document and stores it in FAISS.""" |
|
if not authenticated: |
|
return "Access Denied!" |
|
pdf_reader = PyPDF2.PdfReader(pdf_file) |
|
document_text = " ".join([page.extract_text().replace("\n", " ") for page in pdf_reader.pages if page.extract_text()]) |
|
text_chunks = document_text.split(". ") |
|
embeddings = np.array([get_embedding(chunk) for chunk in text_chunks]) |
|
doc_index.add(embeddings) |
|
doc_texts.extend(text_chunks) |
|
|
|
return "Doc Processed." |
|
|
|
|
|
def retrieve_relevant_passage(query, top_k=3): |
|
"""Retrieves the most relevant HR document excerpts based on query.""" |
|
if not authenticated: |
|
return "Access Denied!" |
|
query_embedding = get_embedding(query) |
|
D, I = doc_index.search(np.array([query_embedding]), top_k) |
|
valid_indices = [i for i in I[0] if i >= 0 and i < len(doc_texts)] |
|
if valid_indices: |
|
relevant_passages = [f"- {doc_texts[i]}" for i in valid_indices] |
|
return "\n".join(relevant_passages) |
|
return "No relevant document found." |
|
|
|
|
|
def retrieve_chat_context(user_input, top_k=3): |
|
"""Retrieves relevant past interactions from the current session only.""" |
|
if not authenticated: |
|
return "" |
|
query_embedding = get_embedding(user_input) |
|
retrieved_texts = [] |
|
if chat_index.ntotal > 0: |
|
D, I = chat_index.search(np.array([query_embedding]), top_k) |
|
retrieved_texts = [chat_sessions[current_session_id][i][1] for i in I[0] if i < len(chat_sessions[current_session_id])] |
|
past_chat_context = get_recent_chat_history() |
|
return f"{past_chat_context}\n{''.join(retrieved_texts)}" |
|
|
|
|
|
def chat_with_pdf(user_input, chat_history=[]): |
|
"""Processes user input and generates AI responses based on HR documentation with streaming.""" |
|
if not authenticated: |
|
return "Access Denied!", chat_history |
|
relevant_passage = retrieve_relevant_passage(user_input) |
|
past_chat_context = retrieve_chat_context(user_input) |
|
few_shot_examples = ( |
|
"Example 1:\nUser: How many paid leaves do I have?\nAI: Full-time employees receive 18 paid leaves per year." |
|
"Example 2:\nUser: How do I apply for maternity leave?\nAI: Submit a request via the HR portal. Eligible employees get up to 26 weeks." |
|
) |
|
prompt = ( |
|
"You are an HR virtual assistant providing professional responses based on company policies. Ensure accuracy and clarity. If unsure, say this exactly 'Please contact the HR department for more details.'.\n\n" |
|
f"{few_shot_examples}\n" |
|
f"**Recent Chat:**\n{past_chat_context}\n**HR Policy Context:**\n{relevant_passage}\n**User Inquiry:** {user_input}\nAI Response:" |
|
) |
|
|
|
def response_generator(): |
|
response = text_generator(prompt, max_new_tokens=1024, do_sample=True, temperature=0.3, top_p=0.85, repetition_penalty=1.2, eos_token_id=tokenizer.eos_token_id) |
|
answer = response[0]['generated_text'].split("AI Response:")[-1].strip() |
|
store_chat_in_session(user_input, answer) |
|
formatted_response = f"{answer}\n\n*Reference:* _{relevant_passage}_" |
|
|
|
|
|
yield formatted_response |
|
|
|
return response_generator(), chat_history |
|
|
|
|
|
with gr.Blocks() as chat_ui: |
|
gr.Markdown("# π HR-Talk") |
|
|
|
with gr.Accordion("Authenticator", open=False): |
|
with gr.Row(equal_height=True): |
|
password_input = gr.Textbox(placeholder="Enter Password Here", type="password", interactive=True, scale=3, show_label=False) |
|
verify_button = gr.Button("β
Verify", variant="primary", scale=1) |
|
access_status = gr.Label(value="Status", scale=2) |
|
verify_button.click(verify_password, inputs=[password_input], outputs=[access_status]) |
|
|
|
|
|
|
|
with gr.Accordion("Document Feeder", open=False): |
|
with gr.Row(equal_height=True): |
|
file_upload = gr.File(label="π Upload PDF", file_types=[".pdf"], interactive=True, scale=5) |
|
upload_btn = gr.Button("π€ Process PDF", variant="primary", scale=2) |
|
status = gr.Label(value="Waiting for upload...", scale=3) |
|
|
|
|
|
|
|
chatbot = gr.Chatbot() |
|
with gr.Row(equal_height=True): |
|
user_input = gr.Textbox(placeholder="Type your message here...", show_label=False, scale=8) |
|
send_btn = gr.Button("Send", scale=2) |
|
|
|
upload_btn.click(process_pdf, inputs=[file_upload], outputs=[status]) |
|
|
|
def stream_response(user_input, chat_history): |
|
response_generator, chat_history = chat_with_pdf(user_input, chat_history) |
|
full_response = "" |
|
for word in response_generator: |
|
full_response += word |
|
yield chat_history[:-1] + [(user_input, full_response)] |
|
chat_history.append((user_input, full_response)) |
|
yield chat_history |
|
|
|
send_btn.click(stream_response, inputs=[user_input, chatbot], outputs=[chatbot]) |
|
|
|
if __name__ == "__main__": |
|
chat_ui.launch() |