Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
import copy | |
import os | |
from datetime import datetime | |
import gradio as gr | |
os.environ["TORCH_CUDNN_SDPA_ENABLED"] = "0,1,2,3,4,5,6,7" | |
import tempfile | |
import cv2 | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import spaces | |
import torch | |
from moviepy.editor import ImageSequenceClip | |
from PIL import Image | |
from sam2.build_sam import build_sam2_video_predictor | |
# Description | |
title = "<center><strong><font size='8'>EdgeTAM<font></strong> <a href='https://github.com/facebookresearch/EdgeTAM'><font size='6'>[GitHub]</font></a> </center>" | |
description_p = """# Instructions | |
<ol> | |
<li> Upload one video or click one example video</li> | |
<li> Click 'include' point type, select the object to segment and track</li> | |
<li> Click 'exclude' point type (optional), select the area you want to avoid segmenting and tracking</li> | |
<li> Click the 'Track' button to obtain the masked video </li> | |
</ol> | |
""" | |
# examples | |
examples = [ | |
["examples/01_dog.mp4"], | |
["examples/02_cups.mp4"], | |
["examples/03_blocks.mp4"], | |
["examples/04_coffee.mp4"], | |
["examples/05_default_juggle.mp4"], | |
["examples/01_breakdancer.mp4"], | |
["examples/02_hummingbird.mp4"], | |
["examples/03_skateboarder.mp4"], | |
["examples/04_octopus.mp4"], | |
["examples/05_landing_dog_soccer.mp4"], | |
["examples/06_pingpong.mp4"], | |
["examples/07_snowboarder.mp4"], | |
["examples/08_driving.mp4"], | |
["examples/09_birdcartoon.mp4"], | |
["examples/10_cloth_magic.mp4"], | |
["examples/11_polevault.mp4"], | |
["examples/12_hideandseek.mp4"], | |
["examples/13_butterfly.mp4"], | |
["examples/14_social_dog_training.mp4"], | |
["examples/15_cricket.mp4"], | |
["examples/16_robotarm.mp4"], | |
["examples/17_childrendancing.mp4"], | |
["examples/18_threedogs.mp4"], | |
["examples/19_cyclist.mp4"], | |
["examples/20_doughkneading.mp4"], | |
["examples/21_biker.mp4"], | |
["examples/22_dogskateboarder.mp4"], | |
["examples/23_racecar.mp4"], | |
["examples/24_clownfish.mp4"], | |
] | |
OBJ_ID = 0 | |
def get_predictor(session_state): | |
if "predictor" not in session_state: | |
sam2_checkpoint = "checkpoints/edgetam.pt" | |
model_cfg = "edgetam.yaml" | |
predictor = build_sam2_video_predictor( | |
model_cfg, sam2_checkpoint, device="cuda" | |
) | |
print("predictor loaded") | |
# use bfloat16 for the entire demo | |
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__() | |
if torch.cuda.get_device_properties(0).major >= 8: | |
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices) | |
torch.backends.cuda.matmul.allow_tf32 = True | |
torch.backends.cudnn.allow_tf32 = True | |
session_state["predictor"] = predictor | |
return session_state["predictor"] | |
def get_video_fps(video_path): | |
# Open the video file | |
cap = cv2.VideoCapture(video_path) | |
if not cap.isOpened(): | |
print("Error: Could not open video.") | |
return None | |
# Get the FPS of the video | |
fps = cap.get(cv2.CAP_PROP_FPS) | |
return fps | |
def reset(session_state): | |
predictor = get_predictor(session_state) | |
predictor.to("cuda") | |
session_state["input_points"] = [] | |
session_state["input_labels"] = [] | |
if session_state["inference_state"] is not None: | |
predictor.reset_state(session_state["inference_state"]) | |
session_state["first_frame"] = None | |
session_state["all_frames"] = None | |
session_state["inference_state"] = None | |
return ( | |
None, | |
gr.update(open=True), | |
None, | |
None, | |
gr.update(value=None, visible=False), | |
session_state, | |
) | |
def clear_points(session_state): | |
predictor = get_predictor(session_state) | |
predictor.to("cuda") | |
session_state["input_points"] = [] | |
session_state["input_labels"] = [] | |
if session_state["inference_state"]["tracking_has_started"]: | |
predictor.reset_state(session_state["inference_state"]) | |
return ( | |
session_state["first_frame"], | |
None, | |
gr.update(value=None, visible=False), | |
session_state, | |
) | |
def preprocess_video_in(video_path, session_state): | |
predictor = get_predictor(session_state) | |
predictor.to("cuda") | |
if video_path is None: | |
return ( | |
gr.update(open=True), # video_in_drawer | |
None, # points_map | |
None, # output_image | |
gr.update(value=None, visible=False), # output_video | |
session_state, | |
) | |
# Read the first frame | |
cap = cv2.VideoCapture(video_path) | |
if not cap.isOpened(): | |
print("Error: Could not open video.") | |
return ( | |
gr.update(open=True), # video_in_drawer | |
None, # points_map | |
None, # output_image | |
gr.update(value=None, visible=False), # output_video | |
session_state, | |
) | |
frame_number = 0 | |
first_frame = None | |
all_frames = [] | |
while True: | |
ret, frame = cap.read() | |
if not ret: | |
break | |
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
frame = np.array(frame) | |
# Store the first frame | |
if frame_number == 0: | |
first_frame = frame | |
all_frames.append(frame) | |
frame_number += 1 | |
cap.release() | |
session_state["first_frame"] = copy.deepcopy(first_frame) | |
session_state["all_frames"] = all_frames | |
session_state["inference_state"] = predictor.init_state(video_path=video_path) | |
session_state["input_points"] = [] | |
session_state["input_labels"] = [] | |
return [ | |
gr.update(open=False), # video_in_drawer | |
first_frame, # points_map | |
None, # output_image | |
gr.update(value=None, visible=False), # output_video | |
session_state, | |
] | |
def segment_with_points( | |
point_type, | |
session_state, | |
evt: gr.SelectData, | |
): | |
predictor = get_predictor(session_state) | |
predictor.to("cuda") | |
session_state["input_points"].append(evt.index) | |
print(f"TRACKING INPUT POINT: {session_state['input_points']}") | |
if point_type == "include": | |
session_state["input_labels"].append(1) | |
elif point_type == "exclude": | |
session_state["input_labels"].append(0) | |
print(f"TRACKING INPUT LABEL: {session_state['input_labels']}") | |
# Open the image and get its dimensions | |
transparent_background = Image.fromarray(session_state["first_frame"]).convert( | |
"RGBA" | |
) | |
w, h = transparent_background.size | |
# Define the circle radius as a fraction of the smaller dimension | |
fraction = 0.01 # You can adjust this value as needed | |
radius = int(fraction * min(w, h)) | |
# Create a transparent layer to draw on | |
transparent_layer = np.zeros((h, w, 4), dtype=np.uint8) | |
for index, track in enumerate(session_state["input_points"]): | |
if session_state["input_labels"][index] == 1: | |
cv2.circle(transparent_layer, track, radius, (0, 255, 0, 255), -1) | |
else: | |
cv2.circle(transparent_layer, track, radius, (255, 0, 0, 255), -1) | |
# Convert the transparent layer back to an image | |
transparent_layer = Image.fromarray(transparent_layer, "RGBA") | |
selected_point_map = Image.alpha_composite( | |
transparent_background, transparent_layer | |
) | |
# Let's add a positive click at (x, y) = (210, 350) to get started | |
points = np.array(session_state["input_points"], dtype=np.float32) | |
# for labels, `1` means positive click and `0` means negative click | |
labels = np.array(session_state["input_labels"], np.int32) | |
_, _, out_mask_logits = predictor.add_new_points( | |
inference_state=session_state["inference_state"], | |
frame_idx=0, | |
obj_id=OBJ_ID, | |
points=points, | |
labels=labels, | |
) | |
mask_image = show_mask((out_mask_logits[0] > 0.0).cpu().numpy()) | |
first_frame_output = Image.alpha_composite(transparent_background, mask_image) | |
torch.cuda.empty_cache() | |
return selected_point_map, first_frame_output, session_state | |
def show_mask(mask, obj_id=None, random_color=False, convert_to_image=True): | |
if random_color: | |
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) | |
else: | |
cmap = plt.get_cmap("tab10") | |
cmap_idx = 0 if obj_id is None else obj_id | |
color = np.array([*cmap(cmap_idx)[:3], 0.6]) | |
h, w = mask.shape[-2:] | |
mask = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) | |
mask = (mask * 255).astype(np.uint8) | |
if convert_to_image: | |
mask = Image.fromarray(mask, "RGBA") | |
return mask | |
def propagate_to_all( | |
video_in, | |
session_state, | |
): | |
predictor = get_predictor(session_state) | |
predictor.to("cuda") | |
if ( | |
len(session_state["input_points"]) == 0 | |
or video_in is None | |
or session_state["inference_state"] is None | |
): | |
return ( | |
None, | |
session_state, | |
) | |
# run propagation throughout the video and collect the results in a dict | |
video_segments = {} # video_segments contains the per-frame segmentation results | |
print("starting propagate_in_video") | |
for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video( | |
session_state["inference_state"] | |
): | |
video_segments[out_frame_idx] = { | |
out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy() | |
for i, out_obj_id in enumerate(out_obj_ids) | |
} | |
# obtain the segmentation results every few frames | |
vis_frame_stride = 1 | |
output_frames = [] | |
for out_frame_idx in range(0, len(video_segments), vis_frame_stride): | |
transparent_background = Image.fromarray( | |
session_state["all_frames"][out_frame_idx] | |
).convert("RGBA") | |
out_mask = video_segments[out_frame_idx][OBJ_ID] | |
mask_image = show_mask(out_mask) | |
output_frame = Image.alpha_composite(transparent_background, mask_image) | |
output_frame = np.array(output_frame) | |
output_frames.append(output_frame) | |
torch.cuda.empty_cache() | |
# Create a video clip from the image sequence | |
original_fps = get_video_fps(video_in) | |
fps = original_fps # Frames per second | |
clip = ImageSequenceClip(output_frames, fps=fps) | |
# Write the result to a file | |
unique_id = datetime.now().strftime("%Y%m%d%H%M%S") | |
final_vid_output_path = f"output_video_{unique_id}.mp4" | |
final_vid_output_path = os.path.join(tempfile.gettempdir(), final_vid_output_path) | |
# Write the result to a file | |
clip.write_videofile(final_vid_output_path, codec="libx264") | |
return ( | |
gr.update(value=final_vid_output_path), | |
session_state, | |
) | |
def update_ui(): | |
return gr.update(visible=True) | |
with gr.Blocks() as demo: | |
session_state = gr.State( | |
{ | |
"first_frame": None, | |
"all_frames": None, | |
"input_points": [], | |
"input_labels": [], | |
"inference_state": None, | |
} | |
) | |
with gr.Column(): | |
# Title | |
gr.Markdown(title) | |
with gr.Row(): | |
with gr.Column(): | |
# Instructions | |
gr.Markdown(description_p) | |
with gr.Accordion("Input Video", open=True) as video_in_drawer: | |
video_in = gr.Video(label="Input Video", format="mp4") | |
with gr.Row(): | |
point_type = gr.Radio( | |
label="point type", | |
choices=["include", "exclude"], | |
value="include", | |
scale=2, | |
) | |
propagate_btn = gr.Button("Track", scale=1, variant="primary") | |
clear_points_btn = gr.Button("Clear Points", scale=1) | |
reset_btn = gr.Button("Reset", scale=1) | |
points_map = gr.Image( | |
label="Frame with Point Prompt", type="numpy", interactive=False | |
) | |
with gr.Column(): | |
gr.Markdown("# Try some of the examples below ⬇️") | |
gr.Examples( | |
examples=examples, | |
inputs=[ | |
video_in, | |
], | |
examples_per_page=8, | |
) | |
gr.Markdown("\n\n\n\n\n\n\n\n\n\n\n") | |
gr.Markdown("\n\n\n\n\n\n\n\n\n\n\n") | |
gr.Markdown("\n\n\n\n\n\n\n\n\n\n\n") | |
output_image = gr.Image(label="Reference Mask") | |
output_video = gr.Video(visible=False) | |
# When new video is uploaded | |
video_in.upload( | |
fn=preprocess_video_in, | |
inputs=[ | |
video_in, | |
session_state, | |
], | |
outputs=[ | |
video_in_drawer, # Accordion to hide uploaded video player | |
points_map, # Image component where we add new tracking points | |
output_image, | |
output_video, | |
session_state, | |
], | |
queue=False, | |
) | |
video_in.change( | |
fn=preprocess_video_in, | |
inputs=[ | |
video_in, | |
session_state, | |
], | |
outputs=[ | |
video_in_drawer, # Accordion to hide uploaded video player | |
points_map, # Image component where we add new tracking points | |
output_image, | |
output_video, | |
session_state, | |
], | |
queue=False, | |
) | |
# triggered when we click on image to add new points | |
points_map.select( | |
fn=segment_with_points, | |
inputs=[ | |
point_type, # "include" or "exclude" | |
session_state, | |
], | |
outputs=[ | |
points_map, # updated image with points | |
output_image, | |
session_state, | |
], | |
queue=False, | |
) | |
# Clear every points clicked and added to the map | |
clear_points_btn.click( | |
fn=clear_points, | |
inputs=session_state, | |
outputs=[ | |
points_map, | |
output_image, | |
output_video, | |
session_state, | |
], | |
queue=False, | |
) | |
reset_btn.click( | |
fn=reset, | |
inputs=session_state, | |
outputs=[ | |
video_in, | |
video_in_drawer, | |
points_map, | |
output_image, | |
output_video, | |
session_state, | |
], | |
queue=False, | |
) | |
propagate_btn.click( | |
fn=update_ui, | |
inputs=[], | |
outputs=output_video, | |
queue=False, | |
).then( | |
fn=propagate_to_all, | |
inputs=[ | |
video_in, | |
session_state, | |
], | |
outputs=[ | |
output_video, | |
session_state, | |
], | |
concurrency_limit=10, | |
queue=False, | |
) | |
demo.queue() | |
demo.launch() | |