File size: 3,750 Bytes
f87f007
 
 
 
 
 
 
c96e867
 
f87f007
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c96e867
f87f007
 
 
 
 
 
c96e867
f87f007
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c96e867
f87f007
 
 
 
 
 
 
 
 
 
d53269f
f87f007
 
 
 
 
 
 
 
 
 
 
 
035f746
f87f007
c96e867
f87f007
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import os
import re
import random
from dataclasses import dataclass
from functools import partial

import torch
import gradio as gr
import spaces
from datasets import load_dataset
from torch.utils.data import DataLoader
from transformers import AutoProcessor, Gemma3ForConditionalGeneration
from PIL import Image, ImageDraw


# --- Configuration ---
@dataclass
class Configuration:
    dataset_id: str = "ariG23498/license-detection-paligemma"
    model_id: str = "google/gemma-3-4b-pt"
    checkpoint_id: str = "ariG23498/gemma-3-4b-pt-object-detection"
    device: str = "cuda" if torch.cuda.is_available() else "cpu"
    dtype: torch.dtype = torch.bfloat16
    batch_size: int = 4
    learning_rate: float = 2e-05
    epochs: int = 1


# --- Utils ---
def parse_paligemma_label(label, width, height):
    # Extract location codes
    loc_pattern = r"<loc(\d{4})>"
    locations = [int(loc) for loc in re.findall(loc_pattern, label)]

    # Extract category (everything after the last location code)
    category = label.split(">")[-1].strip()

    # Order in PaliGemma format is: y1, x1, y2, x2
    y1_norm, x1_norm, y2_norm, x2_norm = locations

    # Convert normalized coordinates to image coordinates
    x1 = (x1_norm / 1024) * width
    y1 = (y1_norm / 1024) * height
    x2 = (x2_norm / 1024) * width
    y2 = (y2_norm / 1024) * height

    return category, [x1, y1, x2, y2]


def visualize_bounding_boxes(image, label, width, height):
    # Copy image for drawing
    draw_image = image.copy()
    draw = ImageDraw.Draw(draw_image)

    category, bbox = parse_paligemma_label(label, width, height)

    draw.rectangle(bbox, outline="red", width=2)
    draw.text((bbox[0], max(0, bbox[1] - 10)), category, fill="red")

    return draw_image


def test_collate_function(batch_of_samples, processor, dtype):
    images = []
    prompts = []
    for sample in batch_of_samples:
        images.append([sample["image"]])
        prompts.append(f"{processor.tokenizer.boi_token} detect \n\n")

    batch = processor(images=images, text=prompts, return_tensors="pt", padding=True)
    batch["pixel_values"] = batch["pixel_values"].to(dtype)
    return batch, images


# --- Initialize ---
cfg = Configuration()

processor = AutoProcessor.from_pretrained(cfg.checkpoint_id)
model = Gemma3ForConditionalGeneration.from_pretrained(
    cfg.checkpoint_id,
    torch_dtype=cfg.dtype,
    device_map="cpu",
)
model.eval()

test_dataset = load_dataset(cfg.dataset_id, split="test")


def get_sample():
    sample = random.choice(test_dataset)
    images = [[sample["image"]]]
    prompts = [f"{processor.tokenizer.boi_token} detect \n\n"]

    batch = processor(images=images, text=prompts, return_tensors="pt", padding=True)
    batch["pixel_values"] = batch["pixel_values"].to(cfg.dtype)

    return batch, sample["image"]


# --- Prediction Logic ---
@spaces.GPU
def run_prediction():
    model.to(cfg.device)

    batch, raw_image = get_sample()
    batch = {k: v.to(cfg.device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}

    with torch.no_grad():
        generation = model.generate(**batch, max_new_tokens=100)
    decoded = processor.batch_decode(generation, skip_special_tokens=True)[0]

    image = raw_image  # ✅ FIXED: raw_image is already a PIL.Image
    width, height = image.size

    result_image = visualize_bounding_boxes(image, decoded, width, height)
    return result_image


# --- Gradio Interface ---
demo = gr.Interface(
    fn=run_prediction,
    inputs=[],
    outputs=gr.Image(type="pil", label="Detected Bounding Box"),
    title="Gemma3 Object Detector",
    description="Click 'Generate' to visualize a prediction from a randomly sampled test image.",
)

if __name__ == "__main__":
    demo.launch()