ariG23498's picture
ariG23498 HF Staff
Update app.py
035f746 verified
import os
import re
import random
from dataclasses import dataclass
from functools import partial
import torch
import gradio as gr
import spaces
from datasets import load_dataset
from torch.utils.data import DataLoader
from transformers import AutoProcessor, Gemma3ForConditionalGeneration
from PIL import Image, ImageDraw
# --- Configuration ---
@dataclass
class Configuration:
dataset_id: str = "ariG23498/license-detection-paligemma"
model_id: str = "google/gemma-3-4b-pt"
checkpoint_id: str = "ariG23498/gemma-3-4b-pt-object-detection"
device: str = "cuda" if torch.cuda.is_available() else "cpu"
dtype: torch.dtype = torch.bfloat16
batch_size: int = 4
learning_rate: float = 2e-05
epochs: int = 1
# --- Utils ---
def parse_paligemma_label(label, width, height):
# Extract location codes
loc_pattern = r"<loc(\d{4})>"
locations = [int(loc) for loc in re.findall(loc_pattern, label)]
# Extract category (everything after the last location code)
category = label.split(">")[-1].strip()
# Order in PaliGemma format is: y1, x1, y2, x2
y1_norm, x1_norm, y2_norm, x2_norm = locations
# Convert normalized coordinates to image coordinates
x1 = (x1_norm / 1024) * width
y1 = (y1_norm / 1024) * height
x2 = (x2_norm / 1024) * width
y2 = (y2_norm / 1024) * height
return category, [x1, y1, x2, y2]
def visualize_bounding_boxes(image, label, width, height):
# Copy image for drawing
draw_image = image.copy()
draw = ImageDraw.Draw(draw_image)
category, bbox = parse_paligemma_label(label, width, height)
draw.rectangle(bbox, outline="red", width=2)
draw.text((bbox[0], max(0, bbox[1] - 10)), category, fill="red")
return draw_image
def test_collate_function(batch_of_samples, processor, dtype):
images = []
prompts = []
for sample in batch_of_samples:
images.append([sample["image"]])
prompts.append(f"{processor.tokenizer.boi_token} detect \n\n")
batch = processor(images=images, text=prompts, return_tensors="pt", padding=True)
batch["pixel_values"] = batch["pixel_values"].to(dtype)
return batch, images
# --- Initialize ---
cfg = Configuration()
processor = AutoProcessor.from_pretrained(cfg.checkpoint_id)
model = Gemma3ForConditionalGeneration.from_pretrained(
cfg.checkpoint_id,
torch_dtype=cfg.dtype,
device_map="cpu",
)
model.eval()
test_dataset = load_dataset(cfg.dataset_id, split="test")
def get_sample():
sample = random.choice(test_dataset)
images = [[sample["image"]]]
prompts = [f"{processor.tokenizer.boi_token} detect \n\n"]
batch = processor(images=images, text=prompts, return_tensors="pt", padding=True)
batch["pixel_values"] = batch["pixel_values"].to(cfg.dtype)
return batch, sample["image"]
# --- Prediction Logic ---
@spaces.GPU
def run_prediction():
model.to(cfg.device)
batch, raw_image = get_sample()
batch = {k: v.to(cfg.device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
with torch.no_grad():
generation = model.generate(**batch, max_new_tokens=100)
decoded = processor.batch_decode(generation, skip_special_tokens=True)[0]
image = raw_image # ✅ FIXED: raw_image is already a PIL.Image
width, height = image.size
result_image = visualize_bounding_boxes(image, decoded, width, height)
return result_image
# --- Gradio Interface ---
demo = gr.Interface(
fn=run_prediction,
inputs=[],
outputs=gr.Image(type="pil", label="Detected Bounding Box"),
title="Gemma3 Object Detector",
description="Click 'Generate' to visualize a prediction from a randomly sampled test image.",
)
if __name__ == "__main__":
demo.launch()