oussamatahkoubit commited on
Commit
5e8d8f8
·
verified ·
1 Parent(s): 4ab6353

Upload 4 files

Browse files
Files changed (5) hide show
  1. .gitattributes +1 -0
  2. Dockerfile +26 -4
  3. app.py +279 -51
  4. requirements.txt +14 -5
  5. sample_documents/test_image.jpg +3 -0
.gitattributes CHANGED
@@ -34,3 +34,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  test_image.jpg filter=lfs diff=lfs merge=lfs -text
 
 
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  test_image.jpg filter=lfs diff=lfs merge=lfs -text
37
+ sample_documents/test_image.jpg filter=lfs diff=lfs merge=lfs -text
Dockerfile CHANGED
@@ -1,8 +1,30 @@
1
  FROM python:3.10-slim
2
- WORKDIR /code
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  COPY requirements.txt .
4
  RUN pip install --no-cache-dir -r requirements.txt
5
- COPY app.py .
6
- COPY test_image.jpg .
 
 
 
 
 
 
7
  EXPOSE 7860
8
- CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
 
 
 
1
  FROM python:3.10-slim
2
+
3
+ WORKDIR /app
4
+
5
+ # Install system dependencies
6
+ RUN apt-get update && apt-get install -y \
7
+ git \
8
+ build-essential \
9
+ python3-dev \
10
+ libgl1-mesa-glx \
11
+ libglib2.0-0 \
12
+ poppler-utils \
13
+ tesseract-ocr \
14
+ && rm -rf /var/lib/apt/lists/*
15
+
16
+ # Install Python dependencies
17
  COPY requirements.txt .
18
  RUN pip install --no-cache-dir -r requirements.txt
19
+
20
+ # Create sample documents directory
21
+ RUN mkdir -p /app/sample_documents
22
+
23
+ # Copy application code
24
+ COPY . .
25
+
26
+ # Expose port for Gradio
27
  EXPOSE 7860
28
+
29
+ # Command to run the application
30
+ CMD ["python", "app.py"]
app.py CHANGED
@@ -1,69 +1,297 @@
1
  import os
2
- import io
3
- import logging
4
- import requests
5
- from fastapi import FastAPI
 
 
 
6
  from PIL import Image
 
 
 
 
 
 
 
 
 
7
 
8
- app = FastAPI()
 
9
 
10
- # Get your token from environment variable 'HF_TOKEN'
11
- API_TOKEN = os.getenv("HF_TOKEN")
12
- if not API_TOKEN:
13
- raise RuntimeError("HF_TOKEN environment variable is not set!")
14
 
15
- # Use a Visual Question Answering (VQA) model
16
- API_URL = "https://api-inference.huggingface.co/models/Salesforce/blip-vqa-base"
 
 
 
 
 
 
 
 
 
17
 
18
- HEADERS = {
19
- "Authorization": f"Bearer {API_TOKEN}"
20
- }
 
 
 
21
 
22
- # Configure logging
23
- logging.basicConfig(level=logging.INFO)
24
- logger = logging.getLogger("app")
 
 
 
 
25
 
26
- @app.on_event("startup")
27
- async def startup_event():
28
- logger.info("Warming up the Hugging Face API")
 
 
 
 
29
 
30
- @app.get("/")
31
- def home():
32
- return {"message": "VQA API is running"}
 
 
 
 
 
 
 
 
33
 
34
- @app.get("/ask")
35
- def ask_question():
36
- image_path = "/code/test_image.jpg"
37
- question = "What does the picture show?" # Example question
 
 
 
 
 
 
 
38
 
39
- logger.info(f"Reading image: {image_path}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
- try:
42
- with open(image_path, "rb") as image_file:
43
- image_bytes = image_file.read()
 
 
 
 
 
 
 
 
44
 
45
- logger.info("Sending request to Hugging Face API")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
- response = requests.post(
48
- API_URL,
49
- headers=HEADERS,
50
- files={"image": ("filename.jpg", image_bytes, "image/jpeg")},
51
- data={"inputs": f'{{"question":"{question}"}}'}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  )
53
- response.raise_for_status()
54
 
55
- result = response.json()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
- if "answer" in result:
58
- answer = result["answer"]
59
- return {"question": question, "answer": answer}
60
- else:
61
- logger.error(f"Unexpected response format: {result}")
62
- return {"error": "Unexpected response format"}
63
 
64
- except requests.exceptions.HTTPError as e:
65
- logger.error(f"HTTP error occurred: {e}")
66
- return {"error": str(e)}
67
- except Exception as e:
68
- logger.error(f"Other error occurred: {e}")
69
- return {"error": str(e)}
 
1
  import os
2
+ import cv2
3
+ import gradio as gr
4
+ import torch
5
+ from llava.model.builder import load_pretrained_model
6
+ from llava.mm_utils import get_model_name_from_path
7
+ from llava.conversation import conv_templates
8
+ from llava.utils import disable_torch_init
9
  from PIL import Image
10
+ import pytesseract
11
+ from pdf2image import convert_from_path
12
+ import docx
13
+ import openpyxl
14
+ from pptx import Presentation
15
+ import io
16
+ import tempfile
17
+ import re
18
+ import shutil
19
 
20
+ # Sample documents directory
21
+ SAMPLE_DIR = "sample_documents"
22
 
23
+ # Initialize LLaVA model
24
+ disable_torch_init()
 
 
25
 
26
+ # Model paths
27
+ model_path = "liuhaotian/llava-v1.5-7b"
28
+ model_name = get_model_name_from_path(model_path)
29
+ tokenizer, model, processor, context_len = load_pretrained_model(
30
+ model_path=model_path,
31
+ model_base=None,
32
+ model_name=model_name,
33
+ device="cuda" if torch.cuda.is_available() else "cpu",
34
+ load_8bit=not torch.cuda.is_available(),
35
+ load_4bit=not torch.cuda.is_available()
36
+ )
37
 
38
+ # Document processing functions
39
+ def process_image(image_path):
40
+ # Use Tesseract to extract text from image
41
+ img = Image.open(image_path)
42
+ text = pytesseract.image_to_string(img)
43
+ return text, img
44
 
45
+ def process_pdf(pdf_path):
46
+ # Convert PDF to images and extract text
47
+ images = convert_from_path(pdf_path)
48
+ text = ""
49
+ for img in images:
50
+ text += pytesseract.image_to_string(img) + "\n\n"
51
+ return text, images[0] if images else None
52
 
53
+ def process_docx(docx_path):
54
+ # Extract text from DOCX
55
+ doc = docx.Document(docx_path)
56
+ text = ""
57
+ for paragraph in doc.paragraphs:
58
+ text += paragraph.text + "\n"
59
+ return text, None
60
 
61
+ def process_excel(excel_path):
62
+ # Extract data from Excel
63
+ workbook = openpyxl.load_workbook(excel_path)
64
+ text = ""
65
+ for sheet_name in workbook.sheetnames:
66
+ sheet = workbook[sheet_name]
67
+ text += f"Sheet: {sheet_name}\n"
68
+ for row in sheet.iter_rows(values_only=True):
69
+ text += " | ".join([str(cell) if cell is not None else "" for cell in row]) + "\n"
70
+ text += "\n"
71
+ return text, None
72
 
73
+ def process_pptx(pptx_path):
74
+ # Extract text from PowerPoint
75
+ presentation = Presentation(pptx_path)
76
+ text = ""
77
+ for i, slide in enumerate(presentation.slides):
78
+ text += f"Slide {i+1}:\n"
79
+ for shape in slide.shapes:
80
+ if hasattr(shape, "text"):
81
+ text += shape.text + "\n"
82
+ text += "\n"
83
+ return text, None
84
 
85
+ def process_document(file_path):
86
+ # Process document based on file extension
87
+ _, ext = os.path.splitext(file_path)
88
+ ext = ext.lower()
89
+
90
+ if ext in ['.jpg', '.jpeg', '.png', '.bmp']:
91
+ return process_image(file_path)
92
+ elif ext == '.pdf':
93
+ return process_pdf(file_path)
94
+ elif ext == '.docx':
95
+ return process_docx(file_path)
96
+ elif ext in ['.xlsx', '.xls']:
97
+ return process_excel(file_path)
98
+ elif ext == '.pptx':
99
+ return process_pptx(file_path)
100
+ else:
101
+ return "Unsupported file format", None
102
 
103
+ def get_sample_documents():
104
+ """Get list of sample documents from the sample directory"""
105
+ if not os.path.exists(SAMPLE_DIR):
106
+ os.makedirs(SAMPLE_DIR)
107
+ # Create a sample text file if no samples exist
108
+ with open(os.path.join(SAMPLE_DIR, "sample.txt"), "w") as f:
109
+ f.write("This is a sample document for testing the document chatbot.\n\n")
110
+ f.write("It contains information about AI models and document processing.\n")
111
+ f.write("You can ask questions about this document to test the system.")
112
+
113
+ return [f for f in os.listdir(SAMPLE_DIR) if os.path.isfile(os.path.join(SAMPLE_DIR, f))]
114
 
115
+ def chat_with_document(file, query, use_sample=False, sample_name=None):
116
+ if use_sample and sample_name:
117
+ file_path = os.path.join(SAMPLE_DIR, sample_name)
118
+ else:
119
+ # Handle uploaded file
120
+ temp_dir = tempfile.mkdtemp()
121
+ file_path = os.path.join(temp_dir, file.name)
122
+
123
+ with open(file_path, 'wb') as f:
124
+ f.write(file.read())
125
+
126
+ # Check if it's an image for visual analysis
127
+ _, ext = os.path.splitext(file_path)
128
+ ext = ext.lower()
129
+
130
+ if ext in ['.jpg', '.jpeg', '.png', '.bmp']:
131
+ # For images, we can use LLaVA's visual capabilities
132
+ image = Image.open(file_path).convert('RGB')
133
+
134
+ # Set up conversation
135
+ conv = conv_templates["llava_v1"].copy()
136
+ conv.append_message(conv.roles[0], query)
137
+ conv.append_message(conv.roles[1], None)
138
+ prompt = conv.get_prompt()
139
+
140
+ # Process image with model
141
+ image_tensor = processor.preprocess(image, return_tensors='pt')['pixel_values'].half().cuda() if torch.cuda.is_available() else processor.preprocess(image, return_tensors='pt')['pixel_values'].float()
142
+
143
+ # Generate response
144
+ with torch.no_grad():
145
+ response = model.generate(
146
+ image_tensor,
147
+ tokenizer(prompt, return_tensors='pt').input_ids.to(model.device),
148
+ max_new_tokens=1024,
149
+ temperature=0.7,
150
+ do_sample=True,
151
+ )
152
+
153
+ response = tokenizer.decode(response[0], skip_special_tokens=True)
154
+ response = response.split(conv.sep2)[-1].strip()
155
+
156
+ else:
157
+ # For documents, extract text and send to model
158
+ document_text, _ = process_document(file_path)
159
+
160
+ # Limit text length to avoid exceeding context window
161
+ if len(document_text) > 4000:
162
+ document_text = document_text[:4000] + "...(truncated)"
163
+
164
+ # Set up conversation with extracted text
165
+ full_prompt = f"This is the content of the document:\n\n{document_text}\n\nNow, please answer this question about the document: {query}"
166
+
167
+ # Use LLaVA for text generation
168
+ conv = conv_templates["llava_v1"].copy()
169
+ conv.append_message(conv.roles[0], full_prompt)
170
+ conv.append_message(conv.roles[1], None)
171
+ prompt = conv.get_prompt()
172
+
173
+ # Generate response
174
+ with torch.no_grad():
175
+ response = model.generate(
176
+ None,
177
+ tokenizer(prompt, return_tensors='pt').input_ids.to(model.device),
178
+ max_new_tokens=1024,
179
+ temperature=0.7,
180
+ do_sample=True,
181
+ )
182
+
183
+ response = tokenizer.decode(response[0], skip_special_tokens=True)
184
+ response = response.split(conv.sep2)[-1].strip()
185
+
186
+ # Clean up if using uploaded file
187
+ if not use_sample:
188
+ os.remove(file_path)
189
+ os.rmdir(temp_dir)
190
+
191
+ return response
192
 
193
+ # Create Gradio interface
194
+ with gr.Blocks() as demo:
195
+ gr.Markdown("# Document and Image Chat Assistant")
196
+
197
+ with gr.Tab("Upload Your Document"):
198
+ with gr.Row():
199
+ with gr.Column():
200
+ file_input = gr.File(label="Upload Document or Image")
201
+ query_input = gr.Textbox(label="Ask a question about the document", lines=2)
202
+ submit_btn = gr.Button("Submit")
203
+
204
+ with gr.Column():
205
+ output = gr.Textbox(label="Response", lines=10)
206
+
207
+ submit_btn.click(
208
+ fn=lambda file, query: chat_with_document(file, query, use_sample=False),
209
+ inputs=[file_input, query_input],
210
+ outputs=output
211
+ )
212
+
213
+ with gr.Tab("Use Sample Documents"):
214
+ with gr.Row():
215
+ with gr.Column():
216
+ sample_dropdown = gr.Dropdown(choices=get_sample_documents(), label="Select Sample Document")
217
+ sample_query_input = gr.Textbox(label="Ask a question about the sample document", lines=2)
218
+ sample_submit_btn = gr.Button("Submit")
219
+
220
+ with gr.Column():
221
+ sample_output = gr.Textbox(label="Response", lines=10)
222
+
223
+ sample_submit_btn.click(
224
+ fn=lambda sample, query: chat_with_document(None, query, use_sample=True, sample_name=sample),
225
+ inputs=[sample_dropdown, sample_query_input],
226
+ outputs=sample_output
227
  )
 
228
 
229
+ # Create sample documents
230
+ def create_sample_documents():
231
+ """Create sample documents for testing"""
232
+ if not os.path.exists(SAMPLE_DIR):
233
+ os.makedirs(SAMPLE_DIR)
234
+
235
+ # Sample text document
236
+ with open(os.path.join(SAMPLE_DIR, "ai_overview.txt"), "w") as f:
237
+ f.write("# Artificial Intelligence Overview\n\n")
238
+ f.write("Artificial Intelligence (AI) is the simulation of human intelligence processes by machines, especially computer systems.\n")
239
+ f.write("These processes include learning (the acquisition of information and rules for using the information),\n")
240
+ f.write("reasoning (using rules to reach approximate or definite conclusions) and self-correction.\n\n")
241
+ f.write("Major AI techniques include:\n")
242
+ f.write("- Machine Learning\n")
243
+ f.write("- Natural Language Processing\n")
244
+ f.write("- Computer Vision\n")
245
+ f.write("- Robotics\n\n")
246
+ f.write("AI is transforming many fields including healthcare, finance, transportation, and more.")
247
+
248
+ # Create sample DOCX if python-docx is available
249
+ try:
250
+ doc = docx.Document()
251
+ doc.add_heading('Project Schedule', 0)
252
+ doc.add_paragraph('This document outlines the schedule for the AI chatbot project.')
253
+
254
+ doc.add_heading('Phase 1: Planning', level=1)
255
+ doc.add_paragraph('Requirements gathering: Week 1-2')
256
+ doc.add_paragraph('Architecture design: Week 3')
257
+
258
+ doc.add_heading('Phase 2: Development', level=1)
259
+ doc.add_paragraph('Backend development: Week 4-6')
260
+ doc.add_paragraph('Frontend development: Week 5-7')
261
+ doc.add_paragraph('Integration: Week 8')
262
+
263
+ doc.add_heading('Phase 3: Testing', level=1)
264
+ doc.add_paragraph('Unit testing: Week 9')
265
+ doc.add_paragraph('Integration testing: Week 10')
266
+ doc.add_paragraph('User acceptance testing: Week 11')
267
+
268
+ doc.save(os.path.join(SAMPLE_DIR, 'project_schedule.docx'))
269
+ except:
270
+ pass
271
+
272
+ # Create a simple image with text
273
+ try:
274
+ img = Image.new('RGB', (800, 400), color=(255, 255, 255))
275
+ from PIL import ImageDraw, ImageFont
276
+ d = ImageDraw.Draw(img)
277
+ # Use default font
278
+ d.text((50, 50), "Document Chat Demo", fill=(0, 0, 0))
279
+ d.text((50, 100), "This is a sample image for testing the document chat system.", fill=(0, 0, 0))
280
+ d.text((50, 150), "The system should be able to answer questions about this text.", fill=(0, 0, 0))
281
+ d.text((50, 200), "It can also describe visual elements in the image.", fill=(0, 0, 0))
282
+
283
+ # Draw some shapes
284
+ d.rectangle([(50, 250), (200, 300)], outline=(255, 0, 0))
285
+ d.ellipse([(300, 250), (450, 300)], outline=(0, 255, 0))
286
+ d.polygon([(500, 250), (550, 300), (600, 250)], outline=(0, 0, 255))
287
+
288
+ img.save(os.path.join(SAMPLE_DIR, 'sample_image.png'))
289
+ except:
290
+ pass
291
 
292
+ # Create sample documents when starting
293
+ create_sample_documents()
 
 
 
 
294
 
295
+ # Start Gradio server
296
+ if __name__ == "__main__":
297
+ demo.launch(server_name="0.0.0.0", server_port=7860)
 
 
 
requirements.txt CHANGED
@@ -1,5 +1,14 @@
1
- fastapi==0.110.0
2
- uvicorn[standard]==0.29.0
3
- requests==2.31.0
4
- python-dotenv==1.0.1
5
- pillow==10.2.0
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ transformers>=4.31.0
3
+ accelerate>=0.21.0
4
+ bitsandbytes>=0.41.0
5
+ sentencepiece
6
+ gradio
7
+ Pillow
8
+ opencv-python
9
+ pytesseract
10
+ pdf2image
11
+ python-docx
12
+ openpyxl
13
+ python-pptx
14
+ llava-torch
sample_documents/test_image.jpg ADDED

Git LFS Details

  • SHA256: 8a56ccfc341865af4ec1c2d836e52e71dcd959e41a8522f60bfcc3ff4e99d388
  • Pointer size: 131 Bytes
  • Size of remote file: 107 kB