|
import gradio as gr |
|
import torch |
|
from PIL import Image |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
import cv2 |
|
import numpy as np |
|
import logging |
|
from huggingface_hub import HfApi, upload_file |
|
import uuid |
|
import os |
|
|
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format="%(asctime)s [%(levelname)s] %(message)s", |
|
handlers=[logging.StreamHandler()] |
|
) |
|
|
|
MID = "apple/FastVLM-7B" |
|
IMAGE_TOKEN_INDEX = -200 |
|
|
|
|
|
HF_MODEL = os.environ.get("HF_UPLOAD_REPO", "rahul7star/ImageExplain") |
|
|
|
tok = None |
|
model = None |
|
|
|
|
|
def load_model(): |
|
global tok, model |
|
if tok is None or model is None: |
|
logging.info("Loading FastVLM model (CPU only)...") |
|
tok = AutoTokenizer.from_pretrained(MID, trust_remote_code=True) |
|
model = AutoModelForCausalLM.from_pretrained( |
|
MID, |
|
torch_dtype=torch.float32, |
|
device_map="cpu", |
|
trust_remote_code=True, |
|
) |
|
logging.info("β
Model loaded successfully on CPU") |
|
return tok, model |
|
|
|
|
|
def extract_frames(video_path: str, num_frames: int = 8, sampling_method: str = "uniform"): |
|
logging.info(f"Extracting up to {num_frames} frames using '{sampling_method}' sampling") |
|
cap = cv2.VideoCapture(video_path) |
|
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
|
logging.info(f"Total frames in video: {total_frames}") |
|
|
|
if total_frames == 0: |
|
cap.release() |
|
logging.warning("β οΈ No frames found in video") |
|
return [] |
|
|
|
frames = [] |
|
if sampling_method == "uniform": |
|
indices = np.linspace(0, total_frames - 1, num_frames, dtype=int) |
|
elif sampling_method == "first": |
|
indices = list(range(min(num_frames, total_frames))) |
|
elif sampling_method == "last": |
|
start = max(0, total_frames - num_frames) |
|
indices = list(range(start, total_frames)) |
|
else: |
|
start = max(0, (total_frames - num_frames) // 2) |
|
indices = list(range(start, min(start + num_frames, total_frames))) |
|
|
|
logging.info(f"Selected frame indices: {indices}") |
|
|
|
for idx in indices: |
|
cap.set(cv2.CAP_PROP_POS_FRAMES, idx) |
|
ret, frame = cap.read() |
|
if ret: |
|
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
|
frames.append(Image.fromarray(frame_rgb)) |
|
logging.info(f"β
Extracted frame {idx}") |
|
else: |
|
logging.warning(f"β οΈ Failed to extract frame {idx}") |
|
|
|
cap.release() |
|
return frames |
|
|
|
|
|
def caption_frame(image: Image.Image, prompt: str) -> str: |
|
tok, model = load_model() |
|
logging.info(f"Captioning frame with prompt: {prompt!r}") |
|
|
|
messages = [{"role": "user", "content": f"<image>\n{prompt}"}] |
|
rendered = tok.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) |
|
pre, post = rendered.split("<image>", 1) |
|
|
|
pre_ids = tok(pre, return_tensors="pt", add_special_tokens=False).input_ids |
|
post_ids = tok(post, return_tensors="pt", add_special_tokens=False).input_ids |
|
img_tok = torch.tensor([[IMAGE_TOKEN_INDEX]], dtype=pre_ids.dtype) |
|
input_ids = torch.cat([pre_ids, img_tok, post_ids], dim=1) |
|
|
|
attention_mask = torch.ones_like(input_ids) |
|
px = model.get_vision_tower().image_processor(images=image, return_tensors="pt")["pixel_values"] |
|
|
|
with torch.no_grad(): |
|
out = model.generate( |
|
inputs=input_ids, |
|
attention_mask=attention_mask, |
|
images=px, |
|
max_new_tokens=15, |
|
temperature=0.7, |
|
do_sample=True, |
|
) |
|
|
|
raw_output = tok.decode(out[0], skip_special_tokens=True) |
|
caption = raw_output |
|
if prompt in caption: |
|
caption = caption.split(prompt)[-1].strip() |
|
|
|
logging.info(f"β
Final cleaned caption: {caption!r}") |
|
return caption |
|
|
|
|
|
from huggingface_hub import HfApi, upload_file |
|
import os |
|
import uuid |
|
import os |
|
import uuid |
|
import logging |
|
from datetime import datetime |
|
import tempfile |
|
from huggingface_hub import HfApi, upload_file |
|
|
|
def upload_to_hf(video_path, summary_text): |
|
api = HfApi() |
|
|
|
today_str = datetime.now().strftime("%Y-%m-%d") |
|
date_folder = f"{today_str}-APPLE-Video_FOLDER" |
|
|
|
|
|
unique_subfolder = f"upload_{uuid.uuid4().hex[:8]}" |
|
hf_folder = f"{date_folder}/{unique_subfolder}" |
|
logging.info(f"Uploading files to HF folder: {hf_folder} in repo {HF_MODEL}") |
|
|
|
|
|
video_filename = os.path.basename(video_path) |
|
video_hf_path = f"{hf_folder}/{video_filename}" |
|
upload_file( |
|
path_or_fileobj=video_path, |
|
path_in_repo=video_hf_path, |
|
repo_id=HF_MODEL, |
|
repo_type="model", |
|
token=os.environ.get("HUGGINGFACE_HUB_TOKEN"), |
|
) |
|
logging.info(f"β
Uploaded video to HF: {video_hf_path}") |
|
|
|
|
|
summary_file = "/tmp/summary.txt" |
|
with open(summary_file, "w", encoding="utf-8") as f: |
|
f.write(summary_text) |
|
|
|
summary_hf_path = f"{hf_folder}/summary.txt" |
|
upload_file( |
|
path_or_fileobj=summary_file, |
|
path_in_repo=summary_hf_path, |
|
repo_id=HF_MODEL, |
|
repo_type="model", |
|
token=os.environ.get("HUGGINGFACE_HUB_TOKEN"), |
|
) |
|
logging.info(f"β
Uploaded summary to HF: {summary_hf_path}") |
|
|
|
return hf_folder |
|
|
|
|
|
def process_video(video_path, num_frames, sampling_method, chat_history, progress=gr.Progress()): |
|
if not video_path: |
|
chat_history.append(["Assistant", "Please upload a video first."]) |
|
logging.warning("No video uploaded") |
|
return chat_history, None |
|
|
|
logging.info(f"Starting analysis of video: {video_path}") |
|
progress(0, desc="Extracting frames...") |
|
frames = extract_frames(video_path, num_frames, sampling_method) |
|
|
|
if not frames: |
|
chat_history.append(["Assistant", "Failed to extract frames."]) |
|
logging.error("No frames extracted") |
|
return chat_history, None |
|
|
|
prompt = "Provide a brief one-sentence description of what's happening in this image." |
|
captions = [] |
|
|
|
chat_history.append(["Assistant", "Analyzing frames..."]) |
|
for i, frame in enumerate(frames): |
|
caption = caption_frame(frame, prompt) |
|
captions.append(f"Frame {i+1}: {caption}") |
|
chat_history[-1] = ["Assistant", "\n".join(captions)] |
|
progress((i + 1) / len(frames)) |
|
logging.info(f"Progress: frame {i+1}/{len(frames)} analyzed") |
|
|
|
final_summary = "\n".join(captions) |
|
logging.info("β
Video analysis complete") |
|
logging.info(f"Final summary:\n{final_summary}") |
|
|
|
|
|
hf_folder = upload_to_hf(video_path, final_summary) |
|
chat_history.append(["Assistant", f"β
Video and summary uploaded to HF folder: {hf_folder}"]) |
|
|
|
progress(1.0, desc="Analysis complete!") |
|
return chat_history, frames |
|
|
|
|
|
class AppleTheme(gr.themes.Base): |
|
def __init__(self): |
|
super().__init__( |
|
primary_hue=gr.themes.colors.blue, |
|
secondary_hue=gr.themes.colors.gray, |
|
neutral_hue=gr.themes.colors.gray, |
|
) |
|
|
|
with gr.Blocks(theme=AppleTheme()) as demo: |
|
gr.Markdown("# π¬ FastVLM Video Captioning (CPU Only, with HF Upload)") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=7): |
|
video_display = gr.Video(label="Video Input", autoplay=True, loop=True) |
|
|
|
with gr.Sidebar(width=400): |
|
chatbot = gr.Chatbot( |
|
value=[["Assistant", "Upload a video and I'll analyze it for you!"]], |
|
height=400 |
|
) |
|
process_btn = gr.Button("π― Analyze Video", variant="primary") |
|
|
|
with gr.Accordion("πΌοΈ Analyzed Frames", open=False): |
|
frame_gallery = gr.Gallery(columns=2, rows=4, height="auto") |
|
|
|
num_frames = gr.State(value=4) |
|
sampling_method = gr.State(value="uniform") |
|
|
|
process_btn.click( |
|
fn=process_video, |
|
inputs=[video_display, num_frames, sampling_method, chatbot], |
|
outputs=[chatbot, frame_gallery], |
|
show_progress=True |
|
) |
|
|
|
demo.launch(server_name="0.0.0.0", server_port=7860, share=False, show_error=True) |
|
|