import os import re import random from dataclasses import dataclass from functools import partial import torch import gradio as gr import spaces from datasets import load_dataset from torch.utils.data import DataLoader from transformers import AutoProcessor, Gemma3ForConditionalGeneration from PIL import Image, ImageDraw # --- Configuration --- @dataclass class Configuration: dataset_id: str = "ariG23498/license-detection-paligemma" model_id: str = "google/gemma-3-4b-pt" checkpoint_id: str = "ariG23498/gemma-3-4b-pt-object-detection" device: str = "cuda" if torch.cuda.is_available() else "cpu" dtype: torch.dtype = torch.bfloat16 batch_size: int = 4 learning_rate: float = 2e-05 epochs: int = 1 # --- Utils --- def parse_paligemma_label(label, width, height): # Extract location codes loc_pattern = r"" locations = [int(loc) for loc in re.findall(loc_pattern, label)] # Extract category (everything after the last location code) category = label.split(">")[-1].strip() # Order in PaliGemma format is: y1, x1, y2, x2 y1_norm, x1_norm, y2_norm, x2_norm = locations # Convert normalized coordinates to image coordinates x1 = (x1_norm / 1024) * width y1 = (y1_norm / 1024) * height x2 = (x2_norm / 1024) * width y2 = (y2_norm / 1024) * height return category, [x1, y1, x2, y2] def visualize_bounding_boxes(image, label, width, height): # Copy image for drawing draw_image = image.copy() draw = ImageDraw.Draw(draw_image) category, bbox = parse_paligemma_label(label, width, height) draw.rectangle(bbox, outline="red", width=2) draw.text((bbox[0], max(0, bbox[1] - 10)), category, fill="red") return draw_image def test_collate_function(batch_of_samples, processor, dtype): images = [] prompts = [] for sample in batch_of_samples: images.append([sample["image"]]) prompts.append(f"{processor.tokenizer.boi_token} detect \n\n") batch = processor(images=images, text=prompts, return_tensors="pt", padding=True) batch["pixel_values"] = batch["pixel_values"].to(dtype) return batch, images # --- Initialize --- cfg = Configuration() processor = AutoProcessor.from_pretrained(cfg.checkpoint_id) model = Gemma3ForConditionalGeneration.from_pretrained( cfg.checkpoint_id, torch_dtype=cfg.dtype, device_map="cpu", ) model.eval() test_dataset = load_dataset(cfg.dataset_id, split="test") def get_sample(): sample = random.choice(test_dataset) images = [[sample["image"]]] prompts = [f"{processor.tokenizer.boi_token} detect \n\n"] batch = processor(images=images, text=prompts, return_tensors="pt", padding=True) batch["pixel_values"] = batch["pixel_values"].to(cfg.dtype) return batch, sample["image"] # --- Prediction Logic --- @spaces.GPU def run_prediction(): model.to(cfg.device) batch, raw_image = get_sample() batch = {k: v.to(cfg.device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()} with torch.no_grad(): generation = model.generate(**batch, max_new_tokens=100) decoded = processor.batch_decode(generation, skip_special_tokens=True)[0] image = raw_image # ✅ FIXED: raw_image is already a PIL.Image width, height = image.size result_image = visualize_bounding_boxes(image, decoded, width, height) return result_image # --- Gradio Interface --- demo = gr.Interface( fn=run_prediction, inputs=[], outputs=gr.Image(type="pil", label="Detected Bounding Box"), title="Gemma3 Object Detector", description="Click 'Generate' to visualize a prediction from a randomly sampled test image.", ) if __name__ == "__main__": demo.launch()