Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import torch | |
from PIL import Image | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import cv2 | |
import numpy as np | |
from typing import Optional | |
import tempfile | |
import os | |
import spaces | |
MID = "apple/FastVLM-7B" | |
IMAGE_TOKEN_INDEX = -200 | |
# Initialize model variables | |
tok = None | |
model = None | |
def load_model(): | |
global tok, model | |
if tok is None or model is None: | |
print("Loading FastVLM model...") | |
tok = AutoTokenizer.from_pretrained(MID, trust_remote_code=True) | |
model = AutoModelForCausalLM.from_pretrained( | |
MID, | |
torch_dtype=torch.float16, | |
device_map="cuda", | |
trust_remote_code=True, | |
) | |
print("Model loaded successfully!") | |
return tok, model | |
def extract_frames(video_path: str, num_frames: int = 8, sampling_method: str = "uniform"): | |
"""Extract frames from video""" | |
cap = cv2.VideoCapture(video_path) | |
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
if total_frames == 0: | |
cap.release() | |
return [] | |
frames = [] | |
if sampling_method == "uniform": | |
# Uniform sampling | |
indices = np.linspace(0, total_frames - 1, num_frames, dtype=int) | |
elif sampling_method == "first": | |
# Take first N frames | |
indices = list(range(min(num_frames, total_frames))) | |
elif sampling_method == "last": | |
# Take last N frames | |
start = max(0, total_frames - num_frames) | |
indices = list(range(start, total_frames)) | |
else: # middle | |
# Take frames from the middle | |
start = max(0, (total_frames - num_frames) // 2) | |
indices = list(range(start, min(start + num_frames, total_frames))) | |
for idx in indices: | |
cap.set(cv2.CAP_PROP_POS_FRAMES, idx) | |
ret, frame = cap.read() | |
if ret: | |
# Convert BGR to RGB | |
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
frames.append(Image.fromarray(frame_rgb)) | |
cap.release() | |
return frames | |
def caption_frame(image: Image.Image, prompt: str) -> str: | |
"""Generate caption for a single frame""" | |
# Load model on GPU | |
tok, model = load_model() | |
# Build chat with custom prompt | |
messages = [ | |
{"role": "user", "content": f"<image>\n{prompt}"} | |
] | |
rendered = tok.apply_chat_template( | |
messages, add_generation_prompt=True, tokenize=False | |
) | |
pre, post = rendered.split("<image>", 1) | |
# Tokenize the text around the image token | |
pre_ids = tok(pre, return_tensors="pt", add_special_tokens=False).input_ids | |
post_ids = tok(post, return_tensors="pt", add_special_tokens=False).input_ids | |
# Splice in the IMAGE token id | |
img_tok = torch.tensor([[IMAGE_TOKEN_INDEX]], dtype=pre_ids.dtype) | |
input_ids = torch.cat([pre_ids, img_tok, post_ids], dim=1).to(model.device) | |
attention_mask = torch.ones_like(input_ids, device=model.device) | |
# Preprocess image | |
px = model.get_vision_tower().image_processor(images=image, return_tensors="pt")["pixel_values"] | |
px = px.to(model.device, dtype=model.dtype) | |
# Generate | |
with torch.no_grad(): | |
out = model.generate( | |
inputs=input_ids, | |
attention_mask=attention_mask, | |
images=px, | |
max_new_tokens=15, | |
temperature=0.7, | |
do_sample=True, | |
) | |
caption = tok.decode(out[0], skip_special_tokens=True) | |
# Extract only the generated part | |
if prompt in caption: | |
caption = caption.split(prompt)[-1].strip() | |
return caption | |
def process_video( | |
video_path: str, | |
num_frames: int, | |
sampling_method: str, | |
caption_mode: str, | |
custom_prompt: str, | |
progress=gr.Progress() | |
) -> tuple: | |
"""Process video and generate captions""" | |
if not video_path: | |
return "Please upload a video first.", None | |
progress(0, desc="Extracting frames...") | |
frames = extract_frames(video_path, num_frames, sampling_method) | |
if not frames: | |
return "Failed to extract frames from video.", None | |
# Use brief one-sentence prompt for faster processing | |
prompt = "Provide a brief one-sentence description of what's happening in this image." | |
captions = [] | |
frame_previews = [] | |
for i, frame in enumerate(frames): | |
progress((i + 1) / (len(frames) + 1), desc=f"Analyzing frame {i + 1}/{len(frames)}...") | |
caption = caption_frame(frame, prompt) | |
captions.append(f"Frame {i + 1}: {caption}") | |
frame_previews.append(frame) | |
progress(1.0, desc="Generating summary...") | |
# Combine captions into a simple narrative | |
full_caption = "\n".join(captions) | |
# Generate overall summary if multiple frames | |
if len(frames) > 1: | |
video_summary = f"Analyzed {len(frames)} frames:\n\n{full_caption}" | |
else: | |
video_summary = f"Video Analysis:\n\n{full_caption}" | |
return video_summary, frame_previews | |
# Create the Gradio interface | |
# Create custom Apple-inspired theme | |
class AppleTheme(gr.themes.Base): | |
def __init__(self): | |
super().__init__( | |
primary_hue=gr.themes.colors.blue, | |
secondary_hue=gr.themes.colors.gray, | |
neutral_hue=gr.themes.colors.gray, | |
spacing_size=gr.themes.sizes.spacing_md, | |
radius_size=gr.themes.sizes.radius_md, | |
text_size=gr.themes.sizes.text_md, | |
font=[ | |
gr.themes.GoogleFont("Inter"), | |
"-apple-system", | |
"BlinkMacSystemFont", | |
"SF Pro Display", | |
"SF Pro Text", | |
"Helvetica Neue", | |
"Helvetica", | |
"Arial", | |
"sans-serif" | |
], | |
font_mono=[ | |
gr.themes.GoogleFont("SF Mono"), | |
"ui-monospace", | |
"Consolas", | |
"monospace" | |
] | |
) | |
super().set( | |
# Core colors | |
body_background_fill="*neutral_50", | |
body_background_fill_dark="*neutral_950", | |
button_primary_background_fill="*primary_500", | |
button_primary_background_fill_hover="*primary_600", | |
button_primary_text_color="white", | |
button_primary_border_color="*primary_500", | |
# Shadows | |
block_shadow="0 4px 12px rgba(0, 0, 0, 0.08)", | |
# Borders | |
block_border_width="1px", | |
block_border_color="*neutral_200", | |
input_border_width="1px", | |
input_border_color="*neutral_300", | |
input_border_color_focus="*primary_500", | |
# Text | |
block_title_text_weight="600", | |
block_label_text_weight="500", | |
block_label_text_size="13px", | |
block_label_text_color="*neutral_600", | |
body_text_color="*neutral_900", | |
# Spacing | |
layout_gap="16px", | |
block_padding="20px", | |
# Specific components | |
slider_color="*primary_500", | |
) | |
# Create the Gradio interface with the custom theme | |
with gr.Blocks(theme=AppleTheme()) as demo: | |
gr.Markdown("# π¬ FastVLM Video Captioning") | |
with gr.Row(): | |
# Main video display | |
with gr.Column(scale=7): | |
video_display = gr.Video( | |
label="Video Input", | |
autoplay=True, | |
loop=True | |
) | |
# Sidebar with chat interface | |
with gr.Sidebar(width=400): | |
gr.Markdown("## π¬ Video Analysis Chat") | |
chatbot = gr.Chatbot( | |
value=[["Assistant", "Upload a video and I'll analyze it for you!"]], | |
height=400, | |
elem_classes=["chatbot"] | |
) | |
process_btn = gr.Button("π― Analyze Video", variant="primary", size="lg") | |
with gr.Accordion("πΌοΈ Analyzed Frames", open=False): | |
frame_gallery = gr.Gallery( | |
label="Extracted Frames", | |
show_label=False, | |
columns=2, | |
rows=4, | |
object_fit="contain", | |
height="auto" | |
) | |
# Hidden parameters with default values | |
num_frames = gr.State(value=8) | |
sampling_method = gr.State(value="uniform") | |
caption_mode = gr.State(value="Brief Summary") | |
custom_prompt = gr.State(value="") | |
# Upload handler | |
def handle_upload(video, chat_history): | |
if video: | |
chat_history.append(["User", "Video uploaded"]) | |
chat_history.append(["Assistant", "Video loaded! Click 'Analyze Video' to generate captions."]) | |
return video, chat_history | |
return None, chat_history | |
video_display.upload( | |
handle_upload, | |
inputs=[video_display, chatbot], | |
outputs=[video_display, chatbot] | |
) | |
# Modified process function to update chatbot with streaming | |
def process_video_with_chat(video_path, num_frames, sampling_method, caption_mode, custom_prompt, chat_history, progress=gr.Progress()): | |
if not video_path: | |
chat_history.append(["Assistant", "Please upload a video first."]) | |
yield chat_history, None | |
return | |
chat_history.append(["User", "Analyzing video..."]) | |
yield chat_history, None | |
# Extract frames | |
progress(0, desc="Extracting frames...") | |
frames = extract_frames(video_path, num_frames, sampling_method) | |
if not frames: | |
chat_history.append(["Assistant", "Failed to extract frames from video."]) | |
yield chat_history, None | |
return | |
# Start streaming response | |
chat_history.append(["Assistant", ""]) | |
prompt = "Provide a brief one-sentence description of what's happening in this image." | |
captions = [] | |
for i, frame in enumerate(frames): | |
progress((i + 1) / (len(frames) + 1), desc=f"Analyzing frame {i + 1}/{len(frames)}...") | |
caption = caption_frame(frame, prompt) | |
frame_caption = f"Frame {i + 1}: {caption}\n" | |
captions.append(frame_caption) | |
# Update the last message with accumulated captions | |
current_text = "".join(captions) | |
chat_history[-1] = ["Assistant", f"Analyzing {len(frames)} frames:\n\n{current_text}"] | |
yield chat_history, frames[:i+1] # Also update frame gallery progressively | |
progress(1.0, desc="Analysis complete!") | |
# Final update with complete message | |
full_caption = "".join(captions) | |
final_message = f"Analyzed {len(frames)} frames:\n\n{full_caption}" | |
chat_history[-1] = ["Assistant", final_message] | |
yield chat_history, frames | |
# Process button with streaming | |
process_btn.click( | |
process_video_with_chat, | |
inputs=[video_display, num_frames, sampling_method, caption_mode, custom_prompt, chatbot], | |
outputs=[chatbot, frame_gallery], | |
show_progress=True | |
) | |
demo.launch() |