ariG23498 HF Staff commited on
Commit
f87f007
·
verified ·
1 Parent(s): c96e867

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +123 -8
app.py CHANGED
@@ -1,14 +1,129 @@
 
 
 
 
 
 
 
1
  import gradio as gr
2
  import spaces
3
- import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
- zero = torch.Tensor([0]).cuda()
6
- print(zero.device) # <-- 'cpu' 🤔
 
 
 
 
7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  @spaces.GPU
9
- def greet(n):
10
- print(zero.device) # <-- 'cuda:0' 🤗
11
- return f"Hello {zero + n} Tensor"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
- demo = gr.Interface(fn=greet, inputs=gr.Number(), outputs=gr.Text())
14
- demo.launch()
 
1
+ import os
2
+ import re
3
+ import random
4
+ from dataclasses import dataclass
5
+ from functools import partial
6
+
7
+ import torch
8
  import gradio as gr
9
  import spaces
10
+ from datasets import load_dataset
11
+ from torch.utils.data import DataLoader
12
+ from transformers import AutoProcessor, Gemma3ForConditionalGeneration
13
+ from PIL import Image, ImageDraw
14
+
15
+
16
+ # --- Configuration ---
17
+ @dataclass
18
+ class Configuration:
19
+ dataset_id: str = "ariG23498/license-detection-paligemma"
20
+ model_id: str = "google/gemma-3-4b-pt"
21
+ checkpoint_id: str = "ariG23498/gemma-3-4b-pt-object-detection"
22
+ device: str = "cuda" if torch.cuda.is_available() else "cpu"
23
+ dtype: torch.dtype = torch.bfloat16
24
+ batch_size: int = 4
25
+ learning_rate: float = 2e-05
26
+ epochs: int = 1
27
+
28
+
29
+ # --- Utils ---
30
+ def parse_paligemma_label(label, width, height):
31
+ # Extract location codes
32
+ loc_pattern = r"<loc(\d{4})>"
33
+ locations = [int(loc) for loc in re.findall(loc_pattern, label)]
34
+
35
+ # Extract category (everything after the last location code)
36
+ category = label.split(">")[-1].strip()
37
+
38
+ # Order in PaliGemma format is: y1, x1, y2, x2
39
+ y1_norm, x1_norm, y2_norm, x2_norm = locations
40
+
41
+ # Convert normalized coordinates to image coordinates
42
+ x1 = (x1_norm / 1024) * width
43
+ y1 = (y1_norm / 1024) * height
44
+ x2 = (x2_norm / 1024) * width
45
+ y2 = (y2_norm / 1024) * height
46
+
47
+ return category, [x1, y1, x2, y2]
48
+
49
+
50
+ def visualize_bounding_boxes(image, label, width, height):
51
+ # Copy image for drawing
52
+ draw_image = image.copy()
53
+ draw = ImageDraw.Draw(draw_image)
54
+
55
+ category, bbox = parse_paligemma_label(label, width, height)
56
+
57
+ draw.rectangle(bbox, outline="red", width=2)
58
+ draw.text((bbox[0], max(0, bbox[1] - 10)), category, fill="red")
59
+
60
+ return draw_image
61
+
62
 
63
+ def test_collate_function(batch_of_samples, processor, dtype):
64
+ images = []
65
+ prompts = []
66
+ for sample in batch_of_samples:
67
+ images.append([sample["image"]])
68
+ prompts.append(f"{processor.tokenizer.boi_token} detect \n\n")
69
 
70
+ batch = processor(images=images, text=prompts, return_tensors="pt", padding=True)
71
+ batch["pixel_values"] = batch["pixel_values"].to(dtype)
72
+ return batch, images
73
+
74
+
75
+ # --- Initialize ---
76
+ cfg = Configuration()
77
+
78
+ processor = AutoProcessor.from_pretrained(cfg.checkpoint_id)
79
+ model = Gemma3ForConditionalGeneration.from_pretrained(
80
+ cfg.checkpoint_id,
81
+ torch_dtype=cfg.dtype,
82
+ device_map="cpu",
83
+ )
84
+ model.eval()
85
+
86
+ test_dataset = load_dataset(cfg.dataset_id, split="test")
87
+
88
+
89
+ def get_sample():
90
+ sample = random.choice(test_dataset)
91
+ images = [[sample["image"]]]
92
+ prompts = [f"{processor.tokenizer.boi_token} detect \n\n"]
93
+
94
+ batch = processor(images=images, text=prompts, return_tensors="pt", padding=True)
95
+ batch["pixel_values"] = batch["pixel_values"].to(cfg.dtype)
96
+
97
+ return batch, sample["image"]
98
+
99
+
100
+ # --- Prediction Logic ---
101
  @spaces.GPU
102
+ def run_prediction():
103
+ model.to(cfg.device)
104
+
105
+ batch, raw_image = get_sample()
106
+ batch = {k: v.to(cfg.device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
107
+
108
+ with torch.no_grad():
109
+ generation = model.generate(**batch, max_new_tokens=100)
110
+ decoded = processor.batch_decode(generation, skip_special_tokens=True)[0]
111
+
112
+ image = raw_image[0]
113
+ width, height = image.size
114
+
115
+ result_image = visualize_bounding_boxes(image, decoded, width, height)
116
+ return result_image
117
+
118
+
119
+ # --- Gradio Interface ---
120
+ demo = gr.Interface(
121
+ fn=run_prediction,
122
+ inputs=[],
123
+ outputs=gr.Image(type="pil", label="Detected Bounding Box"),
124
+ title="Gemma3 Object Detector",
125
+ description="Click 'Run' to visualize a prediction from a randomly sampled test image.",
126
+ )
127
 
128
+ if __name__ == "__main__":
129
+ demo.launch()