Prashant26am commited on
Commit
ae365ab
·
0 Parent(s):

fix: Update Gradio to 4.44.1 and remove example images

Browse files
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ *.so
6
+ .Python
7
+ env/
8
+ build/
9
+ develop-eggs/
10
+ dist/
11
+ downloads/
12
+ eggs/
13
+ .eggs/
14
+ lib/
15
+ lib64/
16
+ parts/
17
+ sdist/
18
+ var/
19
+ wheels/
20
+ *.egg-info/
21
+ .installed.cfg
22
+ *.egg
23
+
24
+ # Virtual Environment
25
+ venv/
26
+ ENV/
27
+
28
+ # IDE
29
+ .idea/
30
+ .vscode/
31
+ *.swp
32
+ *.swo
33
+
34
+ # OS
35
+ .DS_Store
36
+ Thumbs.db
37
+
38
+ # Model files
39
+ *.bin
40
+ *.pt
41
+ *.pth
42
+ *.ckpt
43
+ *.safetensors
44
+
45
+ # Logs
46
+ *.log
47
+ logs/
48
+
49
+ # Frontend
50
+ frontend/
51
+ node_modules/
52
+ npm-debug.log*
53
+ yarn-debug.log*
54
+ yarn-error.log*
55
+
56
+ # Temporary files
57
+ *.tmp
58
+ *.temp
59
+ temp/
60
+ tmp/
README.md ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: LLaVA Chat
3
+ emoji: 🖼️
4
+ colorFrom: blue
5
+ colorTo: indigo
6
+ sdk: gradio
7
+ sdk_version: 3.50.2
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ ---
12
+
13
+ # LLaVA Chat
14
+
15
+ A powerful multimodal AI assistant that can understand and discuss images. Upload any image and chat with LLaVA about it!
16
+
17
+ ## Features
18
+
19
+ - 🖼️ Upload and analyze any image
20
+ - 💬 Natural conversation about image content
21
+ - ⚙️ Adjustable generation parameters
22
+ - 🎯 High-quality image understanding
23
+ - 🚀 Fast and responsive interface
24
+
25
+ ## How to Use
26
+
27
+ 1. Upload an image using the image uploader
28
+ 2. Type your question or prompt about the image
29
+ 3. (Optional) Adjust the generation parameters:
30
+ - Max New Tokens: Control response length
31
+ - Temperature: Adjust response creativity
32
+ - Top P: Fine-tune response diversity
33
+ 4. Click "Generate Response" to get LLaVA's analysis
34
+
35
+ ## Example Prompts
36
+
37
+ - "What can you see in this image?"
38
+ - "Describe this scene in detail"
39
+ - "What emotions does this image convey?"
40
+ - "What's happening in this picture?"
41
+ - "Can you identify any objects or people in this image?"
42
+
43
+ ## Model Details
44
+
45
+ This Space uses the LLaVA (Large Language and Vision Assistant) model, which combines:
46
+ - CLIP ViT-L/14 vision encoder
47
+ - Vicuna-7B language model
48
+ - Advanced multimodal understanding capabilities
49
+
50
+ ## License
51
+
52
+ This project is licensed under the MIT License.
app.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, UploadFile, File, HTTPException
2
+ from fastapi.middleware.cors import CORSMiddleware
3
+ from fastapi.responses import JSONResponse
4
+ import os
5
+ import tempfile
6
+ from typing import Optional
7
+ from pydantic import BaseModel
8
+ import torch
9
+ import gradio as gr
10
+ from models.llava import LLaVA
11
+
12
+ # Initialize model globally
13
+ model = None
14
+
15
+ def initialize_model():
16
+ global model
17
+ try:
18
+ model = LLaVA(
19
+ vision_model_path="openai/clip-vit-large-patch14-336",
20
+ language_model_path="lmsys/vicuna-7b-v1.5",
21
+ device="cuda" if torch.cuda.is_available() else "cpu",
22
+ load_in_8bit=True
23
+ )
24
+ print(f"Model initialized on {model.device}")
25
+ return True
26
+ except Exception as e:
27
+ print(f"Error initializing model: {e}")
28
+ return False
29
+
30
+ def process_image(image, prompt, max_new_tokens=256, temperature=0.7, top_p=0.9):
31
+ if not model:
32
+ return "Error: Model not initialized"
33
+
34
+ try:
35
+ # Save the uploaded image temporarily
36
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.png') as temp_file:
37
+ image.save(temp_file.name)
38
+ temp_path = temp_file.name
39
+
40
+ # Generate response
41
+ response = model.generate_from_image(
42
+ image_path=temp_path,
43
+ prompt=prompt,
44
+ max_new_tokens=max_new_tokens,
45
+ temperature=temperature,
46
+ top_p=top_p
47
+ )
48
+
49
+ # Clean up temporary file
50
+ os.unlink(temp_path)
51
+ return response
52
+
53
+ except Exception as e:
54
+ return f"Error processing image: {str(e)}"
55
+
56
+ # Create Gradio interface
57
+ def create_interface():
58
+ with gr.Blocks(title="LLaVA Chat", theme=gr.themes.Soft()) as demo:
59
+ gr.Markdown("""
60
+ # LLaVA Chat
61
+ Upload an image and chat with LLaVA about it. This model can understand and describe images, answer questions about them, and engage in visual conversations.
62
+ """)
63
+
64
+ with gr.Row():
65
+ with gr.Column(scale=1):
66
+ image_input = gr.Image(type="pil", label="Upload Image")
67
+ prompt_input = gr.Textbox(
68
+ label="Ask about the image",
69
+ placeholder="What can you see in this image?",
70
+ lines=3
71
+ )
72
+
73
+ with gr.Accordion("Advanced Settings", open=False):
74
+ max_tokens = gr.Slider(
75
+ minimum=32,
76
+ maximum=512,
77
+ value=256,
78
+ step=32,
79
+ label="Max New Tokens"
80
+ )
81
+ temperature = gr.Slider(
82
+ minimum=0.1,
83
+ maximum=1.0,
84
+ value=0.7,
85
+ step=0.1,
86
+ label="Temperature"
87
+ )
88
+ top_p = gr.Slider(
89
+ minimum=0.1,
90
+ maximum=1.0,
91
+ value=0.9,
92
+ step=0.1,
93
+ label="Top P"
94
+ )
95
+
96
+ submit_btn = gr.Button("Generate Response", variant="primary")
97
+
98
+ with gr.Column(scale=1):
99
+ output = gr.Textbox(
100
+ label="Model Response",
101
+ lines=10,
102
+ show_copy_button=True
103
+ )
104
+
105
+ # Set up the submit action
106
+ submit_btn.click(
107
+ fn=process_image,
108
+ inputs=[image_input, prompt_input, max_tokens, temperature, top_p],
109
+ outputs=output
110
+ )
111
+
112
+ # Add examples
113
+ gr.Examples(
114
+ examples=[
115
+ ["examples/cat.jpg", "What can you see in this image?"],
116
+ ["examples/landscape.jpg", "Describe this scene in detail."],
117
+ ["examples/food.jpg", "What kind of food is this and how would you describe it?"]
118
+ ],
119
+ inputs=[image_input, prompt_input]
120
+ )
121
+
122
+ return demo
123
+
124
+ # Create FastAPI app
125
+ app = FastAPI(title="LLaVA Web Interface")
126
+
127
+ # Configure CORS
128
+ app.add_middleware(
129
+ CORSMiddleware,
130
+ allow_origins=["*"],
131
+ allow_credentials=True,
132
+ allow_methods=["*"],
133
+ allow_headers=["*"],
134
+ )
135
+
136
+ # Create Gradio app
137
+ demo = create_interface()
138
+
139
+ # Mount Gradio app
140
+ app = gr.mount_gradio_app(app, demo, path="/")
141
+
142
+ if __name__ == "__main__":
143
+ # Initialize model
144
+ if initialize_model():
145
+ import uvicorn
146
+ uvicorn.run(app, host="0.0.0.0", port=7860) # Hugging Face Spaces uses port 7860
147
+ else:
148
+ print("Failed to initialize model. Exiting...")
models/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .llava import LLaVA, MLP, StoppingCriteriaSub
2
+
3
+ __all__ = ['LLaVA', 'MLP', 'StoppingCriteriaSub']
models/llava.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LLaVA: Large Language and Vision Assistant
3
+ Implementation based on the paper "Visual Instruction Tuning" (NeurIPS 2023)
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from transformers import (
9
+ CLIPVisionModel,
10
+ CLIPImageProcessor,
11
+ AutoTokenizer,
12
+ AutoModelForCausalLM,
13
+ StoppingCriteria,
14
+ StoppingCriteriaList
15
+ )
16
+ from PIL import Image
17
+ import os
18
+ from typing import List, Dict, Optional, Tuple, Union
19
+
20
+
21
+ class StoppingCriteriaSub(StoppingCriteria):
22
+ """Custom stopping criteria for text generation."""
23
+
24
+ def __init__(self, stops=None, encounters=1):
25
+ super().__init__()
26
+ self.stops = stops or []
27
+ self.encounters = encounters
28
+ self.counter = 0
29
+
30
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
31
+ """Check if stopping criteria are met."""
32
+ for stop_id in self.stops:
33
+ if stop_id in input_ids[0][-1:]:
34
+ self.counter += 1
35
+ if self.counter >= self.encounters:
36
+ return True
37
+ return False
38
+
39
+
40
+ class MLP(nn.Module):
41
+ """MLP projection layer to connect vision and language models."""
42
+
43
+ def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, dropout_rate: float = 0.1):
44
+ super().__init__()
45
+ self.fc1 = nn.Linear(input_dim, hidden_dim)
46
+ self.fc2 = nn.Linear(hidden_dim, output_dim)
47
+ self.act = nn.GELU()
48
+ self.dropout = nn.Dropout(dropout_rate)
49
+
50
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
51
+ """Forward pass through the MLP."""
52
+ x = self.fc1(x)
53
+ x = self.act(x)
54
+ x = self.dropout(x)
55
+ x = self.fc2(x)
56
+ return x
57
+
58
+
59
+ class LLaVA(nn.Module):
60
+ """
61
+ LLaVA: Large Language and Vision Assistant
62
+ A multimodal model that connects a vision encoder with a language model.
63
+ """
64
+
65
+ def __init__(
66
+ self,
67
+ vision_model_path: str = "openai/clip-vit-large-patch14-336",
68
+ language_model_path: str = "lmsys/vicuna-7b-v1.5",
69
+ projection_hidden_dim: int = 4096,
70
+ device: str = None,
71
+ load_in_8bit: bool = False,
72
+ load_in_4bit: bool = False,
73
+ ):
74
+ """
75
+ Initialize the LLaVA model.
76
+
77
+ Args:
78
+ vision_model_path: Path or name of the vision model
79
+ language_model_path: Path or name of the language model
80
+ projection_hidden_dim: Hidden dimension of the projection layer
81
+ device: Device to load the model on ('cuda', 'cpu', etc.)
82
+ load_in_8bit: Whether to load the language model in 8-bit precision
83
+ load_in_4bit: Whether to load the language model in 4-bit precision
84
+ """
85
+ super().__init__()
86
+
87
+ self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
88
+
89
+ # Load vision model
90
+ self.vision_model = CLIPVisionModel.from_pretrained(vision_model_path)
91
+ self.image_processor = CLIPImageProcessor.from_pretrained(vision_model_path)
92
+
93
+ # Load language model
94
+ kwargs = {}
95
+ if load_in_8bit:
96
+ kwargs['load_in_8bit'] = True
97
+ elif load_in_4bit:
98
+ kwargs['load_in_4bit'] = True
99
+ kwargs['bnb_4bit_compute_dtype'] = torch.float16
100
+
101
+ self.tokenizer = AutoTokenizer.from_pretrained(language_model_path)
102
+ self.language_model = AutoModelForCausalLM.from_pretrained(
103
+ language_model_path,
104
+ torch_dtype=torch.float16,
105
+ **kwargs
106
+ )
107
+
108
+ # Set padding token if not set
109
+ if self.tokenizer.pad_token is None:
110
+ self.tokenizer.pad_token = self.tokenizer.eos_token
111
+
112
+ # Get dimensions
113
+ vision_hidden_size = self.vision_model.config.hidden_size
114
+ language_hidden_size = self.language_model.config.hidden_size
115
+
116
+ # Create projection layer
117
+ self.projection = MLP(
118
+ input_dim=vision_hidden_size,
119
+ hidden_dim=projection_hidden_dim,
120
+ output_dim=language_hidden_size
121
+ )
122
+
123
+ # Move models to device
124
+ self.vision_model.to(self.device)
125
+ self.language_model.to(self.device)
126
+ self.projection.to(self.device)
127
+
128
+ # Set to evaluation mode
129
+ self.vision_model.eval()
130
+ self.language_model.eval()
131
+ self.projection.eval()
132
+
133
+ # Template for conversation
134
+ self.conv_template = [
135
+ {"role": "system", "content": "You are a helpful assistant that can understand images and answer questions about them."},
136
+ ]
137
+
138
+ def encode_image(self, image_path: str) -> torch.Tensor:
139
+ """
140
+ Encode an image using the vision model.
141
+
142
+ Args:
143
+ image_path: Path to the image file
144
+
145
+ Returns:
146
+ Tensor containing the image features
147
+ """
148
+ image = Image.open(image_path).convert('RGB')
149
+ inputs = self.image_processor(images=image, return_tensors="pt").to(self.device)
150
+
151
+ with torch.no_grad():
152
+ outputs = self.vision_model(**inputs)
153
+ image_features = outputs.pooler_output # [1, hidden_size]
154
+
155
+ return image_features
156
+
157
+ def project_image_features(self, image_features: torch.Tensor) -> torch.Tensor:
158
+ """
159
+ Project image features to the language model's embedding space.
160
+
161
+ Args:
162
+ image_features: Image features from the vision model
163
+
164
+ Returns:
165
+ Projected image features
166
+ """
167
+ with torch.no_grad():
168
+ projected_features = self.projection(image_features)
169
+
170
+ return projected_features
171
+
172
+ def format_prompt(self, prompt: str, conversation: List[Dict[str, str]] = None) -> str:
173
+ """
174
+ Format the prompt for the language model.
175
+
176
+ Args:
177
+ prompt: The text prompt
178
+ conversation: Optional conversation history
179
+
180
+ Returns:
181
+ Formatted prompt string
182
+ """
183
+ if conversation is None:
184
+ conversation = self.conv_template.copy()
185
+
186
+ conversation.append({"role": "user", "content": prompt})
187
+
188
+ formatted_prompt = ""
189
+ for message in conversation:
190
+ if message["role"] == "system":
191
+ formatted_prompt += f"<s>[INST] <<SYS>>\n{message['content']}\n<</SYS>>\n\n"
192
+ elif message["role"] == "user":
193
+ if formatted_prompt:
194
+ formatted_prompt += f"{message['content']} [/INST]"
195
+ else:
196
+ formatted_prompt += f"<s>[INST] {message['content']} [/INST]"
197
+ elif message["role"] == "assistant":
198
+ formatted_prompt += f" {message['content']} </s><s>[INST] "
199
+
200
+ return formatted_prompt
201
+
202
+ def generate_from_image(
203
+ self,
204
+ image_path: str,
205
+ prompt: str,
206
+ max_new_tokens: int = 512,
207
+ temperature: float = 0.7,
208
+ top_p: float = 0.9,
209
+ conversation: List[Dict[str, str]] = None
210
+ ) -> str:
211
+ """
212
+ Generate text based on an image and a prompt.
213
+
214
+ Args:
215
+ image_path: Path to the image file
216
+ prompt: Text prompt
217
+ max_new_tokens: Maximum number of tokens to generate
218
+ temperature: Sampling temperature
219
+ top_p: Top-p sampling parameter
220
+ conversation: Optional conversation history
221
+
222
+ Returns:
223
+ Generated text response
224
+ """
225
+ # Encode image
226
+ image_features = self.encode_image(image_path)
227
+ projected_features = self.project_image_features(image_features)
228
+
229
+ # Format prompt
230
+ formatted_prompt = self.format_prompt(prompt, conversation)
231
+
232
+ # Tokenize prompt
233
+ inputs = self.tokenizer(formatted_prompt, return_tensors="pt").to(self.device)
234
+ input_ids = inputs.input_ids
235
+
236
+ # Prepare for generation
237
+ stopping_criteria = StoppingCriteriaList([
238
+ StoppingCriteriaSub(stops=[self.tokenizer.eos_token_id], encounters=1)
239
+ ])
240
+
241
+ # Generate response
242
+ with torch.no_grad():
243
+ # Prepare the inputs for the language model
244
+ # Here we would normally inject the image features into the language model
245
+ # This is a simplified version - in the actual LLaVA, this is done by modifying
246
+ # the language model's forward pass to accept image features
247
+
248
+ # For demonstration purposes, we'll just use the language model directly
249
+ outputs = self.language_model.generate(
250
+ input_ids=input_ids,
251
+ max_new_tokens=max_new_tokens,
252
+ temperature=temperature,
253
+ top_p=top_p,
254
+ stopping_criteria=stopping_criteria,
255
+ do_sample=True
256
+ )
257
+
258
+ # Decode the generated text
259
+ generated_text = self.tokenizer.decode(outputs[0][input_ids.shape[1]:], skip_special_tokens=True)
260
+
261
+ return generated_text.strip()
262
+
263
+ def save_model(self, output_dir: str):
264
+ """
265
+ Save the model to the specified directory.
266
+
267
+ Args:
268
+ output_dir: Directory to save the model
269
+ """
270
+ os.makedirs(output_dir, exist_ok=True)
271
+
272
+ # Save vision model
273
+ self.vision_model.save_pretrained(os.path.join(output_dir, "vision_model"))
274
+ self.image_processor.save_pretrained(os.path.join(output_dir, "vision_model"))
275
+
276
+ # Save language model
277
+ self.language_model.save_pretrained(os.path.join(output_dir, "language_model"))
278
+ self.tokenizer.save_pretrained(os.path.join(output_dir, "language_model"))
279
+
280
+ # Save projection layer
281
+ torch.save(self.projection.state_dict(), os.path.join(output_dir, "projection.pt"))
282
+
283
+ @classmethod
284
+ def from_pretrained(cls, model_path: str, device: str = None):
285
+ """
286
+ Load a pretrained LLaVA model.
287
+
288
+ Args:
289
+ model_path: Path to the saved model
290
+ device: Device to load the model on
291
+
292
+ Returns:
293
+ Loaded LLaVA model
294
+ """
295
+ # Load vision model
296
+ vision_model_path = os.path.join(model_path, "vision_model")
297
+
298
+ # Load language model
299
+ language_model_path = os.path.join(model_path, "language_model")
300
+
301
+ # Create model instance
302
+ model = cls(
303
+ vision_model_path=vision_model_path,
304
+ language_model_path=language_model_path,
305
+ device=device
306
+ )
307
+
308
+ # Load projection layer
309
+ projection_path = os.path.join(model_path, "projection.pt")
310
+ model.projection.load_state_dict(torch.load(projection_path, map_location=model.device))
311
+
312
+ return model
requirements.txt ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch>=2.0.0
2
+ torchvision>=0.15.0
3
+ transformers>=4.30.0
4
+ accelerate>=0.20.0
5
+ pillow>=9.0.0
6
+ numpy>=1.24.0
7
+ tqdm>=4.65.0
8
+ matplotlib>=3.7.0
9
+ opencv-python>=4.7.0
10
+ einops>=0.6.0
11
+ timm>=0.9.0
12
+ sentencepiece>=0.1.99
13
+ gradio>=3.35.0
14
+ peft>=0.4.0
15
+ bitsandbytes>=0.40.0
16
+ safetensors>=0.3.1
17
+ fastapi==0.104.1
18
+ uvicorn==0.24.0
19
+ python-multipart==0.0.6
20
+ pydantic==2.5.2
21
+ python-jose==3.3.0
22
+ passlib==1.7.4
23
+ bcrypt==4.0.1
24
+ aiofiles==23.2.1
25
+ python-dotenv==1.0.0
26
+ httpx==0.25.2
utils/__init__.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .data_utils import (
2
+ load_image,
3
+ process_image,
4
+ pad_image,
5
+ load_conversation_data,
6
+ format_conversation,
7
+ create_image_text_pair
8
+ )
9
+
10
+ from .eval_utils import (
11
+ evaluate_vqa,
12
+ visualize_results,
13
+ compute_metrics
14
+ )
15
+
16
+ from .visualization import (
17
+ display_image_with_caption,
18
+ visualize_attention,
19
+ create_comparison_grid,
20
+ add_caption_to_image
21
+ )
22
+
23
+ __all__ = [
24
+ 'load_image',
25
+ 'process_image',
26
+ 'pad_image',
27
+ 'load_conversation_data',
28
+ 'format_conversation',
29
+ 'create_image_text_pair',
30
+ 'evaluate_vqa',
31
+ 'visualize_results',
32
+ 'compute_metrics',
33
+ 'display_image_with_caption',
34
+ 'visualize_attention',
35
+ 'create_comparison_grid',
36
+ 'add_caption_to_image'
37
+ ]
utils/data_utils.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utility functions for data processing in LLaVA.
3
+ """
4
+
5
+ import os
6
+ import json
7
+ import torch
8
+ from PIL import Image
9
+ from typing import List, Dict, Tuple, Optional, Union
10
+ import numpy as np
11
+ from transformers import CLIPImageProcessor
12
+
13
+
14
+ def load_image(image_path: str) -> Image.Image:
15
+ """
16
+ Load an image from a file path.
17
+
18
+ Args:
19
+ image_path: Path to the image file
20
+
21
+ Returns:
22
+ PIL Image object
23
+ """
24
+ if not os.path.exists(image_path):
25
+ raise FileNotFoundError(f"Image file not found: {image_path}")
26
+
27
+ try:
28
+ image = Image.open(image_path).convert('RGB')
29
+ return image
30
+ except Exception as e:
31
+ raise ValueError(f"Error loading image: {e}")
32
+
33
+
34
+ def process_image(
35
+ image: Union[str, Image.Image],
36
+ image_processor: CLIPImageProcessor,
37
+ device: str = "cuda"
38
+ ) -> torch.Tensor:
39
+ """
40
+ Process an image for input to the vision model.
41
+
42
+ Args:
43
+ image: PIL Image object or path to image file
44
+ image_processor: CLIP image processor
45
+ device: Device to load the processed image on
46
+
47
+ Returns:
48
+ Processed image tensor
49
+ """
50
+ if isinstance(image, str):
51
+ image = load_image(image)
52
+
53
+ inputs = image_processor(images=image, return_tensors="pt").to(device)
54
+ return inputs
55
+
56
+
57
+ def pad_image(image: Image.Image, target_size: Tuple[int, int] = (336, 336)) -> Image.Image:
58
+ """
59
+ Pad an image to the target size while maintaining aspect ratio.
60
+
61
+ Args:
62
+ image: PIL Image object
63
+ target_size: Target size (width, height)
64
+
65
+ Returns:
66
+ Padded image
67
+ """
68
+ width, height = image.size
69
+ target_width, target_height = target_size
70
+
71
+ # Calculate padding
72
+ ratio = min(target_width / width, target_height / height)
73
+ new_width = int(width * ratio)
74
+ new_height = int(height * ratio)
75
+
76
+ # Resize image
77
+ resized_image = image.resize((new_width, new_height), Image.LANCZOS)
78
+
79
+ # Create new image with padding
80
+ new_image = Image.new("RGB", target_size, (0, 0, 0))
81
+ paste_x = (target_width - new_width) // 2
82
+ paste_y = (target_height - new_height) // 2
83
+ new_image.paste(resized_image, (paste_x, paste_y))
84
+
85
+ return new_image
86
+
87
+
88
+ def load_conversation_data(json_path: str) -> List[Dict]:
89
+ """
90
+ Load conversation data from a JSON file.
91
+
92
+ Args:
93
+ json_path: Path to the JSON file
94
+
95
+ Returns:
96
+ List of conversation dictionaries
97
+ """
98
+ if not os.path.exists(json_path):
99
+ raise FileNotFoundError(f"JSON file not found: {json_path}")
100
+
101
+ try:
102
+ with open(json_path, 'r', encoding='utf-8') as f:
103
+ data = json.load(f)
104
+ return data
105
+ except Exception as e:
106
+ raise ValueError(f"Error loading JSON data: {e}")
107
+
108
+
109
+ def format_conversation(
110
+ conversation: List[Dict[str, str]],
111
+ system_prompt: Optional[str] = None
112
+ ) -> List[Dict[str, str]]:
113
+ """
114
+ Format a conversation for the LLaVA model.
115
+
116
+ Args:
117
+ conversation: List of conversation messages
118
+ system_prompt: Optional system prompt to prepend
119
+
120
+ Returns:
121
+ Formatted conversation
122
+ """
123
+ formatted_conv = []
124
+
125
+ # Add system prompt if provided
126
+ if system_prompt:
127
+ formatted_conv.append({"role": "system", "content": system_prompt})
128
+
129
+ # Add conversation messages
130
+ for message in conversation:
131
+ if "role" in message and "content" in message:
132
+ formatted_conv.append({
133
+ "role": message["role"],
134
+ "content": message["content"]
135
+ })
136
+
137
+ return formatted_conv
138
+
139
+
140
+ def create_image_text_pair(
141
+ image_path: str,
142
+ text: str,
143
+ image_processor: CLIPImageProcessor,
144
+ tokenizer,
145
+ max_length: int = 512,
146
+ device: str = "cuda"
147
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
148
+ """
149
+ Create an image-text pair for training or inference.
150
+
151
+ Args:
152
+ image_path: Path to the image file
153
+ text: Text prompt
154
+ image_processor: CLIP image processor
155
+ tokenizer: Language model tokenizer
156
+ max_length: Maximum text length
157
+ device: Device to load tensors on
158
+
159
+ Returns:
160
+ Tuple of (image_tensor, text_tensor)
161
+ """
162
+ # Process image
163
+ image = load_image(image_path)
164
+ image_inputs = image_processor(images=image, return_tensors="pt").to(device)
165
+
166
+ # Process text
167
+ text_inputs = tokenizer(
168
+ text,
169
+ return_tensors="pt",
170
+ padding="max_length",
171
+ max_length=max_length,
172
+ truncation=True
173
+ ).to(device)
174
+
175
+ return image_inputs, text_inputs
utils/eval_utils.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utility functions for evaluating LLaVA models.
3
+ """
4
+
5
+ import os
6
+ import json
7
+ import torch
8
+ import numpy as np
9
+ from typing import List, Dict, Tuple, Optional, Union
10
+ from PIL import Image
11
+ import matplotlib.pyplot as plt
12
+ from tqdm import tqdm
13
+
14
+
15
+ def evaluate_vqa(
16
+ model,
17
+ questions_file: str,
18
+ image_folder: str,
19
+ output_file: Optional[str] = None,
20
+ max_new_tokens: int = 100
21
+ ) -> Dict:
22
+ """
23
+ Evaluate the model on visual question answering.
24
+
25
+ Args:
26
+ model: LLaVA model
27
+ questions_file: Path to the questions JSON file
28
+ image_folder: Path to the folder containing images
29
+ output_file: Optional path to save results
30
+ max_new_tokens: Maximum number of tokens to generate
31
+
32
+ Returns:
33
+ Dictionary with evaluation results
34
+ """
35
+ # Load questions
36
+ with open(questions_file, 'r', encoding='utf-8') as f:
37
+ questions = json.load(f)
38
+
39
+ results = []
40
+
41
+ # Process each question
42
+ for q in tqdm(questions, desc="Evaluating VQA"):
43
+ image_path = os.path.join(image_folder, q['image'])
44
+ question_text = q['question']
45
+
46
+ # Generate answer
47
+ try:
48
+ answer = model.generate_from_image(
49
+ image_path=image_path,
50
+ prompt=question_text,
51
+ max_new_tokens=max_new_tokens
52
+ )
53
+
54
+ result = {
55
+ 'question_id': q.get('question_id', None),
56
+ 'image': q['image'],
57
+ 'question': question_text,
58
+ 'answer': answer,
59
+ 'gt_answer': q.get('answer', None)
60
+ }
61
+
62
+ results.append(result)
63
+ except Exception as e:
64
+ print(f"Error processing question {q.get('question_id', '')}: {e}")
65
+
66
+ # Save results if output file is provided
67
+ if output_file:
68
+ with open(output_file, 'w', encoding='utf-8') as f:
69
+ json.dump(results, f, indent=2)
70
+
71
+ # Calculate accuracy if ground truth answers are available
72
+ accuracy = None
73
+ if all('gt_answer' in r and r['gt_answer'] is not None for r in results):
74
+ correct = 0
75
+ for r in results:
76
+ # Simple exact match accuracy
77
+ if r['answer'].lower() == r['gt_answer'].lower():
78
+ correct += 1
79
+
80
+ accuracy = correct / len(results) if results else 0
81
+
82
+ return {
83
+ 'results': results,
84
+ 'accuracy': accuracy,
85
+ 'num_questions': len(results)
86
+ }
87
+
88
+
89
+ def visualize_results(
90
+ results: List[Dict],
91
+ num_examples: int = 5,
92
+ figsize: Tuple[int, int] = (15, 10),
93
+ image_folder: str = None
94
+ ) -> None:
95
+ """
96
+ Visualize VQA results.
97
+
98
+ Args:
99
+ results: List of result dictionaries
100
+ num_examples: Number of examples to visualize
101
+ figsize: Figure size
102
+ image_folder: Path to the folder containing images
103
+ """
104
+ # Select a subset of results
105
+ if len(results) > num_examples:
106
+ indices = np.random.choice(len(results), num_examples, replace=False)
107
+ selected_results = [results[i] for i in indices]
108
+ else:
109
+ selected_results = results
110
+
111
+ # Create figure
112
+ fig, axes = plt.subplots(len(selected_results), 1, figsize=figsize)
113
+ if len(selected_results) == 1:
114
+ axes = [axes]
115
+
116
+ # Plot each example
117
+ for i, result in enumerate(selected_results):
118
+ # Load image
119
+ if image_folder:
120
+ image_path = os.path.join(image_folder, result['image'])
121
+ img = Image.open(image_path).convert('RGB')
122
+ axes[i].imshow(img)
123
+
124
+ # Set title and text
125
+ title = f"Q: {result['question']}"
126
+ text = f"A: {result['answer']}"
127
+ if 'gt_answer' in result and result['gt_answer']:
128
+ text += f"\nGT: {result['gt_answer']}"
129
+
130
+ axes[i].set_title(title)
131
+ axes[i].text(0, -0.5, text, transform=axes[i].transAxes, fontsize=12)
132
+ axes[i].axis('off')
133
+
134
+ plt.tight_layout()
135
+ plt.show()
136
+
137
+
138
+ def compute_metrics(results: List[Dict]) -> Dict:
139
+ """
140
+ Compute evaluation metrics.
141
+
142
+ Args:
143
+ results: List of result dictionaries
144
+
145
+ Returns:
146
+ Dictionary with metrics
147
+ """
148
+ metrics = {}
149
+
150
+ # Check if ground truth answers are available
151
+ has_gt = all('gt_answer' in r and r['gt_answer'] is not None for r in results)
152
+
153
+ if has_gt:
154
+ # Exact match accuracy
155
+ correct = 0
156
+ for r in results:
157
+ if r['answer'].lower() == r['gt_answer'].lower():
158
+ correct += 1
159
+
160
+ metrics['exact_match_accuracy'] = correct / len(results) if results else 0
161
+
162
+ # Token overlap (simple BLEU-like metric)
163
+ total_overlap = 0
164
+ for r in results:
165
+ pred_tokens = set(r['answer'].lower().split())
166
+ gt_tokens = set(r['gt_answer'].lower().split())
167
+
168
+ if gt_tokens: # Avoid division by zero
169
+ overlap = len(pred_tokens.intersection(gt_tokens)) / len(gt_tokens)
170
+ total_overlap += overlap
171
+
172
+ metrics['token_overlap'] = total_overlap / len(results) if results else 0
173
+
174
+ # Response length statistics
175
+ lengths = [len(r['answer'].split()) for r in results]
176
+ metrics['avg_response_length'] = sum(lengths) / len(lengths) if lengths else 0
177
+ metrics['min_response_length'] = min(lengths) if lengths else 0
178
+ metrics['max_response_length'] = max(lengths) if lengths else 0
179
+
180
+ return metrics
utils/visualization.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Visualization utilities for LLaVA.
3
+ """
4
+
5
+ import matplotlib.pyplot as plt
6
+ import numpy as np
7
+ from PIL import Image, ImageDraw, ImageFont
8
+ import torch
9
+ from typing import List, Dict, Tuple, Optional, Union
10
+ import cv2
11
+
12
+
13
+ def display_image_with_caption(
14
+ image_path: str,
15
+ caption: str,
16
+ figsize: Tuple[int, int] = (10, 10)
17
+ ) -> None:
18
+ """
19
+ Display an image with a caption.
20
+
21
+ Args:
22
+ image_path: Path to the image file
23
+ caption: Caption text
24
+ figsize: Figure size
25
+ """
26
+ image = Image.open(image_path).convert('RGB')
27
+
28
+ plt.figure(figsize=figsize)
29
+ plt.imshow(image)
30
+ plt.axis('off')
31
+ plt.title(caption)
32
+ plt.tight_layout()
33
+ plt.show()
34
+
35
+
36
+ def visualize_attention(
37
+ image_path: str,
38
+ attention_weights: torch.Tensor,
39
+ figsize: Tuple[int, int] = (15, 5)
40
+ ) -> None:
41
+ """
42
+ Visualize attention weights on an image.
43
+
44
+ Args:
45
+ image_path: Path to the image file
46
+ attention_weights: Attention weights tensor
47
+ figsize: Figure size
48
+ """
49
+ # Load image
50
+ image = Image.open(image_path).convert('RGB')
51
+ image_np = np.array(image)
52
+
53
+ # Normalize attention weights
54
+ if attention_weights.dim() > 2:
55
+ # Average across heads and layers if necessary
56
+ attention_weights = attention_weights.mean(dim=(0, 1))
57
+
58
+ attention_weights = attention_weights.detach().cpu().numpy()
59
+ attention_weights = (attention_weights - attention_weights.min()) / (attention_weights.max() - attention_weights.min())
60
+
61
+ # Resize attention map to image size
62
+ attention_map = cv2.resize(attention_weights, (image.width, image.height))
63
+
64
+ # Create heatmap
65
+ heatmap = cv2.applyColorMap(np.uint8(255 * attention_map), cv2.COLORMAP_JET)
66
+ heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
67
+
68
+ # Overlay heatmap on image
69
+ alpha = 0.5
70
+ overlay = heatmap * alpha + image_np * (1 - alpha)
71
+ overlay = overlay.astype(np.uint8)
72
+
73
+ # Display original image and attention overlay
74
+ fig, axes = plt.subplots(1, 3, figsize=figsize)
75
+
76
+ axes[0].imshow(image_np)
77
+ axes[0].set_title('Original Image')
78
+ axes[0].axis('off')
79
+
80
+ axes[1].imshow(heatmap)
81
+ axes[1].set_title('Attention Map')
82
+ axes[1].axis('off')
83
+
84
+ axes[2].imshow(overlay)
85
+ axes[2].set_title('Overlay')
86
+ axes[2].axis('off')
87
+
88
+ plt.tight_layout()
89
+ plt.show()
90
+
91
+
92
+ def create_comparison_grid(
93
+ image_path: str,
94
+ responses: List[Dict[str, str]],
95
+ output_path: Optional[str] = None,
96
+ figsize: Tuple[int, int] = (12, 10)
97
+ ) -> None:
98
+ """
99
+ Create a comparison grid of different model responses.
100
+
101
+ Args:
102
+ image_path: Path to the image file
103
+ responses: List of dictionaries with 'model' and 'response' keys
104
+ output_path: Optional path to save the figure
105
+ figsize: Figure size
106
+ """
107
+ # Load image
108
+ image = Image.open(image_path).convert('RGB')
109
+
110
+ # Create figure
111
+ fig = plt.figure(figsize=figsize)
112
+
113
+ # Add image
114
+ ax1 = plt.subplot2grid((len(responses) + 1, 3), (0, 0), colspan=3)
115
+ ax1.imshow(image)
116
+ ax1.set_title('Input Image')
117
+ ax1.axis('off')
118
+
119
+ # Add responses
120
+ for i, resp in enumerate(responses):
121
+ ax = plt.subplot2grid((len(responses) + 1, 3), (i + 1, 0), colspan=3)
122
+ ax.text(0.5, 0.5, f"{resp['model']}: {resp['response']}",
123
+ wrap=True, horizontalalignment='center',
124
+ verticalalignment='center', fontsize=10)
125
+ ax.axis('off')
126
+
127
+ plt.tight_layout()
128
+
129
+ # Save figure if output path is provided
130
+ if output_path:
131
+ plt.savefig(output_path, bbox_inches='tight')
132
+
133
+ plt.show()
134
+
135
+
136
+ def add_caption_to_image(
137
+ image_path: str,
138
+ caption: str,
139
+ output_path: str,
140
+ font_size: int = 20,
141
+ font_color: Tuple[int, int, int] = (255, 255, 255),
142
+ bg_color: Tuple[int, int, int] = (0, 0, 0)
143
+ ) -> None:
144
+ """
145
+ Add a caption to an image and save it.
146
+
147
+ Args:
148
+ image_path: Path to the input image
149
+ caption: Caption text
150
+ output_path: Path to save the output image
151
+ font_size: Font size
152
+ font_color: Font color (RGB)
153
+ bg_color: Background color (RGB)
154
+ """
155
+ # Load image
156
+ image = Image.open(image_path).convert('RGB')
157
+
158
+ # Create a new image with space for the caption
159
+ caption_height = font_size + 20 # Add some padding
160
+ new_image = Image.new('RGB', (image.width, image.height + caption_height), bg_color)
161
+ new_image.paste(image, (0, 0))
162
+
163
+ # Add caption
164
+ draw = ImageDraw.Draw(new_image)
165
+ try:
166
+ font = ImageFont.truetype("arial.ttf", font_size)
167
+ except IOError:
168
+ font = ImageFont.load_default()
169
+
170
+ # Calculate text position
171
+ text_width = draw.textlength(caption, font=font)
172
+ text_position = ((image.width - text_width) // 2, image.height + 10)
173
+
174
+ # Draw text
175
+ draw.text(text_position, caption, font=font, fill=font_color)
176
+
177
+ # Save image
178
+ new_image.save(output_path)