yakilee commited on
Commit
9351df3
·
verified ·
1 Parent(s): dba9150

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -117
app.py CHANGED
@@ -8,183 +8,123 @@ import os
8
  import torch
9
  import gc
10
 
11
- # Set PyTorch memory allocation configuration
12
- os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True,max_split_size_mb:128"
13
-
14
  DESCRIPTION = "[Sparrow Qwen2-VL-2B Backend](https://github.com/katanaml/sparrow)"
15
 
16
- def process_image(image_filepath, max_width=800, max_height=1000):
17
  if image_filepath is None:
18
- raise ValueError("No image provided. Please upload an image before submitting.")
19
 
20
  img = Image.open(image_filepath)
21
  width, height = img.size
22
 
23
- # Calculate new dimensions while maintaining aspect ratio
24
- if width > max_width or height > max_height:
25
- aspect_ratio = width / height
26
- if width > max_width:
27
- new_width = max_width
28
- new_height = int(new_width / aspect_ratio)
29
- if new_height > max_height:
30
- new_height = max_height
31
- new_width = int(new_height * aspect_ratio)
32
  else:
33
- new_width, new_height = width, height
34
-
35
- # Resize the image if needed
36
- if new_width != width or new_height != height:
37
- img = img.resize((new_width, new_height), Image.LANCZOS)
38
 
39
- # Generate temporary filename - use /tmp folder for better space management
40
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
41
- filename = f"/tmp/image_{timestamp}.jpg" # Use jpg for smaller file size
42
-
43
- # Save with optimized compression
44
- img.save(filename, format='JPEG', quality=85, optimize=True)
45
 
46
  return os.path.abspath(filename), new_width, new_height
47
 
48
- # Initialize model with memory optimizations but without 4-bit quantization
49
  model = None
50
  processor = None
51
 
52
  def load_model():
53
- # Load model with memory optimizations
54
  model = Qwen2VLForConditionalGeneration.from_pretrained(
55
  "Qwen/Qwen2-VL-2B-Instruct",
56
- torch_dtype=torch.float16, # Use fp16 for memory efficiency
57
  device_map="auto",
58
- attn_implementation="flash_attention_2" # Use FlashAttention if available
59
  )
60
-
61
  processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
62
- return model, processor
63
 
64
  @spaces.GPU
65
  def run_inference(input_imgs, text_input):
66
  global model, processor
67
-
68
- # Lazy load model
69
- if model is None or processor is None:
70
- model, processor = load_model()
71
 
72
  results = []
73
 
74
- # Process images one at a time to avoid OOM issues
75
  for image in input_imgs:
76
- # Clear cache before processing each image
77
  torch.cuda.empty_cache()
78
  gc.collect()
79
 
80
- # Process image with reduced dimensions
81
  image_path, width, height = process_image(image)
82
 
83
  try:
84
- # Create messages with optimized image
85
- messages = [
86
- {
87
- "role": "user",
88
- "content": [
89
- {
90
- "type": "image",
91
- "image": image_path,
92
- "resized_height": height,
93
- "resized_width": width
94
- },
95
- {
96
- "type": "text",
97
- "text": text_input
98
- }
99
- ]
100
- }
101
- ]
102
 
103
- # Prepare inputs with memory optimization
104
  text = processor.apply_chat_template(
105
  messages, tokenize=False, add_generation_prompt=True
106
  )
107
 
108
- image_inputs, video_inputs = process_vision_info(messages)
109
-
110
- # Clear unused memory
111
- del messages
112
- torch.cuda.empty_cache()
113
-
114
- # Process inputs with truncation to control memory usage
115
  inputs = processor(
116
  text=[text],
117
- images=image_inputs,
118
- videos=video_inputs,
119
  padding=True,
120
- truncation=True, # Add truncation
121
- max_length=768, # Limit context length
122
  return_tensors="pt",
123
- )
124
 
125
- # Move to GPU efficiently
126
- inputs = {k: v.to("cuda") for k, v in inputs.items()}
127
-
128
- # Clean up variables to free memory
129
- del text, image_inputs, video_inputs
130
- torch.cuda.empty_cache()
131
-
132
- # Generate with optimized parameters
133
- with torch.inference_mode(): # More efficient than no_grad
134
  generated_ids = model.generate(
135
- **inputs,
136
- max_new_tokens=1024, # Reduced from 4096
137
- do_sample=False, # Deterministic generation uses less memory
138
- use_cache=True, # Use KV cache
139
- num_beams=1 # Disable beam search to save memory
140
  )
141
-
142
- # Process output efficiently
143
- generated_ids_trimmed = [
144
- out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs["input_ids"], generated_ids)
145
- ]
146
 
147
- raw_output = processor.batch_decode(
148
- generated_ids_trimmed, skip_special_tokens=True
149
- )
 
 
150
 
151
- results.append(raw_output[0])
152
- print(f"Processed: {image_path}")
153
 
154
- # Clear tensors from GPU memory
155
- del inputs, generated_ids, generated_ids_trimmed
156
  torch.cuda.empty_cache()
157
  gc.collect()
158
 
159
  finally:
160
- # Clean up temporary files
161
  if os.path.exists(image_path):
162
  os.remove(image_path)
163
 
164
  return results
165
 
166
- # Gradio interface
167
- css = """
168
- #output {
169
- height: 500px;
170
- overflow: auto;
171
- border: 1px solid #ccc;
172
- }
173
- """
174
-
175
- with gr.Blocks(css=css) as demo:
176
  gr.Markdown(DESCRIPTION)
177
- with gr.Tab(label="Qwen2-VL-2B Input"):
178
- with gr.Row():
179
- with gr.Column():
180
- input_imgs = gr.Files(file_types=["image"], label="Upload Document Images")
181
- text_input = gr.Textbox(label="Query")
182
- submit_btn = gr.Button(value="Submit", variant="primary")
183
- with gr.Column():
184
- output_text = gr.Textbox(label="Response")
185
 
186
- submit_btn.click(run_inference, [input_imgs, text_input], [output_text])
187
 
188
- # Use smaller queue size to manage memory
189
- demo.queue(api_open=True, max_size=3)
190
- demo.launch(debug=True)
 
8
  import torch
9
  import gc
10
 
11
+ # Configure memory settings
12
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:64"
 
13
  DESCRIPTION = "[Sparrow Qwen2-VL-2B Backend](https://github.com/katanaml/sparrow)"
14
 
15
+ def process_image(image_filepath, max_width=640, max_height=800):
16
  if image_filepath is None:
17
+ raise ValueError("No image provided")
18
 
19
  img = Image.open(image_filepath)
20
  width, height = img.size
21
 
22
+ # Enhanced resizing with aspect ratio preservation
23
+ aspect_ratio = width / height
24
+ if aspect_ratio > (max_width/max_height):
25
+ new_width = max_width
26
+ new_height = int(max_width / aspect_ratio)
 
 
 
 
27
  else:
28
+ new_height = max_height
29
+ new_width = int(max_height * aspect_ratio)
 
 
 
30
 
31
+ img = img.resize((new_width, new_height), Image.LANCZOS)
32
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
33
+ filename = f"/tmp/image_{timestamp}.jpg"
34
+ img.save(filename, format='JPEG', quality=75, optimize=True)
 
 
35
 
36
  return os.path.abspath(filename), new_width, new_height
37
 
38
+ # Model initialization with memory optimizations
39
  model = None
40
  processor = None
41
 
42
  def load_model():
43
+ global model, processor
44
  model = Qwen2VLForConditionalGeneration.from_pretrained(
45
  "Qwen/Qwen2-VL-2B-Instruct",
46
+ torch_dtype=torch.float16,
47
  device_map="auto",
48
+ low_cpu_mem_usage=True
49
  )
 
50
  processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
 
51
 
52
  @spaces.GPU
53
  def run_inference(input_imgs, text_input):
54
  global model, processor
55
+ if model is None:
56
+ load_model()
 
 
57
 
58
  results = []
59
 
 
60
  for image in input_imgs:
 
61
  torch.cuda.empty_cache()
62
  gc.collect()
63
 
 
64
  image_path, width, height = process_image(image)
65
 
66
  try:
67
+ messages = [{
68
+ "role": "user",
69
+ "content": [
70
+ {"type": "image", "image": image_path},
71
+ {"type": "text", "text": text_input}
72
+ ]
73
+ }]
 
 
 
 
 
 
 
 
 
 
 
74
 
 
75
  text = processor.apply_chat_template(
76
  messages, tokenize=False, add_generation_prompt=True
77
  )
78
 
79
+ # Process inputs in chunks
 
 
 
 
 
 
80
  inputs = processor(
81
  text=[text],
82
+ images=[Image.open(image_path)],
 
83
  padding=True,
84
+ truncation=True,
85
+ max_length=512,
86
  return_tensors="pt",
87
+ ).to("cuda")
88
 
89
+ # Memory-efficient generation
90
+ with torch.inference_mode():
 
 
 
 
 
 
 
91
  generated_ids = model.generate(
92
+ **inputs,
93
+ max_new_tokens=512,
94
+ do_sample=False,
95
+ num_beams=1,
96
+ early_stopping=True
97
  )
 
 
 
 
 
98
 
99
+ # Clean output processing
100
+ output = processor.batch_decode(
101
+ generated_ids[:, inputs.input_ids.shape[1]:],
102
+ skip_special_tokens=True
103
+ )[0]
104
 
105
+ results.append(output)
 
106
 
107
+ # Force memory cleanup
108
+ del inputs, generated_ids
109
  torch.cuda.empty_cache()
110
  gc.collect()
111
 
112
  finally:
 
113
  if os.path.exists(image_path):
114
  os.remove(image_path)
115
 
116
  return results
117
 
118
+ # Streamlined interface
119
+ with gr.Blocks() as demo:
 
 
 
 
 
 
 
 
120
  gr.Markdown(DESCRIPTION)
121
+ with gr.Row():
122
+ input_imgs = gr.Files(file_types=["image"], label="Upload Images")
123
+ text_input = gr.Textbox(label="Query")
124
+ submit_btn = gr.Button("Submit", variant="primary")
125
+ output_text = gr.Textbox(label="Response", elem_id="output")
 
 
 
126
 
127
+ submit_btn.click(run_inference, [input_imgs, text_input], output_text)
128
 
129
+ demo.queue(max_size=1)
130
+ demo.launch()