Spaces:
Build error
Build error
Haoxin Chen
commited on
Commit
·
15190a9
1
Parent(s):
a8f3a29
fix ckpt path
Browse files- i2v_test.py +1 -1
- t2v_test.py +8 -8
i2v_test.py
CHANGED
@@ -68,7 +68,7 @@ class Image2Video():
|
|
68 |
return os.path.join(self.result_dir, f"{prompt_str}.mp4")
|
69 |
|
70 |
def download_model(self):
|
71 |
-
REPO_ID = 'VideoCrafter/Image2Video-512
|
72 |
filename_list = ['model.ckpt']
|
73 |
if not os.path.exists('./checkpoints/i2v_512_v1/'):
|
74 |
os.makedirs('./checkpoints/i2v_512_v1/')
|
|
|
68 |
return os.path.join(self.result_dir, f"{prompt_str}.mp4")
|
69 |
|
70 |
def download_model(self):
|
71 |
+
REPO_ID = 'VideoCrafter/Image2Video-512'
|
72 |
filename_list = ['model.ckpt']
|
73 |
if not os.path.exists('./checkpoints/i2v_512_v1/'):
|
74 |
os.makedirs('./checkpoints/i2v_512_v1/')
|
t2v_test.py
CHANGED
@@ -12,8 +12,8 @@ class Text2Video():
|
|
12 |
self.result_dir = result_dir
|
13 |
if not os.path.exists(self.result_dir):
|
14 |
os.mkdir(self.result_dir)
|
15 |
-
ckpt_path='checkpoints/
|
16 |
-
config_file='configs/
|
17 |
config = OmegaConf.load(config_file)
|
18 |
model_config = config.pop("model", OmegaConf.create())
|
19 |
model_config['params']['unet_config']['params']['use_checkpoint']=False
|
@@ -39,7 +39,7 @@ class Text2Video():
|
|
39 |
batch_size=1
|
40 |
channels = model.model.diffusion_model.in_channels
|
41 |
frames = model.temporal_length
|
42 |
-
h, w =
|
43 |
noise_shape = [batch_size, channels, frames, h, w]
|
44 |
|
45 |
#prompts = batch_size * [""]
|
@@ -59,15 +59,15 @@ class Text2Video():
|
|
59 |
return os.path.join(self.result_dir, f"{prompt_str}.mp4")
|
60 |
|
61 |
def download_model(self):
|
62 |
-
REPO_ID = 'VideoCrafter/Text2Video-
|
63 |
filename_list = ['model.ckpt']
|
64 |
-
if not os.path.exists('./checkpoints/
|
65 |
-
os.makedirs('./checkpoints/
|
66 |
for filename in filename_list:
|
67 |
-
local_file = os.path.join('./checkpoints/
|
68 |
|
69 |
if not os.path.exists(local_file):
|
70 |
-
hf_hub_download(repo_id=REPO_ID, filename=filename, local_dir='./checkpoints/
|
71 |
|
72 |
|
73 |
if __name__ == '__main__':
|
|
|
12 |
self.result_dir = result_dir
|
13 |
if not os.path.exists(self.result_dir):
|
14 |
os.mkdir(self.result_dir)
|
15 |
+
ckpt_path='checkpoints/base_1024_v1/model.ckpt'
|
16 |
+
config_file='configs/inference_t2v_1024_v1.0.yaml'
|
17 |
config = OmegaConf.load(config_file)
|
18 |
model_config = config.pop("model", OmegaConf.create())
|
19 |
model_config['params']['unet_config']['params']['use_checkpoint']=False
|
|
|
39 |
batch_size=1
|
40 |
channels = model.model.diffusion_model.in_channels
|
41 |
frames = model.temporal_length
|
42 |
+
h, w = 576 // 8, 1024 // 8
|
43 |
noise_shape = [batch_size, channels, frames, h, w]
|
44 |
|
45 |
#prompts = batch_size * [""]
|
|
|
59 |
return os.path.join(self.result_dir, f"{prompt_str}.mp4")
|
60 |
|
61 |
def download_model(self):
|
62 |
+
REPO_ID = 'VideoCrafter/Text2Video-1024'
|
63 |
filename_list = ['model.ckpt']
|
64 |
+
if not os.path.exists('./checkpoints/base_1024_v1/'):
|
65 |
+
os.makedirs('./checkpoints/base_1024_v1/')
|
66 |
for filename in filename_list:
|
67 |
+
local_file = os.path.join('./checkpoints/base_1024_v1/', filename)
|
68 |
|
69 |
if not os.path.exists(local_file):
|
70 |
+
hf_hub_download(repo_id=REPO_ID, filename=filename, local_dir='./checkpoints/base_1024_v1/', local_dir_use_symlinks=False)
|
71 |
|
72 |
|
73 |
if __name__ == '__main__':
|