|
|
|
|
|
|
|
|
|
|
|
|
|
import copy |
|
import os |
|
from datetime import datetime |
|
|
|
import gradio as gr |
|
|
|
os.environ["TORCH_CUDNN_SDPA_ENABLED"] = "0,1,2,3,4,5,6,7" |
|
import tempfile |
|
|
|
import cv2 |
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
import spaces |
|
import torch |
|
|
|
from moviepy.editor import ImageSequenceClip |
|
from PIL import Image |
|
from sam2.build_sam import build_sam2_video_predictor |
|
|
|
|
|
title = "<center><strong><font size='8'>EdgeTAM<font></strong> <a href='https://github.com/facebookresearch/EdgeTAM'><font size='6'>[GitHub]</font></a> </center>" |
|
|
|
description_p = """# Instructions |
|
<ol> |
|
<li> Upload one video or click one example video</li> |
|
<li> Click 'include' point type, select the object to segment and track</li> |
|
<li> Click 'exclude' point type (optional), select the area you want to avoid segmenting and tracking</li> |
|
<li> Click the 'Track' button to obtain the masked video </li> |
|
</ol> |
|
""" |
|
|
|
|
|
examples = [ |
|
["examples/01_dog.mp4"], |
|
["examples/02_cups.mp4"], |
|
["examples/03_blocks.mp4"], |
|
["examples/04_coffee.mp4"], |
|
["examples/05_default_juggle.mp4"], |
|
["examples/01_breakdancer.mp4"], |
|
["examples/02_hummingbird.mp4"], |
|
["examples/03_skateboarder.mp4"], |
|
["examples/04_octopus.mp4"], |
|
["examples/05_landing_dog_soccer.mp4"], |
|
["examples/06_pingpong.mp4"], |
|
["examples/07_snowboarder.mp4"], |
|
["examples/08_driving.mp4"], |
|
["examples/09_birdcartoon.mp4"], |
|
["examples/10_cloth_magic.mp4"], |
|
["examples/11_polevault.mp4"], |
|
["examples/12_hideandseek.mp4"], |
|
["examples/13_butterfly.mp4"], |
|
["examples/14_social_dog_training.mp4"], |
|
["examples/15_cricket.mp4"], |
|
["examples/16_robotarm.mp4"], |
|
["examples/17_childrendancing.mp4"], |
|
["examples/18_threedogs.mp4"], |
|
["examples/19_cyclist.mp4"], |
|
["examples/20_doughkneading.mp4"], |
|
["examples/21_biker.mp4"], |
|
["examples/22_dogskateboarder.mp4"], |
|
["examples/23_racecar.mp4"], |
|
["examples/24_clownfish.mp4"], |
|
] |
|
|
|
OBJ_ID = 0 |
|
|
|
|
|
sam2_checkpoint = "checkpoints/edgetam.pt" |
|
model_cfg = "edgetam.yaml" |
|
predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device="cpu") |
|
predictor.to("cuda") |
|
print("predictor loaded") |
|
|
|
|
|
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__() |
|
if torch.cuda.get_device_properties(0).major >= 8: |
|
|
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
torch.backends.cudnn.allow_tf32 = True |
|
|
|
|
|
def get_video_fps(video_path): |
|
|
|
cap = cv2.VideoCapture(video_path) |
|
|
|
if not cap.isOpened(): |
|
print("Error: Could not open video.") |
|
return None |
|
|
|
|
|
fps = cap.get(cv2.CAP_PROP_FPS) |
|
|
|
return fps |
|
|
|
|
|
def reset(session_state): |
|
session_state["input_points"] = [] |
|
session_state["input_labels"] = [] |
|
if session_state["inference_state"] is not None: |
|
predictor.reset_state(session_state["inference_state"]) |
|
session_state["first_frame"] = None |
|
session_state["all_frames"] = None |
|
session_state["inference_state"] = None |
|
return ( |
|
None, |
|
gr.update(open=True), |
|
None, |
|
None, |
|
gr.update(value=None, visible=False), |
|
session_state, |
|
) |
|
|
|
|
|
def clear_points(session_state): |
|
session_state["input_points"] = [] |
|
session_state["input_labels"] = [] |
|
if session_state["inference_state"]["tracking_has_started"]: |
|
predictor.reset_state(session_state["inference_state"]) |
|
return ( |
|
session_state["first_frame"], |
|
None, |
|
gr.update(value=None, visible=False), |
|
session_state, |
|
) |
|
|
|
|
|
@spaces.GPU |
|
def preprocess_video_in(video_path, session_state): |
|
if video_path is None: |
|
return ( |
|
gr.update(open=True), |
|
None, |
|
None, |
|
gr.update(value=None, visible=False), |
|
session_state, |
|
) |
|
|
|
|
|
cap = cv2.VideoCapture(video_path) |
|
if not cap.isOpened(): |
|
print("Error: Could not open video.") |
|
return ( |
|
gr.update(open=True), |
|
None, |
|
None, |
|
gr.update(value=None, visible=False), |
|
session_state, |
|
) |
|
|
|
frame_number = 0 |
|
first_frame = None |
|
all_frames = [] |
|
|
|
while True: |
|
ret, frame = cap.read() |
|
if not ret: |
|
break |
|
|
|
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
|
frame = np.array(frame) |
|
|
|
|
|
if frame_number == 0: |
|
first_frame = frame |
|
all_frames.append(frame) |
|
|
|
frame_number += 1 |
|
|
|
cap.release() |
|
session_state["first_frame"] = copy.deepcopy(first_frame) |
|
session_state["all_frames"] = all_frames |
|
|
|
session_state["inference_state"] = predictor.init_state(video_path=video_path) |
|
session_state["input_points"] = [] |
|
session_state["input_labels"] = [] |
|
|
|
return [ |
|
gr.update(open=False), |
|
first_frame, |
|
None, |
|
gr.update(value=None, visible=False), |
|
session_state, |
|
] |
|
|
|
|
|
@spaces.GPU |
|
def segment_with_points( |
|
point_type, |
|
session_state, |
|
evt: gr.SelectData, |
|
): |
|
session_state["input_points"].append(evt.index) |
|
print(f"TRACKING INPUT POINT: {session_state['input_points']}") |
|
|
|
if point_type == "include": |
|
session_state["input_labels"].append(1) |
|
elif point_type == "exclude": |
|
session_state["input_labels"].append(0) |
|
print(f"TRACKING INPUT LABEL: {session_state['input_labels']}") |
|
|
|
|
|
transparent_background = Image.fromarray(session_state["first_frame"]).convert( |
|
"RGBA" |
|
) |
|
w, h = transparent_background.size |
|
|
|
|
|
fraction = 0.01 |
|
radius = int(fraction * min(w, h)) |
|
|
|
|
|
transparent_layer = np.zeros((h, w, 4), dtype=np.uint8) |
|
|
|
for index, track in enumerate(session_state["input_points"]): |
|
if session_state["input_labels"][index] == 1: |
|
cv2.circle(transparent_layer, track, radius, (0, 255, 0, 255), -1) |
|
else: |
|
cv2.circle(transparent_layer, track, radius, (255, 0, 0, 255), -1) |
|
|
|
|
|
transparent_layer = Image.fromarray(transparent_layer, "RGBA") |
|
selected_point_map = Image.alpha_composite( |
|
transparent_background, transparent_layer |
|
) |
|
|
|
|
|
points = np.array(session_state["input_points"], dtype=np.float32) |
|
|
|
labels = np.array(session_state["input_labels"], np.int32) |
|
_, _, out_mask_logits = predictor.add_new_points( |
|
inference_state=session_state["inference_state"], |
|
frame_idx=0, |
|
obj_id=OBJ_ID, |
|
points=points, |
|
labels=labels, |
|
) |
|
|
|
mask_image = show_mask((out_mask_logits[0] > 0.0).cpu().numpy()) |
|
first_frame_output = Image.alpha_composite(transparent_background, mask_image) |
|
|
|
torch.cuda.empty_cache() |
|
return selected_point_map, first_frame_output, session_state |
|
|
|
|
|
def show_mask(mask, obj_id=None, random_color=False, convert_to_image=True): |
|
if random_color: |
|
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) |
|
else: |
|
cmap = plt.get_cmap("tab10") |
|
cmap_idx = 0 if obj_id is None else obj_id |
|
color = np.array([*cmap(cmap_idx)[:3], 0.6]) |
|
h, w = mask.shape[-2:] |
|
mask = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) |
|
mask = (mask * 255).astype(np.uint8) |
|
if convert_to_image: |
|
mask = Image.fromarray(mask, "RGBA") |
|
return mask |
|
|
|
|
|
@spaces.GPU |
|
def propagate_to_all( |
|
video_in, |
|
session_state, |
|
): |
|
if ( |
|
len(session_state["input_points"]) == 0 |
|
or video_in is None |
|
or session_state["inference_state"] is None |
|
): |
|
return ( |
|
None, |
|
session_state, |
|
) |
|
|
|
|
|
video_segments = {} |
|
print("starting propagate_in_video") |
|
for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video( |
|
session_state["inference_state"] |
|
): |
|
video_segments[out_frame_idx] = { |
|
out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy() |
|
for i, out_obj_id in enumerate(out_obj_ids) |
|
} |
|
|
|
|
|
vis_frame_stride = 1 |
|
|
|
output_frames = [] |
|
for out_frame_idx in range(0, len(video_segments), vis_frame_stride): |
|
transparent_background = Image.fromarray( |
|
session_state["all_frames"][out_frame_idx] |
|
).convert("RGBA") |
|
out_mask = video_segments[out_frame_idx][OBJ_ID] |
|
mask_image = show_mask(out_mask) |
|
output_frame = Image.alpha_composite(transparent_background, mask_image) |
|
output_frame = np.array(output_frame) |
|
output_frames.append(output_frame) |
|
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
original_fps = get_video_fps(video_in) |
|
fps = original_fps |
|
clip = ImageSequenceClip(output_frames, fps=fps) |
|
|
|
unique_id = datetime.now().strftime("%Y%m%d%H%M%S") |
|
final_vid_output_path = f"output_video_{unique_id}.mp4" |
|
final_vid_output_path = os.path.join(tempfile.gettempdir(), final_vid_output_path) |
|
|
|
|
|
clip.write_videofile(final_vid_output_path, codec="libx264") |
|
|
|
return ( |
|
gr.update(value=final_vid_output_path), |
|
session_state, |
|
) |
|
|
|
|
|
def update_ui(): |
|
return gr.update(visible=True) |
|
|
|
|
|
with gr.Blocks() as demo: |
|
session_state = gr.State( |
|
{ |
|
"first_frame": None, |
|
"all_frames": None, |
|
"input_points": [], |
|
"input_labels": [], |
|
"inference_state": None, |
|
} |
|
) |
|
|
|
with gr.Column(): |
|
|
|
gr.Markdown(title) |
|
with gr.Row(): |
|
|
|
with gr.Column(): |
|
|
|
gr.Markdown(description_p) |
|
|
|
with gr.Accordion("Input Video", open=True) as video_in_drawer: |
|
video_in = gr.Video(label="Input Video", format="mp4") |
|
|
|
with gr.Row(): |
|
point_type = gr.Radio( |
|
label="point type", |
|
choices=["include", "exclude"], |
|
value="include", |
|
scale=2, |
|
) |
|
propagate_btn = gr.Button("Track", scale=1, variant="primary") |
|
clear_points_btn = gr.Button("Clear Points", scale=1) |
|
reset_btn = gr.Button("Reset", scale=1) |
|
|
|
points_map = gr.Image( |
|
label="Frame with Point Prompt", type="numpy", interactive=False |
|
) |
|
|
|
with gr.Column(): |
|
gr.Markdown("# Try some of the examples below ⬇️") |
|
gr.Examples( |
|
examples=examples, |
|
inputs=[ |
|
video_in, |
|
], |
|
examples_per_page=8, |
|
) |
|
gr.Markdown("\n\n\n\n\n\n\n\n\n\n\n") |
|
gr.Markdown("\n\n\n\n\n\n\n\n\n\n\n") |
|
gr.Markdown("\n\n\n\n\n\n\n\n\n\n\n") |
|
output_image = gr.Image(label="Reference Mask") |
|
|
|
output_video = gr.Video(visible=False) |
|
|
|
|
|
video_in.upload( |
|
fn=preprocess_video_in, |
|
inputs=[ |
|
video_in, |
|
session_state, |
|
], |
|
outputs=[ |
|
video_in_drawer, |
|
points_map, |
|
output_image, |
|
output_video, |
|
session_state, |
|
], |
|
queue=False, |
|
) |
|
|
|
video_in.change( |
|
fn=preprocess_video_in, |
|
inputs=[ |
|
video_in, |
|
session_state, |
|
], |
|
outputs=[ |
|
video_in_drawer, |
|
points_map, |
|
output_image, |
|
output_video, |
|
session_state, |
|
], |
|
queue=False, |
|
) |
|
|
|
|
|
points_map.select( |
|
fn=segment_with_points, |
|
inputs=[ |
|
point_type, |
|
session_state, |
|
], |
|
outputs=[ |
|
points_map, |
|
output_image, |
|
session_state, |
|
], |
|
queue=False, |
|
) |
|
|
|
|
|
clear_points_btn.click( |
|
fn=clear_points, |
|
inputs=session_state, |
|
outputs=[ |
|
points_map, |
|
output_image, |
|
output_video, |
|
session_state, |
|
], |
|
queue=False, |
|
) |
|
|
|
reset_btn.click( |
|
fn=reset, |
|
inputs=session_state, |
|
outputs=[ |
|
video_in, |
|
video_in_drawer, |
|
points_map, |
|
output_image, |
|
output_video, |
|
session_state, |
|
], |
|
queue=False, |
|
) |
|
|
|
propagate_btn.click( |
|
fn=update_ui, |
|
inputs=[], |
|
outputs=output_video, |
|
queue=False, |
|
).then( |
|
fn=propagate_to_all, |
|
inputs=[ |
|
video_in, |
|
session_state, |
|
], |
|
outputs=[ |
|
output_video, |
|
session_state, |
|
], |
|
concurrency_limit=10, |
|
queue=False, |
|
) |
|
|
|
|
|
demo.queue() |
|
demo.launch() |
|
|