ritampatra commited on
Commit
879e1ad
·
verified ·
1 Parent(s): 5ab0b92

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -49
app.py CHANGED
@@ -1,84 +1,96 @@
1
  import gradio as gr
2
- from transformers import AutoTokenizer, AutoModel, pipeline
3
- from langchain.vectorstores import FAISS
4
- from langchain.document_loaders import PyPDFLoader
5
- from langchain.chains.question_answering import load_qa_chain
6
- from langchain.llms import HuggingFaceHub
7
  import torch
 
8
 
9
- # Function to load and process the document (PDF)
10
  def load_document(file):
11
- loader = PyPDFLoader(file.name)
12
- documents = loader.load()
13
- return documents
 
 
 
14
 
15
- # Function to embed documents using Hugging Face model directly
16
- def embed_documents(documents):
17
- # Load tokenizer and model
18
  tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-mpnet-base-v2")
19
  model = AutoModel.from_pretrained("sentence-transformers/all-mpnet-base-v2")
 
 
 
 
 
 
 
20
 
21
- # Get document texts
22
- document_texts = [doc.page_content for doc in documents]
 
 
23
 
24
- # Create embeddings for each document
25
- embeddings = []
26
- for text in document_texts:
27
- inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
28
- with torch.no_grad():
29
- model_output = model(**inputs)
30
- embedding = model_output.last_hidden_state.mean(dim=1) # Mean pool the embeddings
31
- embeddings.append(embedding.squeeze().numpy())
32
 
33
- # Store embeddings in FAISS vector store
34
- vector_store = FAISS.from_embeddings(embeddings, documents)
35
- return vector_store
 
36
 
37
- # Function to handle chatbot queries
38
- def chat_with_document(query, vector_store):
39
- retriever = vector_store.as_retriever()
40
- llm = HuggingFaceHub(repo_id="google/flan-t5-large", model_kwargs={"temperature": 0.2})
41
- chain = load_qa_chain(llm, chain_type="stuff")
42
- results = retriever.get_relevant_documents(query)
43
- answer = chain.run(input_documents=results, question=query)
44
- return answer
45
 
46
- # Function to build the Gradio interface
 
 
 
 
 
 
47
  def chatbot_interface():
48
- vector_store = None
 
49
 
50
- # Function to handle file upload and document embedding
51
  def upload_file(file):
52
- nonlocal vector_store
53
- documents = load_document(file)
54
- vector_store = embed_documents(documents)
55
- return "Document uploaded and processed. You can now ask questions."
56
 
57
  # Function to handle user queries
58
  def ask_question(query):
59
- if vector_store:
60
- return chat_with_document(query, vector_store)
61
  return "Please upload a document first."
62
 
63
- # Gradio interface components
64
  upload = gr.File(label="Upload a PDF document")
65
  question = gr.Textbox(label="Ask a question about the document")
66
  answer = gr.Textbox(label="Answer", readonly=True)
67
 
68
- # Linking the functions to Gradio interface
69
- upload_button = gr.Interface(fn=upload_file, inputs=upload, outputs="text")
70
- chat_box = gr.Interface(fn=ask_question, inputs=question, outputs=answer)
71
-
72
  # Gradio app layout
73
  with gr.Blocks() as demo:
74
  gr.Markdown("# Document Chatbot")
75
  with gr.Row():
76
- upload_button.render()
77
  with gr.Row():
78
  question.render()
79
  answer.render()
80
 
81
- # Launch the Gradio app
 
 
 
82
  demo.launch()
83
 
84
  # Start the chatbot interface
 
1
  import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModel
3
+ import faiss
4
+ import numpy as np
 
 
5
  import torch
6
+ from PyPDF2 import PdfReader
7
 
8
+ # Load PDF and extract text from it
9
  def load_document(file):
10
+ pdf = PdfReader(file)
11
+ text = ''
12
+ for page_num in range(len(pdf.pages)):
13
+ page = pdf.pages[page_num]
14
+ text += page.extract_text()
15
+ return text
16
 
17
+ # Embed the document using Hugging Face model
18
+ def embed_text(text):
19
+ # Load tokenizer and model from Hugging Face
20
  tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-mpnet-base-v2")
21
  model = AutoModel.from_pretrained("sentence-transformers/all-mpnet-base-v2")
22
+
23
+ # Tokenize and embed text
24
+ inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
25
+ with torch.no_grad():
26
+ outputs = model(**inputs)
27
+ embeddings = outputs.last_hidden_state.mean(dim=1) # Mean pooling to get the embedding
28
+ return embeddings.squeeze().numpy()
29
 
30
+ # Initialize FAISS index
31
+ def initialize_faiss(embedding_size):
32
+ index = faiss.IndexFlatL2(embedding_size)
33
+ return index
34
 
35
+ # Add document embeddings to FAISS index
36
+ def add_to_index(index, embeddings):
37
+ index.add(embeddings)
 
 
 
 
 
38
 
39
+ # Search the FAISS index for the best matching text
40
+ def search_index(index, query_embedding, texts, top_k=3):
41
+ distances, indices = index.search(np.array([query_embedding]), top_k)
42
+ return [texts[i] for i in indices[0]]
43
 
44
+ # Process the document and build the FAISS index
45
+ def process_document(file):
46
+ text = load_document(file)
47
+ chunks = [text[i:i + 512] for i in range(0, len(text), 512)] # Split text into chunks
48
+ embeddings = np.vstack([embed_text(chunk) for chunk in chunks]) # Create embeddings for each chunk
49
+ faiss_index = initialize_faiss(embeddings.shape[1]) # Initialize FAISS index
50
+ add_to_index(faiss_index, embeddings) # Add embeddings to FAISS index
51
+ return faiss_index, chunks
52
 
53
+ # Answer query by searching FAISS index
54
+ def query_document(query, faiss_index, document_chunks):
55
+ query_embedding = embed_text(query) # Embed query
56
+ results = search_index(faiss_index, query_embedding, document_chunks) # Search for the best matching chunks
57
+ return "\n\n".join(results) # Return the matching document parts
58
+
59
+ # Gradio interface
60
  def chatbot_interface():
61
+ faiss_index = None
62
+ document_chunks = None
63
 
64
+ # Function to handle document upload
65
  def upload_file(file):
66
+ nonlocal faiss_index, document_chunks
67
+ faiss_index, document_chunks = process_document(file)
68
+ return "Document uploaded and indexed. You can now ask questions."
 
69
 
70
  # Function to handle user queries
71
  def ask_question(query):
72
+ if faiss_index and document_chunks:
73
+ return query_document(query, faiss_index, document_chunks)
74
  return "Please upload a document first."
75
 
76
+ # Gradio UI
77
  upload = gr.File(label="Upload a PDF document")
78
  question = gr.Textbox(label="Ask a question about the document")
79
  answer = gr.Textbox(label="Answer", readonly=True)
80
 
 
 
 
 
81
  # Gradio app layout
82
  with gr.Blocks() as demo:
83
  gr.Markdown("# Document Chatbot")
84
  with gr.Row():
85
+ upload.render()
86
  with gr.Row():
87
  question.render()
88
  answer.render()
89
 
90
+ # Bind upload and question functionality
91
+ upload.upload(upload_file)
92
+ question.submit(ask_question, inputs=question, outputs=answer)
93
+
94
  demo.launch()
95
 
96
  # Start the chatbot interface