Upload 4 files
Browse files- .gitattributes +1 -0
- Dockerfile +26 -4
- app.py +279 -51
- requirements.txt +14 -5
- sample_documents/test_image.jpg +3 -0
.gitattributes
CHANGED
@@ -34,3 +34,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
test_image.jpg filter=lfs diff=lfs merge=lfs -text
|
|
|
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
test_image.jpg filter=lfs diff=lfs merge=lfs -text
|
37 |
+
sample_documents/test_image.jpg filter=lfs diff=lfs merge=lfs -text
|
Dockerfile
CHANGED
@@ -1,8 +1,30 @@
|
|
1 |
FROM python:3.10-slim
|
2 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
COPY requirements.txt .
|
4 |
RUN pip install --no-cache-dir -r requirements.txt
|
5 |
-
|
6 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
EXPOSE 7860
|
8 |
-
|
|
|
|
|
|
1 |
FROM python:3.10-slim
|
2 |
+
|
3 |
+
WORKDIR /app
|
4 |
+
|
5 |
+
# Install system dependencies
|
6 |
+
RUN apt-get update && apt-get install -y \
|
7 |
+
git \
|
8 |
+
build-essential \
|
9 |
+
python3-dev \
|
10 |
+
libgl1-mesa-glx \
|
11 |
+
libglib2.0-0 \
|
12 |
+
poppler-utils \
|
13 |
+
tesseract-ocr \
|
14 |
+
&& rm -rf /var/lib/apt/lists/*
|
15 |
+
|
16 |
+
# Install Python dependencies
|
17 |
COPY requirements.txt .
|
18 |
RUN pip install --no-cache-dir -r requirements.txt
|
19 |
+
|
20 |
+
# Create sample documents directory
|
21 |
+
RUN mkdir -p /app/sample_documents
|
22 |
+
|
23 |
+
# Copy application code
|
24 |
+
COPY . .
|
25 |
+
|
26 |
+
# Expose port for Gradio
|
27 |
EXPOSE 7860
|
28 |
+
|
29 |
+
# Command to run the application
|
30 |
+
CMD ["python", "app.py"]
|
app.py
CHANGED
@@ -1,69 +1,297 @@
|
|
1 |
import os
|
2 |
-
import
|
3 |
-
import
|
4 |
-
import
|
5 |
-
from
|
|
|
|
|
|
|
6 |
from PIL import Image
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
|
8 |
-
|
|
|
9 |
|
10 |
-
#
|
11 |
-
|
12 |
-
if not API_TOKEN:
|
13 |
-
raise RuntimeError("HF_TOKEN environment variable is not set!")
|
14 |
|
15 |
-
#
|
16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
-
|
19 |
-
|
20 |
-
|
|
|
|
|
|
|
21 |
|
22 |
-
|
23 |
-
|
24 |
-
|
|
|
|
|
|
|
|
|
25 |
|
26 |
-
|
27 |
-
|
28 |
-
|
|
|
|
|
|
|
|
|
29 |
|
30 |
-
|
31 |
-
|
32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
|
39 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
|
41 |
-
|
42 |
-
|
43 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
|
45 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
)
|
53 |
-
response.raise_for_status()
|
54 |
|
55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
|
57 |
-
|
58 |
-
|
59 |
-
return {"question": question, "answer": answer}
|
60 |
-
else:
|
61 |
-
logger.error(f"Unexpected response format: {result}")
|
62 |
-
return {"error": "Unexpected response format"}
|
63 |
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
except Exception as e:
|
68 |
-
logger.error(f"Other error occurred: {e}")
|
69 |
-
return {"error": str(e)}
|
|
|
1 |
import os
|
2 |
+
import cv2
|
3 |
+
import gradio as gr
|
4 |
+
import torch
|
5 |
+
from llava.model.builder import load_pretrained_model
|
6 |
+
from llava.mm_utils import get_model_name_from_path
|
7 |
+
from llava.conversation import conv_templates
|
8 |
+
from llava.utils import disable_torch_init
|
9 |
from PIL import Image
|
10 |
+
import pytesseract
|
11 |
+
from pdf2image import convert_from_path
|
12 |
+
import docx
|
13 |
+
import openpyxl
|
14 |
+
from pptx import Presentation
|
15 |
+
import io
|
16 |
+
import tempfile
|
17 |
+
import re
|
18 |
+
import shutil
|
19 |
|
20 |
+
# Sample documents directory
|
21 |
+
SAMPLE_DIR = "sample_documents"
|
22 |
|
23 |
+
# Initialize LLaVA model
|
24 |
+
disable_torch_init()
|
|
|
|
|
25 |
|
26 |
+
# Model paths
|
27 |
+
model_path = "liuhaotian/llava-v1.5-7b"
|
28 |
+
model_name = get_model_name_from_path(model_path)
|
29 |
+
tokenizer, model, processor, context_len = load_pretrained_model(
|
30 |
+
model_path=model_path,
|
31 |
+
model_base=None,
|
32 |
+
model_name=model_name,
|
33 |
+
device="cuda" if torch.cuda.is_available() else "cpu",
|
34 |
+
load_8bit=not torch.cuda.is_available(),
|
35 |
+
load_4bit=not torch.cuda.is_available()
|
36 |
+
)
|
37 |
|
38 |
+
# Document processing functions
|
39 |
+
def process_image(image_path):
|
40 |
+
# Use Tesseract to extract text from image
|
41 |
+
img = Image.open(image_path)
|
42 |
+
text = pytesseract.image_to_string(img)
|
43 |
+
return text, img
|
44 |
|
45 |
+
def process_pdf(pdf_path):
|
46 |
+
# Convert PDF to images and extract text
|
47 |
+
images = convert_from_path(pdf_path)
|
48 |
+
text = ""
|
49 |
+
for img in images:
|
50 |
+
text += pytesseract.image_to_string(img) + "\n\n"
|
51 |
+
return text, images[0] if images else None
|
52 |
|
53 |
+
def process_docx(docx_path):
|
54 |
+
# Extract text from DOCX
|
55 |
+
doc = docx.Document(docx_path)
|
56 |
+
text = ""
|
57 |
+
for paragraph in doc.paragraphs:
|
58 |
+
text += paragraph.text + "\n"
|
59 |
+
return text, None
|
60 |
|
61 |
+
def process_excel(excel_path):
|
62 |
+
# Extract data from Excel
|
63 |
+
workbook = openpyxl.load_workbook(excel_path)
|
64 |
+
text = ""
|
65 |
+
for sheet_name in workbook.sheetnames:
|
66 |
+
sheet = workbook[sheet_name]
|
67 |
+
text += f"Sheet: {sheet_name}\n"
|
68 |
+
for row in sheet.iter_rows(values_only=True):
|
69 |
+
text += " | ".join([str(cell) if cell is not None else "" for cell in row]) + "\n"
|
70 |
+
text += "\n"
|
71 |
+
return text, None
|
72 |
|
73 |
+
def process_pptx(pptx_path):
|
74 |
+
# Extract text from PowerPoint
|
75 |
+
presentation = Presentation(pptx_path)
|
76 |
+
text = ""
|
77 |
+
for i, slide in enumerate(presentation.slides):
|
78 |
+
text += f"Slide {i+1}:\n"
|
79 |
+
for shape in slide.shapes:
|
80 |
+
if hasattr(shape, "text"):
|
81 |
+
text += shape.text + "\n"
|
82 |
+
text += "\n"
|
83 |
+
return text, None
|
84 |
|
85 |
+
def process_document(file_path):
|
86 |
+
# Process document based on file extension
|
87 |
+
_, ext = os.path.splitext(file_path)
|
88 |
+
ext = ext.lower()
|
89 |
+
|
90 |
+
if ext in ['.jpg', '.jpeg', '.png', '.bmp']:
|
91 |
+
return process_image(file_path)
|
92 |
+
elif ext == '.pdf':
|
93 |
+
return process_pdf(file_path)
|
94 |
+
elif ext == '.docx':
|
95 |
+
return process_docx(file_path)
|
96 |
+
elif ext in ['.xlsx', '.xls']:
|
97 |
+
return process_excel(file_path)
|
98 |
+
elif ext == '.pptx':
|
99 |
+
return process_pptx(file_path)
|
100 |
+
else:
|
101 |
+
return "Unsupported file format", None
|
102 |
|
103 |
+
def get_sample_documents():
|
104 |
+
"""Get list of sample documents from the sample directory"""
|
105 |
+
if not os.path.exists(SAMPLE_DIR):
|
106 |
+
os.makedirs(SAMPLE_DIR)
|
107 |
+
# Create a sample text file if no samples exist
|
108 |
+
with open(os.path.join(SAMPLE_DIR, "sample.txt"), "w") as f:
|
109 |
+
f.write("This is a sample document for testing the document chatbot.\n\n")
|
110 |
+
f.write("It contains information about AI models and document processing.\n")
|
111 |
+
f.write("You can ask questions about this document to test the system.")
|
112 |
+
|
113 |
+
return [f for f in os.listdir(SAMPLE_DIR) if os.path.isfile(os.path.join(SAMPLE_DIR, f))]
|
114 |
|
115 |
+
def chat_with_document(file, query, use_sample=False, sample_name=None):
|
116 |
+
if use_sample and sample_name:
|
117 |
+
file_path = os.path.join(SAMPLE_DIR, sample_name)
|
118 |
+
else:
|
119 |
+
# Handle uploaded file
|
120 |
+
temp_dir = tempfile.mkdtemp()
|
121 |
+
file_path = os.path.join(temp_dir, file.name)
|
122 |
+
|
123 |
+
with open(file_path, 'wb') as f:
|
124 |
+
f.write(file.read())
|
125 |
+
|
126 |
+
# Check if it's an image for visual analysis
|
127 |
+
_, ext = os.path.splitext(file_path)
|
128 |
+
ext = ext.lower()
|
129 |
+
|
130 |
+
if ext in ['.jpg', '.jpeg', '.png', '.bmp']:
|
131 |
+
# For images, we can use LLaVA's visual capabilities
|
132 |
+
image = Image.open(file_path).convert('RGB')
|
133 |
+
|
134 |
+
# Set up conversation
|
135 |
+
conv = conv_templates["llava_v1"].copy()
|
136 |
+
conv.append_message(conv.roles[0], query)
|
137 |
+
conv.append_message(conv.roles[1], None)
|
138 |
+
prompt = conv.get_prompt()
|
139 |
+
|
140 |
+
# Process image with model
|
141 |
+
image_tensor = processor.preprocess(image, return_tensors='pt')['pixel_values'].half().cuda() if torch.cuda.is_available() else processor.preprocess(image, return_tensors='pt')['pixel_values'].float()
|
142 |
+
|
143 |
+
# Generate response
|
144 |
+
with torch.no_grad():
|
145 |
+
response = model.generate(
|
146 |
+
image_tensor,
|
147 |
+
tokenizer(prompt, return_tensors='pt').input_ids.to(model.device),
|
148 |
+
max_new_tokens=1024,
|
149 |
+
temperature=0.7,
|
150 |
+
do_sample=True,
|
151 |
+
)
|
152 |
+
|
153 |
+
response = tokenizer.decode(response[0], skip_special_tokens=True)
|
154 |
+
response = response.split(conv.sep2)[-1].strip()
|
155 |
+
|
156 |
+
else:
|
157 |
+
# For documents, extract text and send to model
|
158 |
+
document_text, _ = process_document(file_path)
|
159 |
+
|
160 |
+
# Limit text length to avoid exceeding context window
|
161 |
+
if len(document_text) > 4000:
|
162 |
+
document_text = document_text[:4000] + "...(truncated)"
|
163 |
+
|
164 |
+
# Set up conversation with extracted text
|
165 |
+
full_prompt = f"This is the content of the document:\n\n{document_text}\n\nNow, please answer this question about the document: {query}"
|
166 |
+
|
167 |
+
# Use LLaVA for text generation
|
168 |
+
conv = conv_templates["llava_v1"].copy()
|
169 |
+
conv.append_message(conv.roles[0], full_prompt)
|
170 |
+
conv.append_message(conv.roles[1], None)
|
171 |
+
prompt = conv.get_prompt()
|
172 |
+
|
173 |
+
# Generate response
|
174 |
+
with torch.no_grad():
|
175 |
+
response = model.generate(
|
176 |
+
None,
|
177 |
+
tokenizer(prompt, return_tensors='pt').input_ids.to(model.device),
|
178 |
+
max_new_tokens=1024,
|
179 |
+
temperature=0.7,
|
180 |
+
do_sample=True,
|
181 |
+
)
|
182 |
+
|
183 |
+
response = tokenizer.decode(response[0], skip_special_tokens=True)
|
184 |
+
response = response.split(conv.sep2)[-1].strip()
|
185 |
+
|
186 |
+
# Clean up if using uploaded file
|
187 |
+
if not use_sample:
|
188 |
+
os.remove(file_path)
|
189 |
+
os.rmdir(temp_dir)
|
190 |
+
|
191 |
+
return response
|
192 |
|
193 |
+
# Create Gradio interface
|
194 |
+
with gr.Blocks() as demo:
|
195 |
+
gr.Markdown("# Document and Image Chat Assistant")
|
196 |
+
|
197 |
+
with gr.Tab("Upload Your Document"):
|
198 |
+
with gr.Row():
|
199 |
+
with gr.Column():
|
200 |
+
file_input = gr.File(label="Upload Document or Image")
|
201 |
+
query_input = gr.Textbox(label="Ask a question about the document", lines=2)
|
202 |
+
submit_btn = gr.Button("Submit")
|
203 |
+
|
204 |
+
with gr.Column():
|
205 |
+
output = gr.Textbox(label="Response", lines=10)
|
206 |
+
|
207 |
+
submit_btn.click(
|
208 |
+
fn=lambda file, query: chat_with_document(file, query, use_sample=False),
|
209 |
+
inputs=[file_input, query_input],
|
210 |
+
outputs=output
|
211 |
+
)
|
212 |
+
|
213 |
+
with gr.Tab("Use Sample Documents"):
|
214 |
+
with gr.Row():
|
215 |
+
with gr.Column():
|
216 |
+
sample_dropdown = gr.Dropdown(choices=get_sample_documents(), label="Select Sample Document")
|
217 |
+
sample_query_input = gr.Textbox(label="Ask a question about the sample document", lines=2)
|
218 |
+
sample_submit_btn = gr.Button("Submit")
|
219 |
+
|
220 |
+
with gr.Column():
|
221 |
+
sample_output = gr.Textbox(label="Response", lines=10)
|
222 |
+
|
223 |
+
sample_submit_btn.click(
|
224 |
+
fn=lambda sample, query: chat_with_document(None, query, use_sample=True, sample_name=sample),
|
225 |
+
inputs=[sample_dropdown, sample_query_input],
|
226 |
+
outputs=sample_output
|
227 |
)
|
|
|
228 |
|
229 |
+
# Create sample documents
|
230 |
+
def create_sample_documents():
|
231 |
+
"""Create sample documents for testing"""
|
232 |
+
if not os.path.exists(SAMPLE_DIR):
|
233 |
+
os.makedirs(SAMPLE_DIR)
|
234 |
+
|
235 |
+
# Sample text document
|
236 |
+
with open(os.path.join(SAMPLE_DIR, "ai_overview.txt"), "w") as f:
|
237 |
+
f.write("# Artificial Intelligence Overview\n\n")
|
238 |
+
f.write("Artificial Intelligence (AI) is the simulation of human intelligence processes by machines, especially computer systems.\n")
|
239 |
+
f.write("These processes include learning (the acquisition of information and rules for using the information),\n")
|
240 |
+
f.write("reasoning (using rules to reach approximate or definite conclusions) and self-correction.\n\n")
|
241 |
+
f.write("Major AI techniques include:\n")
|
242 |
+
f.write("- Machine Learning\n")
|
243 |
+
f.write("- Natural Language Processing\n")
|
244 |
+
f.write("- Computer Vision\n")
|
245 |
+
f.write("- Robotics\n\n")
|
246 |
+
f.write("AI is transforming many fields including healthcare, finance, transportation, and more.")
|
247 |
+
|
248 |
+
# Create sample DOCX if python-docx is available
|
249 |
+
try:
|
250 |
+
doc = docx.Document()
|
251 |
+
doc.add_heading('Project Schedule', 0)
|
252 |
+
doc.add_paragraph('This document outlines the schedule for the AI chatbot project.')
|
253 |
+
|
254 |
+
doc.add_heading('Phase 1: Planning', level=1)
|
255 |
+
doc.add_paragraph('Requirements gathering: Week 1-2')
|
256 |
+
doc.add_paragraph('Architecture design: Week 3')
|
257 |
+
|
258 |
+
doc.add_heading('Phase 2: Development', level=1)
|
259 |
+
doc.add_paragraph('Backend development: Week 4-6')
|
260 |
+
doc.add_paragraph('Frontend development: Week 5-7')
|
261 |
+
doc.add_paragraph('Integration: Week 8')
|
262 |
+
|
263 |
+
doc.add_heading('Phase 3: Testing', level=1)
|
264 |
+
doc.add_paragraph('Unit testing: Week 9')
|
265 |
+
doc.add_paragraph('Integration testing: Week 10')
|
266 |
+
doc.add_paragraph('User acceptance testing: Week 11')
|
267 |
+
|
268 |
+
doc.save(os.path.join(SAMPLE_DIR, 'project_schedule.docx'))
|
269 |
+
except:
|
270 |
+
pass
|
271 |
+
|
272 |
+
# Create a simple image with text
|
273 |
+
try:
|
274 |
+
img = Image.new('RGB', (800, 400), color=(255, 255, 255))
|
275 |
+
from PIL import ImageDraw, ImageFont
|
276 |
+
d = ImageDraw.Draw(img)
|
277 |
+
# Use default font
|
278 |
+
d.text((50, 50), "Document Chat Demo", fill=(0, 0, 0))
|
279 |
+
d.text((50, 100), "This is a sample image for testing the document chat system.", fill=(0, 0, 0))
|
280 |
+
d.text((50, 150), "The system should be able to answer questions about this text.", fill=(0, 0, 0))
|
281 |
+
d.text((50, 200), "It can also describe visual elements in the image.", fill=(0, 0, 0))
|
282 |
+
|
283 |
+
# Draw some shapes
|
284 |
+
d.rectangle([(50, 250), (200, 300)], outline=(255, 0, 0))
|
285 |
+
d.ellipse([(300, 250), (450, 300)], outline=(0, 255, 0))
|
286 |
+
d.polygon([(500, 250), (550, 300), (600, 250)], outline=(0, 0, 255))
|
287 |
+
|
288 |
+
img.save(os.path.join(SAMPLE_DIR, 'sample_image.png'))
|
289 |
+
except:
|
290 |
+
pass
|
291 |
|
292 |
+
# Create sample documents when starting
|
293 |
+
create_sample_documents()
|
|
|
|
|
|
|
|
|
294 |
|
295 |
+
# Start Gradio server
|
296 |
+
if __name__ == "__main__":
|
297 |
+
demo.launch(server_name="0.0.0.0", server_port=7860)
|
|
|
|
|
|
requirements.txt
CHANGED
@@ -1,5 +1,14 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
transformers>=4.31.0
|
3 |
+
accelerate>=0.21.0
|
4 |
+
bitsandbytes>=0.41.0
|
5 |
+
sentencepiece
|
6 |
+
gradio
|
7 |
+
Pillow
|
8 |
+
opencv-python
|
9 |
+
pytesseract
|
10 |
+
pdf2image
|
11 |
+
python-docx
|
12 |
+
openpyxl
|
13 |
+
python-pptx
|
14 |
+
llava-torch
|
sample_documents/test_image.jpg
ADDED
![]() |
Git LFS Details
|