Spaces:
Runtime error
Runtime error
File size: 3,481 Bytes
af6180a a8f3a29 af6180a 4949f04 af6180a 2959057 af6180a 4df5632 a8f3a29 af6180a 2959057 af6180a 4949f04 af6180a 05d5be9 af6180a 05d5be9 af6180a a8f3a29 2959057 af6180a 4949f04 af6180a 4949f04 af6180a 4949f04 af6180a 4949f04 af6180a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
import os
import time
from omegaconf import OmegaConf
import torch
from scripts.evaluation.funcs import load_model_checkpoint, save_videos, batch_ddim_sampling
from utils.utils import instantiate_from_config
from huggingface_hub import hf_hub_download
class Text2Video():
def __init__(self,result_dir='./tmp/',gpu_num=1) -> None:
self.download_model()
self.result_dir = result_dir
if not os.path.exists(self.result_dir):
os.mkdir(self.result_dir)
ckpt_path='checkpoints/base_512_v2/model.ckpt'
config_file='configs/inference_t2v_512_v2.0.yaml'
config = OmegaConf.load(config_file)
model_config = config.pop("model", OmegaConf.create())
model_config['params']['unet_config']['params']['use_checkpoint']=False
model_list = []
for gpu_id in range(gpu_num):
model = instantiate_from_config(model_config)
# model = model.cuda(gpu_id)
assert os.path.exists(ckpt_path), "Error: checkpoint Not Found!"
model = load_model_checkpoint(model, ckpt_path)
model.eval()
model_list.append(model)
self.model_list = model_list
self.save_fps = 8
def get_prompt(self, prompt, steps=50, cfg_scale=12.0, eta=1.0, fps=16):
torch.cuda.empty_cache()
print('start:', prompt, time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time())))
start = time.time()
gpu_id=0
if steps > 60:
steps = 60
model = self.model_list[gpu_id]
model = model.cuda()
batch_size=1
channels = model.model.diffusion_model.in_channels
frames = model.temporal_length
h, w = 320 // 8, 512 // 8
noise_shape = [batch_size, channels, frames, h, w]
#prompts = batch_size * [""]
prompt = prompt+", professional, 4k, highly detailed"
neg_prompts = "graphite, impressionist, noisy, blurry, soft, deformed, ugly"
text_emb = model.get_learned_conditioning([prompt])
cond = {"c_crossattn": [text_emb], "fps": fps}
## inference
batch_samples = batch_ddim_sampling(model, cond, noise_shape, n_samples=1, ddim_steps=steps, ddim_eta=eta, cfg_scale=cfg_scale, neg_prompts=neg_prompts)
## b,samples,c,t,h,w
prompt_str = prompt.replace("/", "_slash_") if "/" in prompt else prompt
prompt_str = prompt_str.replace(" ", "_") if " " in prompt else prompt_str
prompt_str=prompt_str[:30]
save_videos(batch_samples, self.result_dir, filenames=[prompt_str], fps=self.save_fps)
print(f"Saved in {prompt_str}. Time used: {(time.time() - start):.2f} seconds")
model=model.cpu()
return os.path.join(self.result_dir, f"{prompt_str}.mp4")
def download_model(self):
REPO_ID = 'VideoCrafter/VideoCrafter2'
filename_list = ['model.ckpt']
if not os.path.exists('./checkpoints/base_512_v2/'):
os.makedirs('./checkpoints/base_512_v2/')
for filename in filename_list:
local_file = os.path.join('./checkpoints/base_512_v2/', filename)
if not os.path.exists(local_file):
hf_hub_download(repo_id=REPO_ID, filename=filename, local_dir='./checkpoints/base_512_v2/', local_dir_use_symlinks=False)
if __name__ == '__main__':
t2v = Text2Video()
video_path = t2v.get_prompt('a black swan swims on the pond')
print('done', video_path) |