Darius Morawiec commited on
Commit
73b837f
·
1 Parent(s): 251f917

Add image resizing option and refactor GPU duration handling

Browse files
Files changed (1) hide show
  1. app.py +25 -7
app.py CHANGED
@@ -33,6 +33,7 @@ else:
33
 
34
 
35
  # Define constants
 
36
  EXAMPLES_DIR = Path(__file__).parent / "examples"
37
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
38
  MODEL_IDS = [
@@ -49,7 +50,7 @@ MODEL_IDS = [
49
  ]
50
 
51
 
52
- def scale_image(image, target_size=1000):
53
  width, height = image.size
54
  if max(width, height) <= target_size:
55
  return image
@@ -112,6 +113,15 @@ with gr.Blocks() as demo:
112
  step=32,
113
  interactive=True,
114
  )
 
 
 
 
 
 
 
 
 
115
  image_target_size = gr.Slider(
116
  label="Image Target Size",
117
  minimum=256,
@@ -119,6 +129,7 @@ with gr.Blocks() as demo:
119
  value=1024,
120
  step=1,
121
  interactive=True,
 
122
  )
123
 
124
  with gr.Column():
@@ -192,12 +203,15 @@ with gr.Blocks() as demo:
192
  system_prompt: str,
193
  user_prompt: str,
194
  max_new_tokens: int = 1024,
 
195
  image_target_size: int | None = None,
196
  ):
197
  model, processor = load_model(model_id)
198
 
199
  base64_image = image_to_base64(
200
- scale_image(image, image_target_size) if image_target_size else image
 
 
201
  )
202
  messages = [
203
  {
@@ -228,11 +242,8 @@ with gr.Blocks() as demo:
228
  )
229
  inputs = inputs.to(DEVICE)
230
 
231
- @spaces.GPU(duration=300)
232
- def _generate(**kwargs):
233
- return model.generate(**kwargs)
234
-
235
- generated_ids = _generate(**inputs, max_new_tokens=max_new_tokens)
236
  generated_ids_trimmed = [
237
  out_ids[len(in_ids) :]
238
  for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
@@ -288,6 +299,7 @@ with gr.Blocks() as demo:
288
  default_system_prompt,
289
  "detect sailboat, rowboat, person",
290
  512,
 
291
  1920,
292
  ],
293
  [
@@ -296,6 +308,7 @@ with gr.Blocks() as demo:
296
  default_system_prompt,
297
  "detect shirt, jeans, jacket, skirt, sunglasses, earring, drink",
298
  1024,
 
299
  1920,
300
  ],
301
  [
@@ -304,6 +317,7 @@ with gr.Blocks() as demo:
304
  default_system_prompt,
305
  "detect basketball, player with white jersey, player with black jersey",
306
  512,
 
307
  1920,
308
  ],
309
  [
@@ -312,6 +326,7 @@ with gr.Blocks() as demo:
312
  default_system_prompt,
313
  "detect app to find great places, app to take beautiful photos, app to listen music",
314
  512,
 
315
  1920,
316
  ],
317
  [
@@ -320,6 +335,7 @@ with gr.Blocks() as demo:
320
  default_system_prompt,
321
  "detect person, bicycle, netherlands flag",
322
  1920,
 
323
  1920,
324
  ],
325
  ],
@@ -329,6 +345,7 @@ with gr.Blocks() as demo:
329
  system_prompt,
330
  user_prompt,
331
  max_new_tokens,
 
332
  image_target_size,
333
  ],
334
  outputs=[
@@ -351,6 +368,7 @@ with gr.Blocks() as demo:
351
  system_prompt,
352
  user_prompt,
353
  max_new_tokens,
 
354
  image_target_size,
355
  ],
356
  outputs=[
 
33
 
34
 
35
  # Define constants
36
+ GPU_DURATION = 300
37
  EXAMPLES_DIR = Path(__file__).parent / "examples"
38
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
39
  MODEL_IDS = [
 
50
  ]
51
 
52
 
53
+ def resize_image(image, target_size=1000):
54
  width, height = image.size
55
  if max(width, height) <= target_size:
56
  return image
 
113
  step=32,
114
  interactive=True,
115
  )
116
+
117
+ image_resize = gr.Radio(
118
+ label="Resize Image",
119
+ choices=["Yes", "No"],
120
+ value="Yes",
121
+ interactive=True,
122
+ scale=2,
123
+ )
124
+
125
  image_target_size = gr.Slider(
126
  label="Image Target Size",
127
  minimum=256,
 
129
  value=1024,
130
  step=1,
131
  interactive=True,
132
+ scale=2,
133
  )
134
 
135
  with gr.Column():
 
203
  system_prompt: str,
204
  user_prompt: str,
205
  max_new_tokens: int = 1024,
206
+ image_resize: str = "Yes",
207
  image_target_size: int | None = None,
208
  ):
209
  model, processor = load_model(model_id)
210
 
211
  base64_image = image_to_base64(
212
+ resize_image(image, image_target_size)
213
+ if image_resize == "Yes" and image_target_size
214
+ else image
215
  )
216
  messages = [
217
  {
 
242
  )
243
  inputs = inputs.to(DEVICE)
244
 
245
+ generate = spaces.GPU(model.generate, duration=GPU_DURATION)
246
+ generated_ids = generate(**inputs, max_new_tokens=max_new_tokens)
 
 
 
247
  generated_ids_trimmed = [
248
  out_ids[len(in_ids) :]
249
  for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
 
299
  default_system_prompt,
300
  "detect sailboat, rowboat, person",
301
  512,
302
+ "Yes",
303
  1920,
304
  ],
305
  [
 
308
  default_system_prompt,
309
  "detect shirt, jeans, jacket, skirt, sunglasses, earring, drink",
310
  1024,
311
+ "Yes",
312
  1920,
313
  ],
314
  [
 
317
  default_system_prompt,
318
  "detect basketball, player with white jersey, player with black jersey",
319
  512,
320
+ "Yes",
321
  1920,
322
  ],
323
  [
 
326
  default_system_prompt,
327
  "detect app to find great places, app to take beautiful photos, app to listen music",
328
  512,
329
+ "Yes",
330
  1920,
331
  ],
332
  [
 
335
  default_system_prompt,
336
  "detect person, bicycle, netherlands flag",
337
  1920,
338
+ "Yes",
339
  1920,
340
  ],
341
  ],
 
345
  system_prompt,
346
  user_prompt,
347
  max_new_tokens,
348
+ image_resize,
349
  image_target_size,
350
  ],
351
  outputs=[
 
368
  system_prompt,
369
  user_prompt,
370
  max_new_tokens,
371
+ image_resize,
372
  image_target_size,
373
  ],
374
  outputs=[