Kal1510 commited on
Commit
610b317
·
verified ·
1 Parent(s): 5f3ca8f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +85 -157
app.py CHANGED
@@ -1,40 +1,32 @@
1
  import os
2
- import torch
3
  import gradio as gr
4
  from PyPDF2 import PdfReader
5
- from transformers import (
6
- AutoTokenizer, pipeline,
7
- AutoModelForCausalLM, AutoConfig,
8
- BitsAndBytesConfig
9
- )
10
  from langchain.text_splitter import RecursiveCharacterTextSplitter
11
  from langchain_community.vectorstores import FAISS
12
  from langchain.prompts import PromptTemplate
13
  from langchain.chains import LLMChain
14
  from langchain.embeddings import HuggingFaceEmbeddings
15
  from langchain.schema import Document
16
- from langchain import HuggingFacePipeline
17
-
18
- from huggingface_hub import login
19
 
20
- api_key=os.getenv("api_key")
21
 
22
- try:
23
- login(token=api_key)
24
- print("login!")
25
- except Exception as e:
26
- print(f"Login failed: {e}")
27
 
28
- # ------------------------------
29
- # Device setup
30
- # ------------------------------
31
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
 
32
 
33
  # ------------------------------
34
- # Embedding model config
35
  # ------------------------------
36
  modelPath = "sentence-transformers/all-mpnet-base-v2"
37
- model_kwargs = {"device": str(device)}
38
  encode_kwargs = {"normalize_embedding": False}
39
 
40
  embeddings = HuggingFaceEmbeddings(
@@ -44,50 +36,59 @@ embeddings = HuggingFaceEmbeddings(
44
  )
45
 
46
  # ------------------------------
47
- # Load Mistral model in 4bit
48
  # ------------------------------
49
- model_name = "mistralai/Mistral-7B-Instruct-v0.1"
50
- tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
51
- tokenizer.pad_token = tokenizer.eos_token
52
- tokenizer.padding_side = "right"
53
-
54
- # 4-bit quantization config
55
- bnb_config = BitsAndBytesConfig(
56
- load_in_4bit=True,
57
- bnb_4bit_quant_type="nf4",
58
- bnb_4bit_use_double_quant=True,
59
- bnb_4bit_compute_dtype=torch.float16
60
  )
61
 
62
- # Load model
63
- model = AutoModelForCausalLM.from_pretrained(
64
- model_name,
65
- # quantization_config=bnb_config,
66
- device_map="auto"
67
- )
 
 
 
 
68
 
69
  # ------------------------------
70
- # Improved Text Generation Pipeline
71
  # ------------------------------
72
- text_generation = pipeline(
73
- model=model,
74
- tokenizer=tokenizer,
75
- task="text-generation",
76
- temperature=0.7,
77
- top_p=0.9,
78
- top_k=50,
79
- repetition_penalty=1.1,
80
- return_full_text=False,
81
- max_new_tokens=2000,
82
- do_sample=True,
83
- eos_token_id=tokenizer.eos_token_id,
84
- )
85
 
86
- # Wrap in LangChain interface
87
- mistral_llm = HuggingFacePipeline(pipeline=text_generation)
 
 
 
 
 
 
 
 
 
 
 
88
 
89
  # ------------------------------
90
- # PDF Processing Functions
91
  # ------------------------------
92
  def pdf_text(pdf_docs):
93
  text = ""
@@ -101,8 +102,8 @@ def pdf_text(pdf_docs):
101
 
102
  def get_chunks(text):
103
  splitter = RecursiveCharacterTextSplitter(
104
- chunk_size=1000,
105
- chunk_overlap=200,
106
  length_function=len
107
  )
108
  chunks = splitter.split_text(text)
@@ -112,96 +113,50 @@ def get_vectorstore(documents):
112
  db = FAISS.from_documents(documents, embedding=embeddings)
113
  db.save_local("faiss_index")
114
 
115
- # ------------------------------
116
- # Conversational Prompt Template
117
- # ------------------------------
118
- def get_qa_prompt():
119
- prompt_template = """<s>[INST]
120
- You are a helpful, knowledgeable AI assistant. Answer the user's question based on the provided context.
121
-
122
- Guidelines:
123
- - Respond in a natural, conversational tone
124
- - Be detailed but concise
125
- - Use paragraphs and bullet points when appropriate
126
- - If you don't know, say so
127
- - Maintain a friendly and professional demeanor
128
-
129
- Conversation History:
130
- {chat_history}
131
-
132
- Relevant Context:
133
- {context}
134
-
135
- Current Question: {question}
136
-
137
- Provide a helpful response: [/INST]"""
138
-
139
- return PromptTemplate(
140
- template=prompt_template,
141
- input_variables=["context", "question", "chat_history"]
142
- )
143
 
144
- # ------------------------------
145
- # Chat Handling Functions
146
- # ------------------------------
147
  def handle_pdf_upload(pdf_files):
 
 
148
  try:
149
- if not pdf_files:
150
- return "⚠️ Please upload at least one PDF file"
151
-
152
  text = pdf_text(pdf_files)
153
  if not text.strip():
154
- return "⚠️ Could not extract text from PDFs - please try different files"
155
-
156
  chunks = get_chunks(text)
157
  get_vectorstore(chunks)
158
- return f"✅ Processed {len(pdf_files)} PDF(s) with {len(chunks)} text chunks"
159
  except Exception as e:
160
  return f"❌ Error: {str(e)}"
161
 
162
- def format_chat_history(chat_history):
163
- return "\n".join([f"User: {q}\nAssistant: {a}" for q, a in chat_history[-3:]])
164
-
165
  def user_query(msg, chat_history):
166
  if not os.path.exists("faiss_index"):
167
- chat_history.append((msg, "Please upload PDF documents first so I can help you."))
168
  return "", chat_history
169
-
170
  try:
171
- # Load vector store
172
  db = FAISS.load_local("faiss_index", embeddings, allow_dangerous_deserialization=True)
173
- retriever = db.as_retriever(search_kwargs={"k": 3})
174
-
175
- # Get relevant context
176
  docs = retriever.get_relevant_documents(msg)
177
- context = "\n\n".join([d.page_content for d in docs])
178
-
179
- # Generate response
180
  prompt = get_qa_prompt()
181
- chain = LLMChain(llm=mistral_llm, prompt=prompt)
182
-
183
- response = chain.run({
184
- "question": msg,
185
- "context": context,
186
- "chat_history": format_chat_history(chat_history)
187
- })
188
-
189
- # Clean response
190
- response = response.strip()
191
- for end_token in ["</s>", "[INST]", "[/INST]"]:
192
- if response.endswith(end_token):
193
- response = response[:-len(end_token)].strip()
194
-
195
  chat_history.append((msg, response))
196
  return "", chat_history
197
-
198
  except Exception as e:
199
  error_msg = f"Sorry, I encountered an error: {str(e)}"
200
  chat_history.append((msg, error_msg))
201
  return "", chat_history
202
 
203
  # ------------------------------
204
- # Gradio Interface
205
  # ------------------------------
206
  with gr.Blocks(theme=gr.themes.Soft(), title="PDF Chat Assistant") as demo:
207
  with gr.Row():
@@ -209,7 +164,7 @@ with gr.Blocks(theme=gr.themes.Soft(), title="PDF Chat Assistant") as demo:
209
  # 📚 PDF Chat Assistant
210
  ### Have natural conversations with your documents
211
  """)
212
-
213
  with gr.Row():
214
  with gr.Column(scale=1, min_width=300):
215
  gr.Markdown("### Document Upload")
@@ -227,7 +182,7 @@ with gr.Blocks(theme=gr.themes.Soft(), title="PDF Chat Assistant") as demo:
227
  2. Click Process Documents
228
  3. Start chatting in the right panel
229
  """)
230
-
231
  with gr.Column(scale=2):
232
  chatbot = gr.Chatbot(
233
  height=600,
@@ -260,37 +215,10 @@ with gr.Blocks(theme=gr.themes.Soft(), title="PDF Chat Assistant") as demo:
260
  label="Example Questions"
261
  )
262
 
263
- # Event handlers
264
- upload_btn.click(
265
- fn=handle_pdf_upload,
266
- inputs=pdf_input,
267
- outputs=status_box
268
- )
269
-
270
- submit_btn.click(
271
- fn=user_query,
272
- inputs=[message, chatbot],
273
- outputs=[message, chatbot]
274
- )
275
-
276
- message.submit(
277
- fn=user_query,
278
- inputs=[message, chatbot],
279
- outputs=[message, chatbot]
280
- )
281
-
282
- clear_chat.click(
283
- lambda: [],
284
- None,
285
- chatbot,
286
- queue=False
287
- )
288
 
289
- # Launch the app
290
  if __name__ == "__main__":
291
- demo.launch(
292
- server_name="0.0.0.0",
293
- server_port=7861,
294
- share=True,
295
- debug=True
296
- )
 
1
  import os
 
2
  import gradio as gr
3
  from PyPDF2 import PdfReader
 
 
 
 
 
4
  from langchain.text_splitter import RecursiveCharacterTextSplitter
5
  from langchain_community.vectorstores import FAISS
6
  from langchain.prompts import PromptTemplate
7
  from langchain.chains import LLMChain
8
  from langchain.embeddings import HuggingFaceEmbeddings
9
  from langchain.schema import Document
10
+ from llama_cpp import Llama
11
+ import warnings
12
+ warnings.filterwarnings("ignore")
13
 
 
14
 
15
+ import subprocess
 
 
 
 
16
 
17
+ subprocess.run([
18
+ "huggingface-cli", "download",
19
+ "TheBloke/Mistral-7B-Instruct-v0.1-GGUF",
20
+ "mistral-7b-instruct-v0.1.Q2_K.gguf",
21
+ "--local-dir", "./models",
22
+ "--local-dir-use-symlinks", "False"
23
+ ], check=True)
24
 
25
  # ------------------------------
26
+ # Device and Embedding Setup (CPU optimized)
27
  # ------------------------------
28
  modelPath = "sentence-transformers/all-mpnet-base-v2"
29
+ model_kwargs = {"device": "cpu"} # Force CPU usage
30
  encode_kwargs = {"normalize_embedding": False}
31
 
32
  embeddings = HuggingFaceEmbeddings(
 
36
  )
37
 
38
  # ------------------------------
39
+ # Load Mistral GGUF via llama.cpp (CPU optimized)
40
  # ------------------------------
41
+ llm_cpp = Llama(
42
+ model_path="./models/mistral-7b-instruct-v0.1.Q2_K.gguf",
43
+ n_ctx=2048,
44
+ n_threads=4, # Adjust based on your CPU cores
45
+ n_gpu_layers=0, # Force CPU-only
46
+ temperature=0.7,
47
+ top_p=0.9,
48
+ repeat_penalty=1.1
 
 
 
49
  )
50
 
51
+ # ------------------------------
52
+ # LangChain-compatible wrapper
53
+ # ------------------------------
54
+ def mistral_llm(prompt):
55
+ output = llm_cpp(
56
+ prompt,
57
+ max_tokens=512, # Reduced for CPU performance
58
+ stop=["</s>", "[INST]", "[/INST]"]
59
+ )
60
+ return output["choices"][0]["text"].strip()
61
 
62
  # ------------------------------
63
+ # Prompt Template (unchanged)
64
  # ------------------------------
65
+ def get_qa_prompt():
66
+ template = """<s>[INST] \
67
+ You are a helpful, knowledgeable AI assistant. Answer the user's question based on the provided context.
68
+
69
+ Guidelines:
70
+ - Respond in a natural, conversational tone
71
+ - Be detailed but concise
72
+ - Use paragraphs and bullet points when appropriate
73
+ - If you don't know, say so
74
+ - Maintain a friendly and professional demeanor
 
 
 
75
 
76
+ Conversation History:
77
+ {chat_history}
78
+
79
+ Relevant Context:
80
+ {context}
81
+
82
+ Current Question: {question}
83
+
84
+ Provide a helpful response: [/INST]"""
85
+ return PromptTemplate(
86
+ template=template,
87
+ input_variables=["context", "question", "chat_history"]
88
+ )
89
 
90
  # ------------------------------
91
+ # PDF and Chat Logic (optimized for CPU)
92
  # ------------------------------
93
  def pdf_text(pdf_docs):
94
  text = ""
 
102
 
103
  def get_chunks(text):
104
  splitter = RecursiveCharacterTextSplitter(
105
+ chunk_size=800, # Smaller chunks for CPU
106
+ chunk_overlap=100,
107
  length_function=len
108
  )
109
  chunks = splitter.split_text(text)
 
113
  db = FAISS.from_documents(documents, embedding=embeddings)
114
  db.save_local("faiss_index")
115
 
116
+ def format_chat_history(history):
117
+ return "\n".join([f"User: {q}\nAssistant: {a}" for q, a in history[-2:]]) # Shorter history
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
 
 
 
 
119
  def handle_pdf_upload(pdf_files):
120
+ if not pdf_files:
121
+ return "⚠️ Upload at least one PDF"
122
  try:
 
 
 
123
  text = pdf_text(pdf_files)
124
  if not text.strip():
125
+ return "⚠️ Could not extract text"
 
126
  chunks = get_chunks(text)
127
  get_vectorstore(chunks)
128
+ return f"✅ Processed {len(pdf_files)} PDF(s) with {len(chunks)} chunks"
129
  except Exception as e:
130
  return f"❌ Error: {str(e)}"
131
 
 
 
 
132
  def user_query(msg, chat_history):
133
  if not os.path.exists("faiss_index"):
134
+ chat_history.append((msg, "Please upload PDF documents first."))
135
  return "", chat_history
136
+
137
  try:
 
138
  db = FAISS.load_local("faiss_index", embeddings, allow_dangerous_deserialization=True)
139
+ retriever = db.as_retriever(search_kwargs={"k": 2}) # Fewer documents for CPU
 
 
140
  docs = retriever.get_relevant_documents(msg)
141
+ context = "\n\n".join([d.page_content for d in docs][:2]) # Limit context
142
+
 
143
  prompt = get_qa_prompt()
144
+ final_prompt = prompt.format(
145
+ context=context[:1500], # Further limit context size
146
+ question=msg,
147
+ chat_history=format_chat_history(chat_history)
148
+ )
149
+
150
+ response = mistral_llm(final_prompt)
 
 
 
 
 
 
 
151
  chat_history.append((msg, response))
152
  return "", chat_history
 
153
  except Exception as e:
154
  error_msg = f"Sorry, I encountered an error: {str(e)}"
155
  chat_history.append((msg, error_msg))
156
  return "", chat_history
157
 
158
  # ------------------------------
159
+ # Gradio Interface (your exact requested format)
160
  # ------------------------------
161
  with gr.Blocks(theme=gr.themes.Soft(), title="PDF Chat Assistant") as demo:
162
  with gr.Row():
 
164
  # 📚 PDF Chat Assistant
165
  ### Have natural conversations with your documents
166
  """)
167
+
168
  with gr.Row():
169
  with gr.Column(scale=1, min_width=300):
170
  gr.Markdown("### Document Upload")
 
182
  2. Click Process Documents
183
  3. Start chatting in the right panel
184
  """)
185
+
186
  with gr.Column(scale=2):
187
  chatbot = gr.Chatbot(
188
  height=600,
 
215
  label="Example Questions"
216
  )
217
 
218
+ upload_btn.click(handle_pdf_upload, inputs=pdf_input, outputs=status_box)
219
+ submit_btn.click(user_query, inputs=[message, chatbot], outputs=[message, chatbot])
220
+ message.submit(user_query, inputs=[message, chatbot], outputs=[message, chatbot])
221
+ clear_chat.click(lambda: [], None, chatbot, queue=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
222
 
 
223
  if __name__ == "__main__":
224
+ demo.launch(server_name="0.0.0.0", server_port=7862, share=True) # Disable sharing for local CPU use