File size: 4,403 Bytes
2aba93c
 
 
 
 
 
 
 
 
 
 
 
d6d3990
2aba93c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d6d3990
 
 
 
 
2aba93c
d6d3990
2aba93c
 
d6d3990
2aba93c
 
 
 
d6d3990
 
 
 
 
 
 
2aba93c
 
 
 
 
d6d3990
2aba93c
 
 
 
 
 
 
d6d3990
2aba93c
 
 
 
 
 
 
 
 
d6d3990
2aba93c
d6d3990
2aba93c
d6d3990
2aba93c
 
 
 
d6d3990
2aba93c
 
d6d3990
 
 
2aba93c
 
 
d6d3990
 
2aba93c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d6d3990
 
 
 
 
2aba93c
737c008
2aba93c
 
737c008
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import os
import cv2
import imutils
import torch
import timm
import einops
import tqdm
import numpy as np
import gradio as gr

from cotracker.utils.visualizer import Visualizer


def parse_video(video_file):
    vs = cv2.VideoCapture(video_file)

    frames = []
    while True:
        (gotit, frame) = vs.read()
        if frame is not None:
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            frames.append(frame)
        if not gotit:
            break

    return np.stack(frames)


def cotracker_demo(
    input_video,
    grid_size: int = 10,
    grid_query_frame: int = 0,
    tracks_leave_trace: bool = False,
):
    load_video = parse_video(input_video)
    grid_query_frame = min(len(load_video) - 1, grid_query_frame)
    load_video = torch.from_numpy(load_video).permute(0, 3, 1, 2)[None].float()

    model = torch.hub.load("facebookresearch/co-tracker", "cotracker2_online")

    if torch.cuda.is_available():
        model = model.cuda()
        load_video = load_video.cuda()

    model(video_chunk=load_video, is_first_step=True, grid_size=grid_size)
    for ind in range(0, load_video.shape[1] - model.step, model.step):
        pred_tracks, pred_visibility = model(
            video_chunk=load_video[:, ind : ind + model.step * 2]
        )  # B T N 2,  B T N 1

    linewidth = 2
    if grid_size < 10:
        linewidth = 4
    elif grid_size < 20:
        linewidth = 3

    vis = Visualizer(
        save_dir=os.path.join(os.path.dirname(__file__), "results"),
        grayscale=False,
        pad_value=100,
        fps=10,
        linewidth=linewidth,
        show_first_frame=5,
        tracks_leave_trace=-1 if tracks_leave_trace else 0,
    )
    import time

    def current_milli_time():
        return round(time.time() * 1000)

    filename = str(current_milli_time())
    vis.visualize(
        load_video.cpu(),
        tracks=pred_tracks.cpu(),
        visibility=pred_visibility.cpu(),
        filename=f"{filename}_pred_track",
        query_frame=grid_query_frame,
    )
    return os.path.join(
        os.path.dirname(__file__), "results", f"{filename}_pred_track.mp4"
    )


apple = os.path.join(os.path.dirname(__file__), "videos", "apple.mp4")
bear = os.path.join(os.path.dirname(__file__), "videos", "bear.mp4")
paragliding_launch = os.path.join(
    os.path.dirname(__file__), "videos", "paragliding-launch.mp4"
)
paragliding = os.path.join(os.path.dirname(__file__), "videos", "paragliding.mp4")

app = gr.Interface(
    title="🎨 CoTracker: It is Better to Track Together",
    description="<div style='text-align: left;'> \
    <p>Welcome to <a href='http://co-tracker.github.io' target='_blank'>CoTracker</a>! This space demonstrates point (pixel) tracking in videos. \
    Points are sampled on a regular grid and are tracked jointly. </p> \
    <p> To get started, simply upload your <b>.mp4</b> video in landscape orientation or click on one of the example videos to load them. The shorter the video, the faster the processing. We recommend submitting short videos of length <b>2-7 seconds</b>.</p> \
    <ul style='display: inline-block; text-align: left;'> \
        <li>The total number of grid points is the square of <b>Grid Size</b>.</li> \
        <li>To specify the starting frame for tracking, adjust <b>Grid Query Frame</b>. Tracks will be visualized only after the selected frame.</li> \
        <li>Check <b>Visualize Track Traces</b> to visualize traces of all the tracked points. </li> \
    </ul> \
    <p style='text-align: left'>For more details, check out our <a href='https://github.com/facebookresearch/co-tracker' target='_blank'>GitHub Repo</a> ⭐</p> \
    </div>",
    fn=cotracker_demo,
    inputs=[
        gr.Video(type="file", label="Input video", interactive=True),
        gr.Slider(minimum=1, maximum=30, step=1, value=10, label="Grid Size"),
        gr.Slider(minimum=0, maximum=30, step=1, default=0, label="Grid Query Frame"),
        gr.Checkbox(label="Visualize Track Traces"),
    ],
    outputs=gr.Video(label="Video with predicted tracks"),
    examples=[
        [apple, 10, 0, False, False],
        [apple, 20, 30, True, False],
        [bear, 10, 0, False, False],
        [paragliding, 10, 0, False, False],
        [paragliding_launch, 10, 0, False, False],
    ],
    cache_examples=True,
    allow_flagging=False,
)
app.queue(max_size=20, concurrency_count=2).launch(debug=True)