Spaces:
Runtime error
Runtime error
import sys | |
sys.path.insert(0, './') | |
import decord | |
import numpy as np | |
import torch | |
import os | |
from lavila.data.video_transforms import Permute | |
from lavila.data.datasets import get_frame_ids, video_loader_by_frames | |
from lavila.models.models import VCLM_OPENAI_TIMESFORMER_BASE_GPT2 | |
from lavila.models.tokenizer import MyGPT2Tokenizer | |
from collections import OrderedDict | |
import torch | |
import torchvision.transforms as transforms | |
import torchvision.transforms._transforms_video as transforms_video | |
import gradio as gr | |
def get_frame_ids(start_frame, end_frame, num_segments=32, jitter=True): | |
seg_size = float(end_frame - start_frame - 1) / num_segments | |
seq = [] | |
for i in range(num_segments): | |
start = int(np.round(seg_size * i) + start_frame) | |
end = int(np.round(seg_size * (i + 1)) + start_frame) | |
end = min(end, end_frame) | |
if jitter: | |
frame_id = np.random.randint(low=start, high=(end + 1)) | |
else: | |
frame_id = (start + end) // 2 | |
seq.append(frame_id) | |
return seq | |
def video_loader_by_frames(root, vid, frame_ids): | |
vr = decord.VideoReader(os.path.join(root, vid)) | |
try: | |
frames = vr.get_batch(frame_ids).asnumpy() | |
frames = [torch.tensor(frame, dtype=torch.float32) for frame in frames] | |
except (IndexError, decord.DECORDError) as error: | |
print(error) | |
print("Erroneous video: ", vid) | |
frames = [torch.zeros((240, 320, 3)) for _ in range(len(frame_ids))] | |
return torch.stack(frames, dim=0) | |
def iter_clips(video_path, num_segments=4, stride_size=16): | |
# The video is represented by `num_seg=4` frames | |
vr = decord.VideoReader(video_path) | |
frame_sample_size = num_segments * stride_size | |
max_start_frame = len(vr) - frame_sample_size | |
curr_frame = 0 | |
fps = vr.get_avg_fps() | |
while curr_frame == 0 or curr_frame < max_start_frame: | |
stop_frame = min(curr_frame + frame_sample_size, len(vr)) | |
curr_sec, stop_sec = curr_frame / fps, stop_frame / fps | |
frame_ids = get_frame_ids(curr_frame, stop_frame, num_segments=num_segments, jitter=False) | |
frames = video_loader_by_frames('./', video_path, frame_ids) | |
yield curr_sec, stop_sec, frames | |
curr_frame += frame_sample_size | |
class Pipeline: | |
def __init__(self, path=""): | |
ckpt_path = os.path.join(path, 'vclm_openai_timesformer_base_gpt2_base.pt_ego4d.jobid_319630.ep_0002.md5sum_68a71f.pth') | |
ckpt = torch.load(ckpt_path, map_location='cpu') | |
state_dict = OrderedDict() | |
for k, v in ckpt['state_dict'].items(): | |
state_dict[k.replace('module.', '')] = v | |
self.device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
self.model = VCLM_OPENAI_TIMESFORMER_BASE_GPT2( | |
text_use_cls_token=False, | |
project_embed_dim=256, | |
gated_xattn=True, | |
timesformer_gated_xattn=False, | |
freeze_lm_vclm=False, | |
freeze_visual_vclm=False, | |
freeze_visual_vclm_temporal=False, | |
num_frames=4, | |
drop_path_rate=0. | |
) | |
self.model.load_state_dict(state_dict, strict=True) | |
self.model.to(self.device) | |
self.model.eval() | |
self.tokenizer = MyGPT2Tokenizer('gpt2', add_bos=True) | |
crop_size = 224 | |
self.val_transform = transforms.Compose([ | |
Permute([3, 0, 1, 2]), | |
transforms.Resize(crop_size), | |
transforms.CenterCrop(crop_size), | |
transforms_video.NormalizeVideo(mean=[108.3272985, 116.7460125, 104.09373615000001], std=[68.5005327, 66.6321579, 70.32316305]) | |
]) | |
def decode_one(self, generated_ids, tokenizer): | |
# get the index of <EOS> | |
if tokenizer.eos_token_id == tokenizer.bos_token_id: | |
if tokenizer.eos_token_id in generated_ids[1:].tolist(): | |
eos_id = generated_ids[1:].tolist().index(tokenizer.eos_token_id) + 1 | |
else: | |
eos_id = len(generated_ids.tolist()) - 1 | |
elif tokenizer.eos_token_id in generated_ids.tolist(): | |
eos_id = generated_ids.tolist().index(tokenizer.eos_token_id) | |
else: | |
eos_id = len(generated_ids.tolist()) - 1 | |
generated_text_str = tokenizer.tokenizer.decode(generated_ids[1:eos_id].tolist()) | |
return generated_text_str | |
def __call__(self, video_path, temperature=0.7, top_p=0.95, max_text_length=77, num_return_sequences=10): | |
text = "" | |
MAX_ITERATIONS = 5 | |
with torch.autocast(self.device): | |
for clip_idx, (start, stop, frames) in enumerate(iter_clips(video_path)): | |
text_to_add = f"{'-'*30} Predictions From: {start:2.3f}-{stop:2.3f} seconds {'-'*30}\n" | |
print(text_to_add) | |
text += text_to_add | |
frames = self.val_transform(frames).unsqueeze(0) | |
if self.device == 'cuda': | |
frames = frames.to(self.device).half() | |
with torch.no_grad(): | |
image_features = self.model.encode_image(frames) | |
generated_text_ids, ppls = self.model.generate( | |
image_features, | |
self.tokenizer, | |
target=None, # free-form generation | |
max_text_length=max_text_length, | |
top_k=None, | |
top_p=top_p, # nucleus sampling | |
num_return_sequences=num_return_sequences, # number of candidates: 10 | |
temperature=temperature, | |
early_stopping=True, | |
) | |
for i in range(num_return_sequences): | |
generated_text_str = self.decode_one(generated_text_ids[i], self.tokenizer) | |
text_to_add = '\t{}: {}\n'.format(i, generated_text_str) | |
print(text_to_add) | |
text += text_to_add | |
if (clip_idx+1) >= MAX_ITERATIONS: | |
return text | |
return text | |
title = "LaViLa" | |
description = """LaViLa (**L**anguage **a**ugmented **Vi**deo **La**nguage Pretraining) is a new approach to learning video representations from Large Language Models (LLMs). We repurpose LLMs to be visually conditioned "Narrators", and use them to automatically generate video-language paired data. We use this data to then learn a video-langauge representation, outperforming prior work by large margins. \nGradio Demo for LaVila. To use it, simply upload your video, or click one of the examples to load them. Read more at the links below.""" | |
article = "<p style='text-align: center'><a href='https://github.com/facebookresearch/LaViLa' target='_blank'>Github Repo</a> | <a href='https://arxiv.org/abs/2212.04501' target='_blank'>Paper on arxiv</a></p><center><img src='https://visitor-badge.glitch.me/badge?page_id=nateraw_lavila' alt='visitor badge'></center></p>" | |
interface = gr.Interface( | |
Pipeline(), | |
inputs=[ | |
gr.Video(label='video_path'), | |
gr.Slider(0.0, 1.0, 0.7, label='temperature'), | |
gr.Slider(0.0, 1.0, 0.95, label='top_p'), | |
], | |
outputs='text', | |
examples=[['eating_spaghetti.mp4', 0.7, 0.95], ['assets/3c0dffd0-e38e-4643-bc48-d513943dc20b_012_014.mp4', 0.7, 0.95]], | |
title=title, | |
description=description, | |
article=article, | |
).launch(debug=True) |