|
import gradio as gr |
|
from transformers import AutoProcessor, PaliGemmaForConditionalGeneration |
|
from PIL import Image |
|
import cv2 |
|
import torch |
|
|
|
|
|
mix_model_id = "google/paligemma-3b-mix-224" |
|
mix_model = PaliGemmaForConditionalGeneration.from_pretrained(mix_model_id) |
|
mix_processor = AutoProcessor.from_pretrained(mix_model_id) |
|
|
|
|
|
def extract_frames(video_path, frame_interval=1): |
|
|
|
vidcap = cv2.VideoCapture(video_path) |
|
frames = [] |
|
success, image = vidcap.read() |
|
count = 0 |
|
|
|
while success: |
|
|
|
if count % frame_interval == 0: |
|
frames.append(image) |
|
success, image = vidcap.read() |
|
count += 1 |
|
|
|
vidcap.release() |
|
return frames |
|
|
|
|
|
def process_video(video, prompt): |
|
|
|
frames = extract_frames(video, frame_interval=10) |
|
|
|
captions = [] |
|
|
|
for frame in frames: |
|
|
|
image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) |
|
inputs = mix_processor(image.convert("RGB"), prompt, return_tensors="pt") |
|
|
|
try: |
|
|
|
output = mix_model.generate(**inputs, max_new_tokens=20) |
|
|
|
|
|
decoded_output = mix_processor.decode(output[0], skip_special_tokens=True) |
|
captions.append(decoded_output[len(prompt):]) |
|
except IndexError as e: |
|
print(f"IndexError: {e}") |
|
captions.append("Error processing frame") |
|
|
|
|
|
return " ".join(captions) |
|
|
|
|
|
inputs = [ |
|
gr.Video(label="Upload Video"), |
|
gr.Textbox(label="Prompt", placeholder="Enter your question") |
|
] |
|
outputs = gr.Textbox(label="Generated Caption") |
|
|
|
|
|
demo = gr.Interface(fn=process_video, inputs=inputs, outputs=outputs, title="Video Captioning with Mix PaliGemma Model", |
|
description="Upload a video and get captions based on your prompt.") |
|
|
|
|
|
demo.launch(debug=True) |