Florence-2-demo / app.py
Itanutiwari527's picture
Uploading app.py
981c23a verified
raw
history blame
3.9 kB
import streamlit as st
import torch
import numpy as np
import pandas as pd
from PIL import Image, ImageDraw
from transformers import AutoProcessor, AutoModelForCausalLM
# Device settings
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
# Load model with caching
@st.cache_resource
def load_model():
CHECKPOINT = "microsoft/Florence-2-base-ft"
model = AutoModelForCausalLM.from_pretrained(CHECKPOINT, trust_remote_code=True).to(device, dtype=torch_dtype)
processor = AutoProcessor.from_pretrained(CHECKPOINT, trust_remote_code=True)
return model, processor
# Load the model and processor
try:
model, processor = load_model()
except Exception as e:
st.error(f"Model loading failed: {e}")
st.stop()
# UI title
st.title("Florence-2 Multi-Modal Model Playground")
# Task selector
task = st.selectbox("Select Task", ["Object Detection (OD)", "Phrase Grounding (PG)", "Image Captioning (IC)"])
# Phrase input for PG
phrase = ""
if task == "Phrase Grounding (PG)":
phrase = st.text_input("Enter phrase for grounding (e.g., 'A red car')", "")
# Image uploader
uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
# If file uploaded
if uploaded_file:
try:
image = Image.open(uploaded_file).convert("RGB")
except Exception as e:
st.error(f"Error loading image: {e}")
st.stop()
st.image(image, caption="Uploaded Image", use_container_width=True)
# Task-specific prompt
if task == "Object Detection (OD)":
task_prompt = "<OD>"
elif task == "Phrase Grounding (PG)":
task_prompt = "<CAPTION_TO_PHRASE_GROUNDING>"
else:
task_prompt = "<CAPTION>"
# Preprocess inputs
try:
inputs = processor(text=task_prompt + phrase, images=image, return_tensors="pt").to(device, torch_dtype)
except Exception as e:
st.error(f"Error during preprocessing: {e}")
st.stop()
# Generate output
with torch.no_grad():
try:
generated_ids = model.generate(
input_ids=inputs["input_ids"],
pixel_values=inputs["pixel_values"],
max_new_tokens=512,
num_beams=3,
do_sample=False
)
except Exception as e:
st.error(f"Error during generation: {e}")
st.stop()
# Decode and post-process
try:
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
parsed_answer = processor.post_process_generation(
generated_text,
task=task_prompt,
image_size=(image.width, image.height)
)
except Exception as e:
st.error(f"Post-processing failed: {e}")
st.stop()
# Display results
if task in ["Object Detection (OD)", "Phrase Grounding (PG)"]:
key = "<OD>" if task == "Object Detection (OD)" else "<CAPTION_TO_PHRASE_GROUNDING>"
detections = parsed_answer.get(key, {"bboxes": [], "labels": []})
bboxes = detections.get("bboxes", [])
labels = detections.get("labels", [])
draw = ImageDraw.Draw(image)
data = []
for bbox, label in zip(bboxes, labels):
x_min, y_min, x_max, y_max = map(int, bbox)
draw.rectangle([x_min, y_min, x_max, y_max], outline="red", width=3)
draw.text((x_min, max(0, y_min - 10)), label, fill="red")
data.append([x_min, y_min, x_max - x_min, y_max - y_min, label])
st.image(image, caption="Detected Objects", use_container_width=True)
df = pd.DataFrame(data, columns=["x", "y", "w", "h", "object"])
st.dataframe(df)
else:
caption = parsed_answer.get("<CAPTION>", "No caption generated.")
st.subheader("Generated Caption:")
st.success(caption)