LeoXing1996 commited on
Commit
a001281
·
0 Parent(s):

init repo for fg

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +44 -0
  2. README.md +10 -0
  3. __assets__/image_animation/magnitude/1.mp4 +0 -0
  4. __assets__/image_animation/magnitude/2.mp4 +0 -0
  5. __assets__/image_animation/magnitude/3.mp4 +0 -0
  6. __assets__/image_animation/magnitude/genshin/3.mp4 +0 -0
  7. __assets__/image_animation/rcnz/1.mp4 +0 -0
  8. __assets__/image_animation/rcnz/2.mp4 +0 -0
  9. __assets__/image_animation/rcnz/3.mp4 +0 -0
  10. __assets__/image_animation/real/1.mp4 +0 -0
  11. __assets__/image_animation/real/2.mp4 +0 -0
  12. __assets__/image_animation/real/3.mp4 +0 -0
  13. __assets__/image_animation/style_transfer/anya/2.mp4 +0 -0
  14. __assets__/image_animation/yanhong/yanhong.mp4 +0 -0
  15. __assets__/image_animation/yanhong/yanhong.png +0 -0
  16. __assets__/image_animation/yiming/yiming.jpeg +0 -0
  17. __assets__/image_animation/yiming/yiming.mp4 +0 -0
  18. animatediff/data/dataset.py +128 -0
  19. animatediff/data/dataset_web.py +205 -0
  20. animatediff/data/video_transformer.py +407 -0
  21. animatediff/models/__init__.py +0 -0
  22. animatediff/models/attention.py +559 -0
  23. animatediff/models/motion_module.py +555 -0
  24. animatediff/models/resnet.py +197 -0
  25. animatediff/models/unet.py +572 -0
  26. animatediff/models/unet_blocks.py +733 -0
  27. animatediff/pipelines/__init__.py +5 -0
  28. animatediff/pipelines/i2v_pipeline.py +775 -0
  29. animatediff/pipelines/pipeline_animation.py +446 -0
  30. animatediff/pipelines/validation_pipeline.py +504 -0
  31. animatediff/utils/convert_from_ckpt.py +964 -0
  32. animatediff/utils/convert_lora_safetensor_to_diffusers.py +208 -0
  33. animatediff/utils/util.py +255 -0
  34. app-counterfeit-only.py +441 -0
  35. app-huggingface.py +525 -0
  36. app.py +567 -0
  37. benchmark.py +47 -0
  38. configs/indomain/base.yaml +14 -0
  39. configs/indomain/real.yaml +45 -0
  40. configs/inference/inference.yaml +26 -0
  41. configs/prompts/1-ToonYou.yaml +22 -0
  42. configs/prompts/1.yaml +20 -0
  43. configs/prompts/2-Lyriel.yaml +22 -0
  44. configs/prompts/3-RcnzCartoon.yaml +22 -0
  45. configs/prompts/4-MajicMix.yaml +22 -0
  46. configs/prompts/5-RealisticVision.yaml +22 -0
  47. configs/prompts/6-Tusun.yaml +20 -0
  48. configs/prompts/7-FilmVelvia.yaml +23 -0
  49. configs/prompts/8-GhibliBackground.yaml +20 -0
  50. configs/training/image_finetune.yaml +48 -0
.gitignore ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.pkl
2
+ *.pt
3
+ *.mov
4
+ *.pth
5
+ *.json
6
+ *.mov
7
+ *.npz
8
+ *.npy
9
+ *.boj
10
+ *.onnx
11
+ *.tar
12
+ *.bin
13
+ cache*
14
+ batch*
15
+ *.jpg
16
+ *.png
17
+ *.mp4
18
+ *.gif
19
+ *.ckpt
20
+ *.safetensors
21
+ *.zip
22
+ *.csv
23
+
24
+ **/__pycache__/
25
+ samples/
26
+ wandb/
27
+ outputs/
28
+
29
+ !pia.png
30
+ .DS_Store
31
+ !__assets__/image_animation/magnitude/1.mp4
32
+ !__assets__/image_animation/magnitude/2.mp4
33
+ !__assets__/image_animation/magnitude/3.mp4
34
+ !__assets__/image_animation/style_transfer/anya/2.mp4
35
+ !__assets__/image_animation/magnitude/genshin/3.mp4
36
+ !__assets__/image_animation/rcnz/1.mp4
37
+ !__assets__/image_animation/rcnz/2.mp4
38
+ !__assets__/image_animation/rcnz/3.mp4
39
+ !__assets__/image_animation/real/1.mp4
40
+ !__assets__/image_animation/real/2.mp4
41
+ !__assets__/image_animation/real/3.mp4
42
+ !__assets__/image_animation/yiming/yiming.mp4
43
+ !__assets__/image_animation/yanhong/yanhong.mp4
44
+ !__assets__/image_animation/yanhong/yanhong.png
README.md ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Demo Space
3
+ emoji: 🤗
4
+ colorFrom: yellow
5
+ colorTo: orange
6
+ sdk: gradio
7
+ sdk_version: 4.7.1
8
+ app_file: app-huggingface.py
9
+ pinned: false
10
+ ---
__assets__/image_animation/magnitude/1.mp4 ADDED
Binary file (217 kB). View file
 
__assets__/image_animation/magnitude/2.mp4 ADDED
Binary file (201 kB). View file
 
__assets__/image_animation/magnitude/3.mp4 ADDED
Binary file (230 kB). View file
 
__assets__/image_animation/magnitude/genshin/3.mp4 ADDED
Binary file (247 kB). View file
 
__assets__/image_animation/rcnz/1.mp4 ADDED
Binary file (187 kB). View file
 
__assets__/image_animation/rcnz/2.mp4 ADDED
Binary file (241 kB). View file
 
__assets__/image_animation/rcnz/3.mp4 ADDED
Binary file (182 kB). View file
 
__assets__/image_animation/real/1.mp4 ADDED
Binary file (194 kB). View file
 
__assets__/image_animation/real/2.mp4 ADDED
Binary file (172 kB). View file
 
__assets__/image_animation/real/3.mp4 ADDED
Binary file (355 kB). View file
 
__assets__/image_animation/style_transfer/anya/2.mp4 ADDED
Binary file (106 kB). View file
 
__assets__/image_animation/yanhong/yanhong.mp4 ADDED
Binary file (221 kB). View file
 
__assets__/image_animation/yanhong/yanhong.png ADDED
__assets__/image_animation/yiming/yiming.jpeg ADDED
__assets__/image_animation/yiming/yiming.mp4 ADDED
Binary file (97.7 kB). View file
 
animatediff/data/dataset.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, io, csv, math, random
2
+ import numpy as np
3
+ from einops import rearrange
4
+ from decord import VideoReader
5
+
6
+ import torch
7
+ import torchvision.transforms as transforms
8
+ from torch.utils.data.dataset import Dataset
9
+ from PIA.utils.util import zero_rank_print, detect_edges
10
+ import cv2
11
+
12
+ def get_score(video_data,
13
+ cond_frame_idx,
14
+ weight=[1.0, 1.0, 1.0, 1.0],
15
+ use_edge=True):
16
+ """
17
+ Similar to get_score under utils/util.py/detect_edges
18
+ """
19
+ """
20
+ the shape of video_data is f c h w, np.ndarray
21
+ """
22
+ h, w = video_data.shape[1], video_data.shape[2]
23
+
24
+ cond_frame = video_data[cond_frame_idx]
25
+ cond_hsv_list = list(
26
+ cv2.split(
27
+ cv2.cvtColor(cond_frame.astype(np.float32), cv2.COLOR_RGB2HSV)))
28
+
29
+ if use_edge:
30
+ cond_frame_lum = cond_hsv_list[-1]
31
+ cond_frame_edge = detect_edges(cond_frame_lum.astype(np.uint8))
32
+ cond_hsv_list.append(cond_frame_edge)
33
+
34
+ score_sum = []
35
+
36
+ for frame_idx in range(video_data.shape[0]):
37
+ frame = video_data[frame_idx]
38
+ hsv_list = list(
39
+ cv2.split(cv2.cvtColor(frame.astype(np.float32),
40
+ cv2.COLOR_RGB2HSV)))
41
+
42
+ if use_edge:
43
+ frame_img_lum = hsv_list[-1]
44
+ frame_img_edge = detect_edges(lum=frame_img_lum.astype(np.uint8))
45
+ hsv_list.append(frame_img_edge)
46
+
47
+ hsv_diff = [
48
+ np.abs(hsv_list[c] - cond_hsv_list[c]) for c in range(len(weight))
49
+ ]
50
+ hsv_mse = [np.sum(hsv_diff[c]) * weight[c] for c in range(len(weight))]
51
+ score_sum.append(sum(hsv_mse) / (h * w) / (sum(weight)))
52
+
53
+ return score_sum
54
+
55
+ class WebVid10M(Dataset):
56
+ def __init__(
57
+ self,
58
+ csv_path, video_folder,
59
+ sample_size=256, sample_stride=4, sample_n_frames=16,
60
+ is_image=False,
61
+ ):
62
+ zero_rank_print(f"loading annotations from {csv_path} ...")
63
+ with open(csv_path, 'r') as csvfile:
64
+ self.dataset = list(csv.DictReader(csvfile))
65
+ self.length = len(self.dataset)
66
+ zero_rank_print(f"data scale: {self.length}")
67
+
68
+ self.video_folder = video_folder
69
+ self.sample_stride = sample_stride
70
+ self.sample_n_frames = sample_n_frames
71
+ self.is_image = is_image
72
+
73
+ sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size)
74
+ self.pixel_transforms = transforms.Compose([
75
+ transforms.RandomHorizontalFlip(),
76
+ transforms.Resize(sample_size[0]),
77
+ transforms.CenterCrop(sample_size),
78
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
79
+ ])
80
+
81
+ def get_batch(self, idx):
82
+ video_dict = self.dataset[idx]
83
+ videoid, name, page_dir = video_dict['videoid'], video_dict['name'], video_dict['page_dir']
84
+
85
+ video_dir = os.path.join(self.video_folder, f"{videoid}.mp4")
86
+ video_reader = VideoReader(video_dir)
87
+ video_length = len(video_reader)
88
+ total_frames = len(video_reader)
89
+ clip_length = min(video_length, (self.sample_n_frames - 1) * self.sample_stride + 1)
90
+ start_idx = random.randint(0, video_length - clip_length)
91
+ batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.sample_n_frames, dtype=int)
92
+
93
+ frame_indice = [random.randint(0, total_frames - 1)]
94
+ pixel_values_np = video_reader.get_batch(frame_indice).asnumpy()
95
+ cond_frames = random.randint(0, self.sample_n_frames - 1)
96
+
97
+ # f h w c -> f c h w
98
+ pixel_values = torch.from_numpy(pixel_values_np).permute(0, 3, 1, 2).contiguous()
99
+ pixel_values = pixel_values / 255.
100
+ del video_reader
101
+
102
+ if self.is_image:
103
+ pixel_values = pixel_values[0]
104
+
105
+ return pixel_values, name, cond_frames, videoid
106
+
107
+ def __len__(self):
108
+ return self.length
109
+
110
+ def __getitem__(self, idx):
111
+ while True:
112
+ try:
113
+ video, name, cond_frames, videoid = self.get_batch(idx)
114
+ break
115
+
116
+ except Exception as e:
117
+ # zero_rank_print(e)
118
+ idx = random.randint(0, self.length-1)
119
+
120
+ video = self.pixel_transforms(video)
121
+ video_ = video.clone().permute(0, 2, 3, 1).numpy() / 2 + 0.5
122
+ video_ = video_ * 255
123
+ #video_ = video_.astype(np.uint8)
124
+ score = get_score(video_, cond_frame_idx=cond_frames)
125
+ del video_
126
+ sample = dict(pixel_values=video, text=name, score=score, cond_frames=cond_frames, vid=videoid)
127
+ return sample
128
+
animatediff/data/dataset_web.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import decord
2
+ import cv2
3
+
4
+ import os, io, csv, torch, math, random
5
+ from typing import Optional
6
+ from einops import rearrange
7
+ import numpy as np
8
+ from decord import VideoReader
9
+ from petrel_client.client import Client
10
+ from torch.utils.data.dataset import Dataset
11
+ import torchvision.transforms as transforms
12
+ from torch.utils.data.distributed import DistributedSampler
13
+
14
+ import animatediff.data.video_transformer as video_transforms
15
+ from animatediff.utils.util import zero_rank_print, detect_edges, prepare_mask_coef_by_score
16
+
17
+
18
+ def get_score(video_data,
19
+ cond_frame_idx,
20
+ weight=[1.0, 1.0, 1.0, 1.0],
21
+ use_edge=True):
22
+ """
23
+ Similar to get_score under utils/util.py/detect_edges
24
+ """
25
+ """
26
+ the shape of video_data is f c h w, np.ndarray
27
+ """
28
+ h, w = video_data.shape[1], video_data.shape[2]
29
+
30
+ cond_frame = video_data[cond_frame_idx]
31
+ cond_hsv_list = list(
32
+ cv2.split(
33
+ cv2.cvtColor(cond_frame.astype(np.float32), cv2.COLOR_RGB2HSV)))
34
+
35
+ if use_edge:
36
+ cond_frame_lum = cond_hsv_list[-1]
37
+ cond_frame_edge = detect_edges(cond_frame_lum.astype(np.uint8))
38
+ cond_hsv_list.append(cond_frame_edge)
39
+
40
+ score_sum = []
41
+
42
+ for frame_idx in range(video_data.shape[0]):
43
+ frame = video_data[frame_idx]
44
+ hsv_list = list(
45
+ cv2.split(cv2.cvtColor(frame.astype(np.float32),
46
+ cv2.COLOR_RGB2HSV)))
47
+
48
+ if use_edge:
49
+ frame_img_lum = hsv_list[-1]
50
+ frame_img_edge = detect_edges(lum=frame_img_lum.astype(np.uint8))
51
+ hsv_list.append(frame_img_edge)
52
+
53
+ hsv_diff = [
54
+ np.abs(hsv_list[c] - cond_hsv_list[c]) for c in range(len(weight))
55
+ ]
56
+ hsv_mse = [np.sum(hsv_diff[c]) * weight[c] for c in range(len(weight))]
57
+ score_sum.append(sum(hsv_mse) / (h * w) / (sum(weight)))
58
+
59
+ return score_sum
60
+
61
+
62
+ class WebVid10M(Dataset):
63
+ def __init__(
64
+ self,
65
+ csv_path,
66
+ sample_n_frames, sample_stride,
67
+ sample_size=[320,512],
68
+ conf_path="~/petreloss.conf",
69
+ static_video=False,
70
+ is_image=False,
71
+ ):
72
+ zero_rank_print(f"initializing ceph client ...")
73
+ self._client = Client(conf_path=conf_path, enable_mc=True)
74
+ self.sample_n_frames = sample_n_frames
75
+ self.sample_stride = sample_stride
76
+ self.temporal_sampler = video_transforms.TemporalRandomCrop(sample_n_frames * sample_stride)
77
+ self.static_video = static_video
78
+ self.is_image = is_image
79
+
80
+ zero_rank_print(f"(~1 mins) loading annotations from {csv_path} ...")
81
+ with open(csv_path, 'r') as csvfile:
82
+ self.dataset = list(csv.DictReader(csvfile))
83
+ self.length = len(self.dataset)
84
+ zero_rank_print(f"data scale: {self.length}")
85
+
86
+ sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size)
87
+ self.pixel_transforms = transforms.Compose([
88
+ transforms.RandomHorizontalFlip(),
89
+ transforms.Resize(sample_size[0]),
90
+ transforms.CenterCrop(sample_size),
91
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
92
+ ])
93
+
94
+ def get_batch(self, idx):
95
+
96
+ video_dict = self.dataset[idx]
97
+ videoid, name, page_dir = video_dict['videoid'], video_dict['name'], video_dict['page_dir']
98
+ ceph_dir = f"webvideo:s3://WebVid10M/{page_dir}/{videoid}.mp4"
99
+
100
+ video_bytes = self._client.Get(ceph_dir)
101
+ video_bytes = io.BytesIO(video_bytes)
102
+
103
+ # ensure not reading zero byte
104
+ assert video_bytes.getbuffer().nbytes != 0
105
+
106
+ video_reader = VideoReader(video_bytes)
107
+ total_frames = len(video_reader)
108
+
109
+ if not self.is_image:
110
+ if self.static_video:
111
+ frame_indice = random.randint(0, total_frames-1)
112
+ frame_indice = np.linspace(frame_indice, frame_indice, self.sample_n_frames, dtype=int)
113
+
114
+ else:
115
+ start_frame_ind, end_frame_ind = self.temporal_sampler(total_frames)
116
+ assert end_frame_ind - start_frame_ind >= self.sample_n_frames
117
+ frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, self.sample_n_frames, dtype=int)
118
+
119
+ else:
120
+ frame_indice = [random.randint(0, total_frames - 1)]
121
+
122
+ pixel_values_np = video_reader.get_batch(frame_indice).asnumpy()
123
+
124
+ cond_frames = random.randint(0, self.sample_n_frames - 1)
125
+
126
+ # f h w c -> f c h w
127
+ pixel_values = torch.from_numpy(pixel_values_np).permute(0, 3, 1, 2).contiguous()
128
+ pixel_values = pixel_values / 255.
129
+ del video_reader
130
+
131
+ if self.is_image:
132
+ pixel_values = pixel_values[0]
133
+
134
+ return pixel_values, name, cond_frames, videoid
135
+
136
+ def __len__(self):
137
+ return self.length
138
+
139
+ def __getitem__(self, idx):
140
+ while True:
141
+ try:
142
+ video, name, cond_frames, videoid = self.get_batch(idx)
143
+ break
144
+
145
+ except Exception as e:
146
+ # zero_rank_print(e)
147
+ idx = random.randint(0, self.length-1)
148
+
149
+ video = self.pixel_transforms(video)
150
+ video_ = video.clone().permute(0, 2, 3, 1).numpy() / 2 + 0.5
151
+ video_ = video_ * 255
152
+ #video_ = video_.astype(np.uint8)
153
+ score = get_score(video_, cond_frame_idx=cond_frames)
154
+ del video_
155
+ sample = dict(pixel_values=video, text=name, score=score, cond_frames=cond_frames, vid=videoid)
156
+ return sample
157
+
158
+
159
+
160
+ if __name__ == "__main__":
161
+ dataset = WebVid10M(
162
+ csv_path="results_10M_train.csv",
163
+ sample_size=(320,512),
164
+ sample_n_frames=16,
165
+ sample_stride=4,
166
+ static_video=False,
167
+ is_image=False,
168
+ )
169
+
170
+ distributed_sampler = DistributedSampler(
171
+ dataset,
172
+ num_replicas=1,
173
+ rank=0,
174
+ shuffle=True,
175
+ seed=5,
176
+ )
177
+ batch_size = 1
178
+ dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, num_workers=0, sampler=distributed_sampler)
179
+
180
+ STATISTIC = [[0., 0.],
181
+ [0.3535855, 24.23687346],
182
+ [0.91609545, 30.65091947],
183
+ [1.41165152, 34.40093286],
184
+ [1.56943881, 36.99639585],
185
+ [1.73182842, 39.42044163],
186
+ [1.82733002, 40.94703526],
187
+ [1.88060527, 42.66233244],
188
+ [1.96208071, 43.73070788],
189
+ [2.02723091, 44.25965378],
190
+ [2.10820894, 45.66120213],
191
+ [2.21115041, 46.29561324],
192
+ [2.23412351, 47.08810863],
193
+ [2.29430165, 47.9515062],
194
+ [2.32986362, 48.69085638],
195
+ [2.37310751, 49.19931439]]
196
+
197
+ for idx, batch in enumerate(dataloader):
198
+ pixel_values, texts, vid = batch['pixel_values'], batch['text'], batch['vid']
199
+ pixel_values = (pixel_values.clone()) / 2. + 0.5
200
+ pixel_values*= 255
201
+ score = get_score(pixel_values)
202
+ cond_frames = [0] * len(batch_size)
203
+ score = prepare_mask_coef_by_score(pixel_values, cond_frames, statistic=STATISTIC)
204
+ print(f'num: {idx}, diff: {score}')
205
+
animatediff/data/video_transformer.py ADDED
@@ -0,0 +1,407 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import random
3
+ import numbers
4
+ from torchvision.transforms import RandomCrop, RandomResizedCrop
5
+
6
+ def _is_tensor_video_clip(clip):
7
+ if not torch.is_tensor(clip):
8
+ raise TypeError("clip should be Tensor. Got %s" % type(clip))
9
+
10
+ if not clip.ndimension() == 4:
11
+ raise ValueError("clip should be 4D. Got %dD" % clip.dim())
12
+
13
+ return True
14
+
15
+
16
+ def crop(clip, i, j, h, w):
17
+ """
18
+ Args:
19
+ clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
20
+ """
21
+ if len(clip.size()) != 4:
22
+ raise ValueError("clip should be a 4D tensor")
23
+ return clip[..., i : i + h, j : j + w]
24
+
25
+
26
+ def resize(clip, target_size, interpolation_mode):
27
+ if len(target_size) != 2:
28
+ raise ValueError(f"target size should be tuple (height, width), instead got {target_size}")
29
+ return torch.nn.functional.interpolate(clip, size=target_size, mode=interpolation_mode, align_corners=False)
30
+
31
+ def resize_scale(clip, target_size, interpolation_mode):
32
+ if len(target_size) != 2:
33
+ raise ValueError(f"target size should be tuple (height, width), instead got {target_size}")
34
+ _, _, H, W = clip.shape
35
+ scale_ = target_size[0] / min(H, W)
36
+ return torch.nn.functional.interpolate(clip, scale_factor=scale_, mode=interpolation_mode, align_corners=False)
37
+
38
+
39
+ def resized_crop(clip, i, j, h, w, size, interpolation_mode="bilinear"):
40
+ """
41
+ Do spatial cropping and resizing to the video clip
42
+ Args:
43
+ clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
44
+ i (int): i in (i,j) i.e coordinates of the upper left corner.
45
+ j (int): j in (i,j) i.e coordinates of the upper left corner.
46
+ h (int): Height of the cropped region.
47
+ w (int): Width of the cropped region.
48
+ size (tuple(int, int)): height and width of resized clip
49
+ Returns:
50
+ clip (torch.tensor): Resized and cropped clip. Size is (T, C, H, W)
51
+ """
52
+ if not _is_tensor_video_clip(clip):
53
+ raise ValueError("clip should be a 4D torch.tensor")
54
+ clip = crop(clip, i, j, h, w)
55
+ clip = resize(clip, size, interpolation_mode)
56
+ return clip
57
+
58
+
59
+ def center_crop(clip, crop_size):
60
+ if not _is_tensor_video_clip(clip):
61
+ raise ValueError("clip should be a 4D torch.tensor")
62
+ h, w = clip.size(-2), clip.size(-1)
63
+ th, tw = crop_size
64
+ if h < th or w < tw:
65
+ raise ValueError("height and width must be no smaller than crop_size")
66
+
67
+ i = int(round((h - th) / 2.0))
68
+ j = int(round((w - tw) / 2.0))
69
+ return crop(clip, i, j, th, tw)
70
+
71
+ def random_shift_crop(clip):
72
+ '''
73
+ Slide along the long edge, with the short edge as crop size
74
+ '''
75
+ if not _is_tensor_video_clip(clip):
76
+ raise ValueError("clip should be a 4D torch.tensor")
77
+ h, w = clip.size(-2), clip.size(-1)
78
+
79
+ if h <= w:
80
+ long_edge = w
81
+ short_edge = h
82
+ else:
83
+ long_edge = h
84
+ short_edge =w
85
+
86
+ th, tw = short_edge, short_edge
87
+
88
+ i = torch.randint(0, h - th + 1, size=(1,)).item()
89
+ j = torch.randint(0, w - tw + 1, size=(1,)).item()
90
+ return crop(clip, i, j, th, tw)
91
+
92
+
93
+ def to_tensor(clip):
94
+ """
95
+ Convert tensor data type from uint8 to float, divide value by 255.0 and
96
+ permute the dimensions of clip tensor
97
+ Args:
98
+ clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W)
99
+ Return:
100
+ clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W)
101
+ """
102
+ _is_tensor_video_clip(clip)
103
+ if not clip.dtype == torch.uint8:
104
+ raise TypeError("clip tensor should have data type uint8. Got %s" % str(clip.dtype))
105
+ # return clip.float().permute(3, 0, 1, 2) / 255.0
106
+ return clip.float() / 255.0
107
+
108
+
109
+ def normalize(clip, mean, std, inplace=False):
110
+ """
111
+ Args:
112
+ clip (torch.tensor): Video clip to be normalized. Size is (T, C, H, W)
113
+ mean (tuple): pixel RGB mean. Size is (3)
114
+ std (tuple): pixel standard deviation. Size is (3)
115
+ Returns:
116
+ normalized clip (torch.tensor): Size is (T, C, H, W)
117
+ """
118
+ if not _is_tensor_video_clip(clip):
119
+ raise ValueError("clip should be a 4D torch.tensor")
120
+ if not inplace:
121
+ clip = clip.clone()
122
+ mean = torch.as_tensor(mean, dtype=clip.dtype, device=clip.device)
123
+ print(mean)
124
+ std = torch.as_tensor(std, dtype=clip.dtype, device=clip.device)
125
+ clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None])
126
+ return clip
127
+
128
+
129
+ def hflip(clip):
130
+ """
131
+ Args:
132
+ clip (torch.tensor): Video clip to be normalized. Size is (T, C, H, W)
133
+ Returns:
134
+ flipped clip (torch.tensor): Size is (T, C, H, W)
135
+ """
136
+ if not _is_tensor_video_clip(clip):
137
+ raise ValueError("clip should be a 4D torch.tensor")
138
+ return clip.flip(-1)
139
+
140
+
141
+ class RandomCropVideo:
142
+ def __init__(self, size):
143
+ if isinstance(size, numbers.Number):
144
+ self.size = (int(size), int(size))
145
+ else:
146
+ self.size = size
147
+
148
+ def __call__(self, clip):
149
+ """
150
+ Args:
151
+ clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
152
+ Returns:
153
+ torch.tensor: randomly cropped video clip.
154
+ size is (T, C, OH, OW)
155
+ """
156
+ i, j, h, w = self.get_params(clip)
157
+ return crop(clip, i, j, h, w)
158
+
159
+ def get_params(self, clip):
160
+ h, w = clip.shape[-2:]
161
+ th, tw = self.size
162
+
163
+ if h < th or w < tw:
164
+ raise ValueError(f"Required crop size {(th, tw)} is larger than input image size {(h, w)}")
165
+
166
+ if w == tw and h == th:
167
+ return 0, 0, h, w
168
+
169
+ i = torch.randint(0, h - th + 1, size=(1,)).item()
170
+ j = torch.randint(0, w - tw + 1, size=(1,)).item()
171
+
172
+ return i, j, th, tw
173
+
174
+ def __repr__(self) -> str:
175
+ return f"{self.__class__.__name__}(size={self.size})"
176
+
177
+
178
+ class UCFCenterCropVideo:
179
+ def __init__(
180
+ self,
181
+ size,
182
+ interpolation_mode="bilinear",
183
+ ):
184
+ if isinstance(size, tuple):
185
+ if len(size) != 2:
186
+ raise ValueError(f"size should be tuple (height, width), instead got {size}")
187
+ self.size = size
188
+ else:
189
+ self.size = (size, size)
190
+
191
+ self.interpolation_mode = interpolation_mode
192
+
193
+
194
+ def __call__(self, clip):
195
+ """
196
+ Args:
197
+ clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
198
+ Returns:
199
+ torch.tensor: scale resized / center cropped video clip.
200
+ size is (T, C, crop_size, crop_size)
201
+ """
202
+ clip_resize = resize_scale(clip=clip, target_size=self.size, interpolation_mode=self.interpolation_mode)
203
+ clip_center_crop = center_crop(clip_resize, self.size)
204
+ return clip_center_crop
205
+
206
+ def __repr__(self) -> str:
207
+ return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
208
+
209
+ class KineticsRandomCropResizeVideo:
210
+ '''
211
+ Slide along the long edge, with the short edge as crop size. And resie to the desired size.
212
+ '''
213
+ def __init__(
214
+ self,
215
+ size,
216
+ interpolation_mode="bilinear",
217
+ ):
218
+ if isinstance(size, tuple):
219
+ if len(size) != 2:
220
+ raise ValueError(f"size should be tuple (height, width), instead got {size}")
221
+ self.size = size
222
+ else:
223
+ self.size = (size, size)
224
+
225
+ self.interpolation_mode = interpolation_mode
226
+
227
+ def __call__(self, clip):
228
+ clip_random_crop = random_shift_crop(clip)
229
+ clip_resize = resize(clip_random_crop, self.size, self.interpolation_mode)
230
+ return clip_resize
231
+
232
+
233
+ class CenterCropVideo:
234
+ def __init__(
235
+ self,
236
+ size,
237
+ interpolation_mode="bilinear",
238
+ ):
239
+ if isinstance(size, tuple):
240
+ if len(size) != 2:
241
+ raise ValueError(f"size should be tuple (height, width), instead got {size}")
242
+ self.size = size
243
+ else:
244
+ self.size = (size, size)
245
+
246
+ self.interpolation_mode = interpolation_mode
247
+
248
+
249
+ def __call__(self, clip):
250
+ """
251
+ Args:
252
+ clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
253
+ Returns:
254
+ torch.tensor: center cropped video clip.
255
+ size is (T, C, crop_size, crop_size)
256
+ """
257
+ clip_center_crop = center_crop(clip, self.size)
258
+ return clip_center_crop
259
+
260
+ def __repr__(self) -> str:
261
+ return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
262
+
263
+
264
+ class NormalizeVideo:
265
+ """
266
+ Normalize the video clip by mean subtraction and division by standard deviation
267
+ Args:
268
+ mean (3-tuple): pixel RGB mean
269
+ std (3-tuple): pixel RGB standard deviation
270
+ inplace (boolean): whether do in-place normalization
271
+ """
272
+
273
+ def __init__(self, mean, std, inplace=False):
274
+ self.mean = mean
275
+ self.std = std
276
+ self.inplace = inplace
277
+
278
+ def __call__(self, clip):
279
+ """
280
+ Args:
281
+ clip (torch.tensor): video clip must be normalized. Size is (C, T, H, W)
282
+ """
283
+ return normalize(clip, self.mean, self.std, self.inplace)
284
+
285
+ def __repr__(self) -> str:
286
+ return f"{self.__class__.__name__}(mean={self.mean}, std={self.std}, inplace={self.inplace})"
287
+
288
+
289
+ class ToTensorVideo:
290
+ """
291
+ Convert tensor data type from uint8 to float, divide value by 255.0 and
292
+ permute the dimensions of clip tensor
293
+ """
294
+
295
+ def __init__(self):
296
+ pass
297
+
298
+ def __call__(self, clip):
299
+ """
300
+ Args:
301
+ clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W)
302
+ Return:
303
+ clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W)
304
+ """
305
+ return to_tensor(clip)
306
+
307
+ def __repr__(self) -> str:
308
+ return self.__class__.__name__
309
+
310
+
311
+ class RandomHorizontalFlipVideo:
312
+ """
313
+ Flip the video clip along the horizontal direction with a given probability
314
+ Args:
315
+ p (float): probability of the clip being flipped. Default value is 0.5
316
+ """
317
+
318
+ def __init__(self, p=0.5):
319
+ self.p = p
320
+
321
+ def __call__(self, clip):
322
+ """
323
+ Args:
324
+ clip (torch.tensor): Size is (T, C, H, W)
325
+ Return:
326
+ clip (torch.tensor): Size is (T, C, H, W)
327
+ """
328
+ if random.random() < self.p:
329
+ clip = hflip(clip)
330
+ return clip
331
+
332
+ def __repr__(self) -> str:
333
+ return f"{self.__class__.__name__}(p={self.p})"
334
+
335
+ # ------------------------------------------------------------
336
+ # --------------------- Sampling ---------------------------
337
+ # ------------------------------------------------------------
338
+ class TemporalRandomCrop(object):
339
+ """Temporally crop the given frame indices at a random location.
340
+
341
+ Args:
342
+ size (int): Desired length of frames will be seen in the model.
343
+ """
344
+
345
+ def __init__(self, size):
346
+ self.size = size
347
+
348
+ def __call__(self, total_frames):
349
+ rand_end = max(0, total_frames - self.size - 1)
350
+ begin_index = random.randint(0, rand_end)
351
+ end_index = min(begin_index + self.size, total_frames)
352
+ return begin_index, end_index
353
+
354
+
355
+ if __name__ == '__main__':
356
+ from torchvision import transforms
357
+ import torchvision.io as io
358
+ import numpy as np
359
+ from torchvision.utils import save_image
360
+ import os
361
+
362
+ vframes, aframes, info = io.read_video(
363
+ filename='./v_Archery_g01_c03.avi',
364
+ pts_unit='sec',
365
+ output_format='TCHW'
366
+ )
367
+
368
+ trans = transforms.Compose([
369
+ ToTensorVideo(),
370
+ RandomHorizontalFlipVideo(),
371
+ UCFCenterCropVideo(512),
372
+ # NormalizeVideo(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
373
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
374
+ ])
375
+
376
+ target_video_len = 32
377
+ frame_interval = 1
378
+ total_frames = len(vframes)
379
+ print(total_frames)
380
+
381
+ temporal_sample = TemporalRandomCrop(target_video_len * frame_interval)
382
+
383
+
384
+ # Sampling video frames
385
+ start_frame_ind, end_frame_ind = temporal_sample(total_frames)
386
+ # print(start_frame_ind)
387
+ # print(end_frame_ind)
388
+ assert end_frame_ind - start_frame_ind >= target_video_len
389
+ frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, target_video_len, dtype=int)
390
+ print(frame_indice)
391
+
392
+ select_vframes = vframes[frame_indice]
393
+ print(select_vframes.shape)
394
+ print(select_vframes.dtype)
395
+
396
+ select_vframes_trans = trans(select_vframes)
397
+ print(select_vframes_trans.shape)
398
+ print(select_vframes_trans.dtype)
399
+
400
+ select_vframes_trans_int = ((select_vframes_trans * 0.5 + 0.5) * 255).to(dtype=torch.uint8)
401
+ print(select_vframes_trans_int.dtype)
402
+ print(select_vframes_trans_int.permute(0, 2, 3, 1).shape)
403
+
404
+ io.write_video('./test.avi', select_vframes_trans_int.permute(0, 2, 3, 1), fps=8)
405
+
406
+ for i in range(target_video_len):
407
+ save_image(select_vframes_trans[i], os.path.join('./test000', '%04d.png' % i), normalize=True, value_range=(-1, 1))
animatediff/models/__init__.py ADDED
File without changes
animatediff/models/attention.py ADDED
@@ -0,0 +1,559 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/guoyww/AnimateDiff
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from torch import nn
9
+
10
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
11
+ from diffusers.models import ModelMixin
12
+ from diffusers.models.attention import Attention
13
+ from diffusers.utils import BaseOutput
14
+ from diffusers.utils.import_utils import is_xformers_available
15
+ from diffusers.models.attention import FeedForward, AdaLayerNorm
16
+
17
+ from einops import rearrange, repeat
18
+ import pdb
19
+
20
+ @dataclass
21
+ class Transformer3DModelOutput(BaseOutput):
22
+ sample: torch.FloatTensor
23
+
24
+
25
+ if is_xformers_available():
26
+ import xformers
27
+ import xformers.ops
28
+ else:
29
+ xformers = None
30
+
31
+
32
+ class Transformer3DModel(ModelMixin, ConfigMixin):
33
+ @register_to_config
34
+ def __init__(
35
+ self,
36
+ num_attention_heads: int = 16,
37
+ attention_head_dim: int = 88,
38
+ in_channels: Optional[int] = None,
39
+ num_layers: int = 1,
40
+ dropout: float = 0.0,
41
+ norm_num_groups: int = 32,
42
+ cross_attention_dim: Optional[int] = None,
43
+ attention_bias: bool = False,
44
+ activation_fn: str = "geglu",
45
+ num_embeds_ada_norm: Optional[int] = None,
46
+ use_linear_projection: bool = False,
47
+ only_cross_attention: bool = False,
48
+ upcast_attention: bool = False,
49
+
50
+ unet_use_cross_frame_attention=None,
51
+ unet_use_temporal_attention=None,
52
+ ):
53
+ super().__init__()
54
+ self.use_linear_projection = use_linear_projection
55
+ self.num_attention_heads = num_attention_heads
56
+ self.attention_head_dim = attention_head_dim
57
+ inner_dim = num_attention_heads * attention_head_dim
58
+
59
+ # Define input layers
60
+ self.in_channels = in_channels
61
+
62
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
63
+ if use_linear_projection:
64
+ self.proj_in = nn.Linear(in_channels, inner_dim)
65
+ else:
66
+ self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
67
+
68
+ # Define transformers blocks
69
+ self.transformer_blocks = nn.ModuleList(
70
+ [
71
+ BasicTransformerBlock(
72
+ inner_dim,
73
+ num_attention_heads,
74
+ attention_head_dim,
75
+ dropout=dropout,
76
+ cross_attention_dim=cross_attention_dim,
77
+ activation_fn=activation_fn,
78
+ num_embeds_ada_norm=num_embeds_ada_norm,
79
+ attention_bias=attention_bias,
80
+ only_cross_attention=only_cross_attention,
81
+ upcast_attention=upcast_attention,
82
+
83
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
84
+ unet_use_temporal_attention=unet_use_temporal_attention,
85
+ )
86
+ for d in range(num_layers)
87
+ ]
88
+ )
89
+
90
+ # 4. Define output layers
91
+ if use_linear_projection:
92
+ self.proj_out = nn.Linear(in_channels, inner_dim)
93
+ else:
94
+ self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
95
+
96
+ def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True):
97
+ # Input
98
+ assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
99
+ video_length = hidden_states.shape[2]
100
+ hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
101
+ encoder_hidden_states = repeat(encoder_hidden_states, 'b n c -> (b f) n c', f=video_length)
102
+
103
+ batch, channel, height, weight = hidden_states.shape
104
+ residual = hidden_states
105
+
106
+ hidden_states = self.norm(hidden_states)
107
+ if not self.use_linear_projection:
108
+ hidden_states = self.proj_in(hidden_states)
109
+ inner_dim = hidden_states.shape[1]
110
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
111
+ else:
112
+ inner_dim = hidden_states.shape[1]
113
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
114
+ hidden_states = self.proj_in(hidden_states)
115
+
116
+ # Blocks
117
+ for block in self.transformer_blocks:
118
+ hidden_states = block(
119
+ hidden_states,
120
+ encoder_hidden_states=encoder_hidden_states,
121
+ timestep=timestep,
122
+ video_length=video_length
123
+ )
124
+
125
+ # Output
126
+ if not self.use_linear_projection:
127
+ hidden_states = (
128
+ hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
129
+ )
130
+ hidden_states = self.proj_out(hidden_states)
131
+ else:
132
+ hidden_states = self.proj_out(hidden_states)
133
+ hidden_states = (
134
+ hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
135
+ )
136
+
137
+ output = hidden_states + residual
138
+
139
+ output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
140
+ if not return_dict:
141
+ return (output,)
142
+
143
+ return Transformer3DModelOutput(sample=output)
144
+
145
+
146
+ class BasicTransformerBlock(nn.Module):
147
+ def __init__(
148
+ self,
149
+ dim: int,
150
+ num_attention_heads: int,
151
+ attention_head_dim: int,
152
+ dropout=0.0,
153
+ cross_attention_dim: Optional[int] = None,
154
+ activation_fn: str = "geglu",
155
+ num_embeds_ada_norm: Optional[int] = None,
156
+ attention_bias: bool = False,
157
+ only_cross_attention: bool = False,
158
+ upcast_attention: bool = False,
159
+
160
+ unet_use_cross_frame_attention = None,
161
+ unet_use_temporal_attention = None,
162
+ ):
163
+ super().__init__()
164
+ self.only_cross_attention = only_cross_attention
165
+ self.use_ada_layer_norm = num_embeds_ada_norm is not None
166
+ self.unet_use_cross_frame_attention = unet_use_cross_frame_attention
167
+ self.unet_use_temporal_attention = unet_use_temporal_attention
168
+
169
+ # SC-Attn
170
+ assert unet_use_cross_frame_attention is not None
171
+ if unet_use_cross_frame_attention:
172
+ self.attn1 = SparseCausalAttention(
173
+ query_dim=dim,
174
+ heads=num_attention_heads,
175
+ dim_head=attention_head_dim,
176
+ dropout=dropout,
177
+ bias=attention_bias,
178
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
179
+ upcast_attention=upcast_attention,
180
+ )
181
+ else:
182
+ self.attn1 = Attention(
183
+ query_dim=dim,
184
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
185
+ heads=num_attention_heads,
186
+ dim_head=attention_head_dim,
187
+ dropout=dropout,
188
+ bias=attention_bias,
189
+ upcast_attention=upcast_attention,
190
+ )
191
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
192
+
193
+ # Cross-Attn
194
+ if cross_attention_dim is not None:
195
+ self.attn2 = Attention(
196
+ query_dim=dim,
197
+ cross_attention_dim=cross_attention_dim,
198
+ heads=num_attention_heads,
199
+ dim_head=attention_head_dim,
200
+ dropout=dropout,
201
+ bias=attention_bias,
202
+ upcast_attention=upcast_attention,
203
+ )
204
+ else:
205
+ self.attn2 = None
206
+
207
+ if cross_attention_dim is not None:
208
+ self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
209
+ else:
210
+ self.norm2 = None
211
+
212
+ # Feed-forward
213
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
214
+ self.norm3 = nn.LayerNorm(dim)
215
+
216
+ # Temp-Attn
217
+ assert unet_use_temporal_attention is not None
218
+ if unet_use_temporal_attention:
219
+ self.attn_temp = Attention(
220
+ query_dim=dim,
221
+ heads=num_attention_heads,
222
+ dim_head=attention_head_dim,
223
+ dropout=dropout,
224
+ bias=attention_bias,
225
+ upcast_attention=upcast_attention,
226
+ )
227
+ nn.init.zeros_(self.attn_temp.to_out[0].weight.data)
228
+ self.norm_temp = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
229
+
230
+ def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None, video_length=None):
231
+ # SparseCausal-Attention
232
+ norm_hidden_states = (
233
+ self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states)
234
+ )
235
+
236
+ # if self.only_cross_attention:
237
+ # hidden_states = (
238
+ # self.attn1(norm_hidden_states, encoder_hidden_states, attention_mask=attention_mask) + hidden_states
239
+ # )
240
+ # else:
241
+ # hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states
242
+
243
+ # pdb.set_trace()
244
+ if self.unet_use_cross_frame_attention:
245
+ hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states
246
+ else:
247
+ hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask) + hidden_states
248
+
249
+ if self.attn2 is not None:
250
+ # Cross-Attention
251
+ norm_hidden_states = (
252
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
253
+ )
254
+ hidden_states = (
255
+ self.attn2(
256
+ norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
257
+ )
258
+ + hidden_states
259
+ )
260
+
261
+ # Feed-forward
262
+ hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
263
+
264
+ # Temporal-Attention
265
+ if self.unet_use_temporal_attention:
266
+ d = hidden_states.shape[1]
267
+ hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length)
268
+ norm_hidden_states = (
269
+ self.norm_temp(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_temp(hidden_states)
270
+ )
271
+ hidden_states = self.attn_temp(norm_hidden_states) + hidden_states
272
+ hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
273
+
274
+ return hidden_states
275
+
276
+ class CrossAttention(nn.Module):
277
+ r"""
278
+ A cross attention layer.
279
+
280
+ Parameters:
281
+ query_dim (`int`): The number of channels in the query.
282
+ cross_attention_dim (`int`, *optional*):
283
+ The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
284
+ heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
285
+ dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
286
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
287
+ bias (`bool`, *optional*, defaults to False):
288
+ Set to `True` for the query, key, and value linear layers to contain a bias parameter.
289
+ """
290
+
291
+ def __init__(
292
+ self,
293
+ query_dim: int,
294
+ cross_attention_dim: Optional[int] = None,
295
+ heads: int = 8,
296
+ dim_head: int = 64,
297
+ dropout: float = 0.0,
298
+ bias=False,
299
+ upcast_attention: bool = False,
300
+ upcast_softmax: bool = False,
301
+ added_kv_proj_dim: Optional[int] = None,
302
+ norm_num_groups: Optional[int] = None,
303
+ ):
304
+ super().__init__()
305
+ inner_dim = dim_head * heads
306
+ cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
307
+ self.upcast_attention = upcast_attention
308
+ self.upcast_softmax = upcast_softmax
309
+
310
+ self.scale = dim_head**-0.5
311
+
312
+ self.heads = heads
313
+ # for slice_size > 0 the attention score computation
314
+ # is split across the batch axis to save memory
315
+ # You can set slice_size with `set_attention_slice`
316
+ self.sliceable_head_dim = heads
317
+ self._slice_size = None
318
+ self._use_memory_efficient_attention_xformers = False
319
+ self.added_kv_proj_dim = added_kv_proj_dim
320
+
321
+ if norm_num_groups is not None:
322
+ self.group_norm = nn.GroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, eps=1e-5, affine=True)
323
+ else:
324
+ self.group_norm = None
325
+
326
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
327
+ self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
328
+ self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
329
+
330
+ if self.added_kv_proj_dim is not None:
331
+ self.add_k_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
332
+ self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
333
+
334
+ self.to_out = nn.ModuleList([])
335
+ self.to_out.append(nn.Linear(inner_dim, query_dim))
336
+ self.to_out.append(nn.Dropout(dropout))
337
+
338
+ def reshape_heads_to_batch_dim(self, tensor):
339
+ batch_size, seq_len, dim = tensor.shape
340
+ head_size = self.heads
341
+ tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
342
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
343
+ return tensor
344
+
345
+ def reshape_batch_dim_to_heads(self, tensor):
346
+ batch_size, seq_len, dim = tensor.shape
347
+ head_size = self.heads
348
+ tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
349
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
350
+ return tensor
351
+
352
+ def set_attention_slice(self, slice_size):
353
+ if slice_size is not None and slice_size > self.sliceable_head_dim:
354
+ raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
355
+
356
+ self._slice_size = slice_size
357
+
358
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
359
+ batch_size, sequence_length, _ = hidden_states.shape
360
+
361
+ encoder_hidden_states = encoder_hidden_states
362
+
363
+ if self.group_norm is not None:
364
+ hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
365
+
366
+ query = self.to_q(hidden_states)
367
+ dim = query.shape[-1]
368
+ query = self.reshape_heads_to_batch_dim(query)
369
+
370
+ if self.added_kv_proj_dim is not None:
371
+ key = self.to_k(hidden_states)
372
+ value = self.to_v(hidden_states)
373
+ encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states)
374
+ encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states)
375
+
376
+ key = self.reshape_heads_to_batch_dim(key)
377
+ value = self.reshape_heads_to_batch_dim(value)
378
+ encoder_hidden_states_key_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_key_proj)
379
+ encoder_hidden_states_value_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_value_proj)
380
+
381
+ key = torch.concat([encoder_hidden_states_key_proj, key], dim=1)
382
+ value = torch.concat([encoder_hidden_states_value_proj, value], dim=1)
383
+ else:
384
+ encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
385
+ key = self.to_k(encoder_hidden_states)
386
+ value = self.to_v(encoder_hidden_states)
387
+
388
+ key = self.reshape_heads_to_batch_dim(key)
389
+ value = self.reshape_heads_to_batch_dim(value)
390
+
391
+ if attention_mask is not None:
392
+ if attention_mask.shape[-1] != query.shape[1]:
393
+ target_length = query.shape[1]
394
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
395
+ attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
396
+
397
+ # attention, what we cannot get enough of
398
+ if self._use_memory_efficient_attention_xformers:
399
+ hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
400
+ # Some versions of xformers return output in fp32, cast it back to the dtype of the input
401
+ hidden_states = hidden_states.to(query.dtype)
402
+ else:
403
+ if self._slice_size is None or query.shape[0] // self._slice_size == 1:
404
+ hidden_states = self._attention(query, key, value, attention_mask)
405
+ else:
406
+ hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
407
+
408
+ # linear proj
409
+ hidden_states = self.to_out[0](hidden_states)
410
+
411
+ # dropout
412
+ hidden_states = self.to_out[1](hidden_states)
413
+ return hidden_states
414
+
415
+ def _attention(self, query, key, value, attention_mask=None):
416
+ if self.upcast_attention:
417
+ query = query.float()
418
+ key = key.float()
419
+
420
+ attention_scores = torch.baddbmm(
421
+ torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
422
+ query,
423
+ key.transpose(-1, -2),
424
+ beta=0,
425
+ alpha=self.scale,
426
+ )
427
+
428
+ if attention_mask is not None:
429
+ attention_scores = attention_scores + attention_mask
430
+
431
+ if self.upcast_softmax:
432
+ attention_scores = attention_scores.float()
433
+
434
+ attention_probs = attention_scores.softmax(dim=-1)
435
+
436
+ # cast back to the original dtype
437
+ attention_probs = attention_probs.to(value.dtype)
438
+
439
+ # compute attention output
440
+ hidden_states = torch.bmm(attention_probs, value)
441
+
442
+ # reshape hidden_states
443
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
444
+ return hidden_states
445
+
446
+ def _sliced_attention(self, query, key, value, sequence_length, dim, attention_mask):
447
+ batch_size_attention = query.shape[0]
448
+ hidden_states = torch.zeros(
449
+ (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype
450
+ )
451
+ slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0]
452
+ for i in range(hidden_states.shape[0] // slice_size):
453
+ start_idx = i * slice_size
454
+ end_idx = (i + 1) * slice_size
455
+
456
+ query_slice = query[start_idx:end_idx]
457
+ key_slice = key[start_idx:end_idx]
458
+
459
+ if self.upcast_attention:
460
+ query_slice = query_slice.float()
461
+ key_slice = key_slice.float()
462
+
463
+ attn_slice = torch.baddbmm(
464
+ torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query_slice.dtype, device=query.device),
465
+ query_slice,
466
+ key_slice.transpose(-1, -2),
467
+ beta=0,
468
+ alpha=self.scale,
469
+ )
470
+
471
+ if attention_mask is not None:
472
+ attn_slice = attn_slice + attention_mask[start_idx:end_idx]
473
+
474
+ if self.upcast_softmax:
475
+ attn_slice = attn_slice.float()
476
+
477
+ attn_slice = attn_slice.softmax(dim=-1)
478
+
479
+ # cast back to the original dtype
480
+ attn_slice = attn_slice.to(value.dtype)
481
+ attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
482
+
483
+ hidden_states[start_idx:end_idx] = attn_slice
484
+
485
+ # reshape hidden_states
486
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
487
+ return hidden_states
488
+
489
+ def _memory_efficient_attention_xformers(self, query, key, value, attention_mask):
490
+ # TODO attention_mask
491
+ query = query.contiguous()
492
+ key = key.contiguous()
493
+ value = value.contiguous()
494
+ hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
495
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
496
+ return hidden_states
497
+
498
+
499
+
500
+ class SparseCausalAttention(CrossAttention):
501
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None):
502
+ batch_size, sequence_length, _ = hidden_states.shape
503
+
504
+ encoder_hidden_states = encoder_hidden_states
505
+
506
+ if self.group_norm is not None:
507
+ hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
508
+
509
+ query = self.to_q(hidden_states)
510
+ dim = query.shape[-1]
511
+ query = self.reshape_heads_to_batch_dim(query)
512
+
513
+ if self.added_kv_proj_dim is not None:
514
+ raise NotImplementedError
515
+
516
+ encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
517
+ key = self.to_k(encoder_hidden_states)
518
+ value = self.to_v(encoder_hidden_states)
519
+
520
+ former_frame_index = torch.arange(video_length) - 1
521
+ former_frame_index[0] = 0
522
+
523
+ key = rearrange(key, "(b f) d c -> b f d c", f=video_length)
524
+ #key = torch.cat([key[:, [0] * video_length], key[:, [0] * video_length]], dim=2)
525
+ key = key[:, [0] * video_length]
526
+ key = rearrange(key, "b f d c -> (b f) d c")
527
+
528
+ value = rearrange(value, "(b f) d c -> b f d c", f=video_length)
529
+ #value = torch.cat([value[:, [0] * video_length], value[:, [0] * video_length]], dim=2)
530
+ #value = value[:, former_frame_index]
531
+ value = rearrange(value, "b f d c -> (b f) d c")
532
+
533
+ key = self.reshape_heads_to_batch_dim(key)
534
+ value = self.reshape_heads_to_batch_dim(value)
535
+
536
+ if attention_mask is not None:
537
+ if attention_mask.shape[-1] != query.shape[1]:
538
+ target_length = query.shape[1]
539
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
540
+ attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
541
+
542
+ # attention, what we cannot get enough of
543
+ if self._use_memory_efficient_attention_xformers:
544
+ hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
545
+ # Some versions of xformers return output in fp32, cast it back to the dtype of the input
546
+ hidden_states = hidden_states.to(query.dtype)
547
+ else:
548
+ if self._slice_size is None or query.shape[0] // self._slice_size == 1:
549
+ hidden_states = self._attention(query, key, value, attention_mask)
550
+ else:
551
+ hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
552
+
553
+ # linear proj
554
+ hidden_states = self.to_out[0](hidden_states)
555
+
556
+ # dropout
557
+ hidden_states = self.to_out[1](hidden_states)
558
+ return hidden_states
559
+
animatediff/models/motion_module.py ADDED
@@ -0,0 +1,555 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/guoyww/AnimateDiff
2
+ from dataclasses import dataclass
3
+ from typing import List, Optional, Tuple, Union
4
+
5
+ import torch
6
+ import numpy as np
7
+ import torch.nn.functional as F
8
+ from torch import nn
9
+ import torchvision
10
+
11
+ from diffusers.utils import BaseOutput
12
+ from diffusers.utils.import_utils import is_xformers_available
13
+ from diffusers.models.attention import FeedForward
14
+
15
+ from einops import rearrange, repeat
16
+ import math
17
+
18
+
19
+ def zero_module(module):
20
+ # Zero out the parameters of a module and return it.
21
+ for p in module.parameters():
22
+ p.detach().zero_()
23
+ return module
24
+
25
+
26
+ @dataclass
27
+ class TemporalTransformer3DModelOutput(BaseOutput):
28
+ sample: torch.FloatTensor
29
+
30
+
31
+ if is_xformers_available():
32
+ import xformers
33
+ import xformers.ops
34
+ else:
35
+ xformers = None
36
+
37
+
38
+ def get_motion_module(
39
+ in_channels,
40
+ motion_module_type: str,
41
+ motion_module_kwargs: dict
42
+ ):
43
+ if motion_module_type == "Vanilla":
44
+ return VanillaTemporalModule(in_channels=in_channels, **motion_module_kwargs,)
45
+ else:
46
+ raise ValueError
47
+
48
+
49
+ class VanillaTemporalModule(nn.Module):
50
+ def __init__(
51
+ self,
52
+ in_channels,
53
+ num_attention_heads = 8,
54
+ num_transformer_block = 2,
55
+ attention_block_types =( "Temporal_Self", "Temporal_Self" ),
56
+ cross_frame_attention_mode = None,
57
+ temporal_position_encoding = False,
58
+ temporal_position_encoding_max_len = 32,
59
+ temporal_attention_dim_div = 1,
60
+ zero_initialize = True,
61
+ ):
62
+ super().__init__()
63
+
64
+ self.temporal_transformer = TemporalTransformer3DModel(
65
+ in_channels=in_channels,
66
+ num_attention_heads=num_attention_heads,
67
+ attention_head_dim=in_channels // num_attention_heads // temporal_attention_dim_div,
68
+ num_layers=num_transformer_block,
69
+ attention_block_types=attention_block_types,
70
+ cross_frame_attention_mode=cross_frame_attention_mode,
71
+ temporal_position_encoding=temporal_position_encoding,
72
+ temporal_position_encoding_max_len=temporal_position_encoding_max_len,
73
+ )
74
+
75
+ if zero_initialize:
76
+ self.temporal_transformer.proj_out = zero_module(self.temporal_transformer.proj_out)
77
+
78
+ def forward(self, input_tensor, temb, encoder_hidden_states, attention_mask=None, anchor_frame_idx=None):
79
+ hidden_states = input_tensor
80
+ hidden_states = self.temporal_transformer(hidden_states, encoder_hidden_states, attention_mask)
81
+
82
+ output = hidden_states
83
+ return output
84
+
85
+
86
+ class TemporalTransformer3DModel(nn.Module):
87
+ def __init__(
88
+ self,
89
+ in_channels,
90
+ num_attention_heads,
91
+ attention_head_dim,
92
+
93
+ num_layers,
94
+ attention_block_types = ( "Temporal_Self", "Temporal_Self", ),
95
+ dropout = 0.0,
96
+ norm_num_groups = 32,
97
+ cross_attention_dim = 1280,
98
+ activation_fn = "geglu",
99
+ attention_bias = False,
100
+ upcast_attention = False,
101
+
102
+ cross_frame_attention_mode = None,
103
+ temporal_position_encoding = False,
104
+ temporal_position_encoding_max_len = 32,
105
+ ):
106
+ super().__init__()
107
+
108
+ inner_dim = num_attention_heads * attention_head_dim
109
+
110
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
111
+ self.proj_in = nn.Linear(in_channels, inner_dim)
112
+
113
+ self.transformer_blocks = nn.ModuleList(
114
+ [
115
+ TemporalTransformerBlock(
116
+ dim=inner_dim,
117
+ num_attention_heads=num_attention_heads,
118
+ attention_head_dim=attention_head_dim,
119
+ attention_block_types=attention_block_types,
120
+ dropout=dropout,
121
+ norm_num_groups=norm_num_groups,
122
+ cross_attention_dim=cross_attention_dim,
123
+ activation_fn=activation_fn,
124
+ attention_bias=attention_bias,
125
+ upcast_attention=upcast_attention,
126
+ cross_frame_attention_mode=cross_frame_attention_mode,
127
+ temporal_position_encoding=temporal_position_encoding,
128
+ temporal_position_encoding_max_len=temporal_position_encoding_max_len,
129
+ )
130
+ for d in range(num_layers)
131
+ ]
132
+ )
133
+ self.proj_out = nn.Linear(inner_dim, in_channels)
134
+
135
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
136
+ assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
137
+ video_length = hidden_states.shape[2]
138
+ hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
139
+
140
+ batch, channel, height, weight = hidden_states.shape
141
+ residual = hidden_states
142
+
143
+ hidden_states = self.norm(hidden_states)
144
+ inner_dim = hidden_states.shape[1]
145
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
146
+ hidden_states = self.proj_in(hidden_states)
147
+
148
+ # Transformer Blocks
149
+ for block in self.transformer_blocks:
150
+ hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states, video_length=video_length)
151
+
152
+ # output
153
+ hidden_states = self.proj_out(hidden_states)
154
+ hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
155
+
156
+ output = hidden_states + residual
157
+ output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
158
+
159
+ return output
160
+
161
+
162
+ class TemporalTransformerBlock(nn.Module):
163
+ def __init__(
164
+ self,
165
+ dim,
166
+ num_attention_heads,
167
+ attention_head_dim,
168
+ attention_block_types = ( "Temporal_Self", "Temporal_Self", ),
169
+ dropout = 0.0,
170
+ norm_num_groups = 32,
171
+ cross_attention_dim = 768,
172
+ activation_fn = "geglu",
173
+ attention_bias = False,
174
+ upcast_attention = False,
175
+ cross_frame_attention_mode = None,
176
+ temporal_position_encoding = False,
177
+ temporal_position_encoding_max_len = 32,
178
+ ):
179
+ super().__init__()
180
+
181
+ attention_blocks = []
182
+ norms = []
183
+
184
+ for block_name in attention_block_types:
185
+ attention_blocks.append(
186
+ VersatileAttention(
187
+ attention_mode=block_name.split("_")[0],
188
+ cross_attention_dim=cross_attention_dim if block_name.endswith("_Cross") else None,
189
+
190
+ query_dim=dim,
191
+ heads=num_attention_heads,
192
+ dim_head=attention_head_dim,
193
+ dropout=dropout,
194
+ bias=attention_bias,
195
+ upcast_attention=upcast_attention,
196
+
197
+ cross_frame_attention_mode=cross_frame_attention_mode,
198
+ temporal_position_encoding=temporal_position_encoding,
199
+ temporal_position_encoding_max_len=temporal_position_encoding_max_len,
200
+ )
201
+ )
202
+ norms.append(nn.LayerNorm(dim))
203
+
204
+ self.attention_blocks = nn.ModuleList(attention_blocks)
205
+ self.norms = nn.ModuleList(norms)
206
+
207
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
208
+ self.ff_norm = nn.LayerNorm(dim)
209
+
210
+
211
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None):
212
+ for attention_block, norm in zip(self.attention_blocks, self.norms):
213
+ norm_hidden_states = norm(hidden_states)
214
+ hidden_states = attention_block(
215
+ norm_hidden_states,
216
+ encoder_hidden_states=encoder_hidden_states if attention_block.is_cross_attention else None,
217
+ video_length=video_length,
218
+ ) + hidden_states
219
+
220
+ hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states
221
+
222
+ output = hidden_states
223
+ return output
224
+
225
+
226
+ class PositionalEncoding(nn.Module):
227
+ def __init__(
228
+ self,
229
+ d_model,
230
+ dropout = 0.,
231
+ max_len = 32
232
+ ):
233
+ super().__init__()
234
+ self.dropout = nn.Dropout(p=dropout)
235
+ position = torch.arange(max_len).unsqueeze(1)
236
+ div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
237
+ pe = torch.zeros(1, max_len, d_model)
238
+ pe[0, :, 0::2] = torch.sin(position * div_term)
239
+ pe[0, :, 1::2] = torch.cos(position * div_term)
240
+ self.register_buffer('pe', pe)
241
+
242
+ def forward(self, x):
243
+ x = x + self.pe[:, :x.size(1)]
244
+ return self.dropout(x)
245
+
246
+
247
+
248
+ class CrossAttention(nn.Module):
249
+ r"""
250
+ A cross attention layer.
251
+
252
+ Parameters:
253
+ query_dim (`int`): The number of channels in the query.
254
+ cross_attention_dim (`int`, *optional*):
255
+ The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
256
+ heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
257
+ dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
258
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
259
+ bias (`bool`, *optional*, defaults to False):
260
+ Set to `True` for the query, key, and value linear layers to contain a bias parameter.
261
+ """
262
+
263
+ def __init__(
264
+ self,
265
+ query_dim: int,
266
+ cross_attention_dim: Optional[int] = None,
267
+ heads: int = 8,
268
+ dim_head: int = 64,
269
+ dropout: float = 0.0,
270
+ bias=False,
271
+ upcast_attention: bool = False,
272
+ upcast_softmax: bool = False,
273
+ added_kv_proj_dim: Optional[int] = None,
274
+ norm_num_groups: Optional[int] = None,
275
+ ):
276
+ super().__init__()
277
+ inner_dim = dim_head * heads
278
+ cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
279
+ self.upcast_attention = upcast_attention
280
+ self.upcast_softmax = upcast_softmax
281
+
282
+ self.scale = dim_head**-0.5
283
+
284
+ self.heads = heads
285
+ # for slice_size > 0 the attention score computation
286
+ # is split across the batch axis to save memory
287
+ # You can set slice_size with `set_attention_slice`
288
+ self.sliceable_head_dim = heads
289
+ self._slice_size = None
290
+ self._use_memory_efficient_attention_xformers = False
291
+ self.added_kv_proj_dim = added_kv_proj_dim
292
+
293
+ if norm_num_groups is not None:
294
+ self.group_norm = nn.GroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, eps=1e-5, affine=True)
295
+ else:
296
+ self.group_norm = None
297
+
298
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
299
+ self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
300
+ self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
301
+
302
+ if self.added_kv_proj_dim is not None:
303
+ self.add_k_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
304
+ self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
305
+
306
+ self.to_out = nn.ModuleList([])
307
+ self.to_out.append(nn.Linear(inner_dim, query_dim))
308
+ self.to_out.append(nn.Dropout(dropout))
309
+
310
+ def reshape_heads_to_batch_dim(self, tensor):
311
+ batch_size, seq_len, dim = tensor.shape
312
+ head_size = self.heads
313
+ tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
314
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
315
+ return tensor
316
+
317
+ def reshape_batch_dim_to_heads(self, tensor):
318
+ batch_size, seq_len, dim = tensor.shape
319
+ head_size = self.heads
320
+ tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
321
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
322
+ return tensor
323
+
324
+ def set_attention_slice(self, slice_size):
325
+ if slice_size is not None and slice_size > self.sliceable_head_dim:
326
+ raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
327
+
328
+ self._slice_size = slice_size
329
+
330
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
331
+ batch_size, sequence_length, _ = hidden_states.shape
332
+
333
+ encoder_hidden_states = encoder_hidden_states
334
+
335
+ if self.group_norm is not None:
336
+ hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
337
+
338
+ query = self.to_q(hidden_states)
339
+ dim = query.shape[-1]
340
+ query = self.reshape_heads_to_batch_dim(query)
341
+
342
+ if self.added_kv_proj_dim is not None:
343
+ key = self.to_k(hidden_states)
344
+ value = self.to_v(hidden_states)
345
+ encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states)
346
+ encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states)
347
+
348
+ key = self.reshape_heads_to_batch_dim(key)
349
+ value = self.reshape_heads_to_batch_dim(value)
350
+ encoder_hidden_states_key_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_key_proj)
351
+ encoder_hidden_states_value_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_value_proj)
352
+
353
+ key = torch.concat([encoder_hidden_states_key_proj, key], dim=1)
354
+ value = torch.concat([encoder_hidden_states_value_proj, value], dim=1)
355
+ else:
356
+ encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
357
+ key = self.to_k(encoder_hidden_states)
358
+ value = self.to_v(encoder_hidden_states)
359
+
360
+ key = self.reshape_heads_to_batch_dim(key)
361
+ value = self.reshape_heads_to_batch_dim(value)
362
+
363
+ if attention_mask is not None:
364
+ if attention_mask.shape[-1] != query.shape[1]:
365
+ target_length = query.shape[1]
366
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
367
+ attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
368
+
369
+ # attention, what we cannot get enough of
370
+ if self._use_memory_efficient_attention_xformers:
371
+ hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
372
+ # Some versions of xformers return output in fp32, cast it back to the dtype of the input
373
+ hidden_states = hidden_states.to(query.dtype)
374
+ else:
375
+ if self._slice_size is None or query.shape[0] // self._slice_size == 1:
376
+ hidden_states = self._attention(query, key, value, attention_mask)
377
+ else:
378
+ hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
379
+
380
+ # linear proj
381
+ hidden_states = self.to_out[0](hidden_states)
382
+
383
+ # dropout
384
+ hidden_states = self.to_out[1](hidden_states)
385
+ return hidden_states
386
+
387
+ def _attention(self, query, key, value, attention_mask=None):
388
+ if self.upcast_attention:
389
+ query = query.float()
390
+ key = key.float()
391
+
392
+ attention_scores = torch.baddbmm(
393
+ torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
394
+ query,
395
+ key.transpose(-1, -2),
396
+ beta=0,
397
+ alpha=self.scale,
398
+ )
399
+
400
+ if attention_mask is not None:
401
+ attention_scores = attention_scores + attention_mask
402
+
403
+ if self.upcast_softmax:
404
+ attention_scores = attention_scores.float()
405
+
406
+ attention_probs = attention_scores.softmax(dim=-1)
407
+
408
+ # cast back to the original dtype
409
+ attention_probs = attention_probs.to(value.dtype)
410
+
411
+ # compute attention output
412
+ hidden_states = torch.bmm(attention_probs, value)
413
+
414
+ # reshape hidden_states
415
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
416
+ return hidden_states
417
+
418
+ def _sliced_attention(self, query, key, value, sequence_length, dim, attention_mask):
419
+ batch_size_attention = query.shape[0]
420
+ hidden_states = torch.zeros(
421
+ (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype
422
+ )
423
+ slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0]
424
+ for i in range(hidden_states.shape[0] // slice_size):
425
+ start_idx = i * slice_size
426
+ end_idx = (i + 1) * slice_size
427
+
428
+ query_slice = query[start_idx:end_idx]
429
+ key_slice = key[start_idx:end_idx]
430
+
431
+ if self.upcast_attention:
432
+ query_slice = query_slice.float()
433
+ key_slice = key_slice.float()
434
+
435
+ attn_slice = torch.baddbmm(
436
+ torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query_slice.dtype, device=query.device),
437
+ query_slice,
438
+ key_slice.transpose(-1, -2),
439
+ beta=0,
440
+ alpha=self.scale,
441
+ )
442
+
443
+ if attention_mask is not None:
444
+ attn_slice = attn_slice + attention_mask[start_idx:end_idx]
445
+
446
+ if self.upcast_softmax:
447
+ attn_slice = attn_slice.float()
448
+
449
+ attn_slice = attn_slice.softmax(dim=-1)
450
+
451
+ # cast back to the original dtype
452
+ attn_slice = attn_slice.to(value.dtype)
453
+ attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
454
+
455
+ hidden_states[start_idx:end_idx] = attn_slice
456
+
457
+ # reshape hidden_states
458
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
459
+ return hidden_states
460
+
461
+ def _memory_efficient_attention_xformers(self, query, key, value, attention_mask):
462
+ # TODO attention_mask
463
+ query = query.contiguous()
464
+ key = key.contiguous()
465
+ value = value.contiguous()
466
+ hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
467
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
468
+ return hidden_states
469
+
470
+
471
+ class VersatileAttention(CrossAttention):
472
+ def __init__(
473
+ self,
474
+ attention_mode = None,
475
+ cross_frame_attention_mode = None,
476
+ temporal_position_encoding = False,
477
+ temporal_position_encoding_max_len = 32,
478
+ *args, **kwargs
479
+ ):
480
+ super().__init__(*args, **kwargs)
481
+ assert attention_mode == "Temporal"
482
+
483
+ self.attention_mode = attention_mode
484
+ self.is_cross_attention = kwargs["cross_attention_dim"] is not None
485
+
486
+ self.pos_encoder = PositionalEncoding(
487
+ kwargs["query_dim"],
488
+ dropout=0.,
489
+ max_len=temporal_position_encoding_max_len
490
+ ) if (temporal_position_encoding and attention_mode == "Temporal") else None
491
+
492
+ def extra_repr(self):
493
+ return f"(Module Info) Attention_Mode: {self.attention_mode}, Is_Cross_Attention: {self.is_cross_attention}"
494
+
495
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None):
496
+ batch_size, sequence_length, _ = hidden_states.shape
497
+
498
+ if self.attention_mode == "Temporal":
499
+ d = hidden_states.shape[1]
500
+ hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length)
501
+
502
+ if self.pos_encoder is not None:
503
+ hidden_states = self.pos_encoder(hidden_states)
504
+
505
+ encoder_hidden_states = repeat(encoder_hidden_states, "b n c -> (b d) n c", d=d) if encoder_hidden_states is not None else encoder_hidden_states
506
+ else:
507
+ raise NotImplementedError
508
+
509
+ encoder_hidden_states = encoder_hidden_states
510
+
511
+ if self.group_norm is not None:
512
+ hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
513
+
514
+ query = self.to_q(hidden_states)
515
+ dim = query.shape[-1]
516
+ query = self.reshape_heads_to_batch_dim(query)
517
+
518
+ if self.added_kv_proj_dim is not None:
519
+ raise NotImplementedError
520
+
521
+ encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
522
+ key = self.to_k(encoder_hidden_states)
523
+ value = self.to_v(encoder_hidden_states)
524
+
525
+ key = self.reshape_heads_to_batch_dim(key)
526
+ value = self.reshape_heads_to_batch_dim(value)
527
+
528
+ if attention_mask is not None:
529
+ if attention_mask.shape[-1] != query.shape[1]:
530
+ target_length = query.shape[1]
531
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
532
+ attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
533
+
534
+ # attention, what we cannot get enough of
535
+ if self._use_memory_efficient_attention_xformers:
536
+ hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
537
+ # Some versions of xformers return output in fp32, cast it back to the dtype of the input
538
+ hidden_states = hidden_states.to(query.dtype)
539
+ else:
540
+ if self._slice_size is None or query.shape[0] // self._slice_size == 1:
541
+ hidden_states = self._attention(query, key, value, attention_mask)
542
+ else:
543
+ hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
544
+
545
+ # linear proj
546
+ hidden_states = self.to_out[0](hidden_states)
547
+
548
+ # dropout
549
+ hidden_states = self.to_out[1](hidden_states)
550
+
551
+ if self.attention_mode == "Temporal":
552
+ hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
553
+
554
+ return hidden_states
555
+
animatediff/models/resnet.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/guoyww/AnimateDiff
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ from einops import rearrange
8
+
9
+
10
+ class InflatedConv3d(nn.Conv2d):
11
+ def forward(self, x):
12
+ video_length = x.shape[2]
13
+
14
+ x = rearrange(x, "b c f h w -> (b f) c h w")
15
+ x = super().forward(x)
16
+ x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
17
+
18
+ return x
19
+
20
+
21
+ class Upsample3D(nn.Module):
22
+ def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
23
+ super().__init__()
24
+ self.channels = channels
25
+ self.out_channels = out_channels or channels
26
+ self.use_conv = use_conv
27
+ self.use_conv_transpose = use_conv_transpose
28
+ self.name = name
29
+
30
+ conv = None
31
+ if use_conv_transpose:
32
+ raise NotImplementedError
33
+ elif use_conv:
34
+ self.conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1)
35
+
36
+ def forward(self, hidden_states, output_size=None):
37
+ assert hidden_states.shape[1] == self.channels
38
+
39
+ if self.use_conv_transpose:
40
+ raise NotImplementedError
41
+
42
+ # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
43
+ dtype = hidden_states.dtype
44
+ if dtype == torch.bfloat16:
45
+ hidden_states = hidden_states.to(torch.float32)
46
+
47
+ # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
48
+ if hidden_states.shape[0] >= 64:
49
+ hidden_states = hidden_states.contiguous()
50
+
51
+ # if `output_size` is passed we force the interpolation output
52
+ # size and do not make use of `scale_factor=2`
53
+ if output_size is None:
54
+ hidden_states = F.interpolate(hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest")
55
+ else:
56
+ hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
57
+
58
+ # If the input is bfloat16, we cast back to bfloat16
59
+ if dtype == torch.bfloat16:
60
+ hidden_states = hidden_states.to(dtype)
61
+
62
+ # if self.use_conv:
63
+ # if self.name == "conv":
64
+ # hidden_states = self.conv(hidden_states)
65
+ # else:
66
+ # hidden_states = self.Conv2d_0(hidden_states)
67
+ hidden_states = self.conv(hidden_states)
68
+
69
+ return hidden_states
70
+
71
+
72
+ class Downsample3D(nn.Module):
73
+ def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
74
+ super().__init__()
75
+ self.channels = channels
76
+ self.out_channels = out_channels or channels
77
+ self.use_conv = use_conv
78
+ self.padding = padding
79
+ stride = 2
80
+ self.name = name
81
+
82
+ if use_conv:
83
+ self.conv = InflatedConv3d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
84
+ else:
85
+ raise NotImplementedError
86
+
87
+ def forward(self, hidden_states):
88
+ assert hidden_states.shape[1] == self.channels
89
+ if self.use_conv and self.padding == 0:
90
+ raise NotImplementedError
91
+
92
+ assert hidden_states.shape[1] == self.channels
93
+ hidden_states = self.conv(hidden_states)
94
+
95
+ return hidden_states
96
+
97
+
98
+ class ResnetBlock3D(nn.Module):
99
+ def __init__(
100
+ self,
101
+ *,
102
+ in_channels,
103
+ out_channels=None,
104
+ conv_shortcut=False,
105
+ dropout=0.0,
106
+ temb_channels=512,
107
+ groups=32,
108
+ groups_out=None,
109
+ pre_norm=True,
110
+ eps=1e-6,
111
+ non_linearity="swish",
112
+ time_embedding_norm="default",
113
+ output_scale_factor=1.0,
114
+ use_in_shortcut=None,
115
+ ):
116
+ super().__init__()
117
+ self.pre_norm = pre_norm
118
+ self.pre_norm = True
119
+ self.in_channels = in_channels
120
+ out_channels = in_channels if out_channels is None else out_channels
121
+ self.out_channels = out_channels
122
+ self.use_conv_shortcut = conv_shortcut
123
+ self.time_embedding_norm = time_embedding_norm
124
+ self.output_scale_factor = output_scale_factor
125
+
126
+ if groups_out is None:
127
+ groups_out = groups
128
+
129
+ self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
130
+
131
+ self.conv1 = InflatedConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
132
+
133
+ if temb_channels is not None:
134
+ if self.time_embedding_norm == "default":
135
+ time_emb_proj_out_channels = out_channels
136
+ elif self.time_embedding_norm == "scale_shift":
137
+ time_emb_proj_out_channels = out_channels * 2
138
+ else:
139
+ raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
140
+
141
+ self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels)
142
+ else:
143
+ self.time_emb_proj = None
144
+
145
+ self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
146
+ self.dropout = torch.nn.Dropout(dropout)
147
+ self.conv2 = InflatedConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
148
+
149
+ if non_linearity == "swish":
150
+ self.nonlinearity = lambda x: F.silu(x)
151
+ elif non_linearity == "mish":
152
+ self.nonlinearity = Mish()
153
+ elif non_linearity == "silu":
154
+ self.nonlinearity = nn.SiLU()
155
+
156
+ self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut
157
+
158
+ self.conv_shortcut = None
159
+ if self.use_in_shortcut:
160
+ self.conv_shortcut = InflatedConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
161
+
162
+ def forward(self, input_tensor, temb):
163
+ hidden_states = input_tensor
164
+
165
+ hidden_states = self.norm1(hidden_states)
166
+ hidden_states = self.nonlinearity(hidden_states)
167
+
168
+ hidden_states = self.conv1(hidden_states)
169
+
170
+ if temb is not None:
171
+ temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None]
172
+
173
+ if temb is not None and self.time_embedding_norm == "default":
174
+ hidden_states = hidden_states + temb
175
+
176
+ hidden_states = self.norm2(hidden_states)
177
+
178
+ if temb is not None and self.time_embedding_norm == "scale_shift":
179
+ scale, shift = torch.chunk(temb, 2, dim=1)
180
+ hidden_states = hidden_states * (1 + scale) + shift
181
+
182
+ hidden_states = self.nonlinearity(hidden_states)
183
+
184
+ hidden_states = self.dropout(hidden_states)
185
+ hidden_states = self.conv2(hidden_states)
186
+
187
+ if self.conv_shortcut is not None:
188
+ input_tensor = self.conv_shortcut(input_tensor)
189
+
190
+ output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
191
+
192
+ return output_tensor
193
+
194
+
195
+ class Mish(torch.nn.Module):
196
+ def forward(self, hidden_states):
197
+ return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states))
animatediff/models/unet.py ADDED
@@ -0,0 +1,572 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/guoyww/AnimateDiff
2
+
3
+ from dataclasses import dataclass
4
+ from typing import List, Optional, Tuple, Union
5
+
6
+ import os
7
+ import json
8
+ import pdb
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.utils.checkpoint
13
+ try:
14
+ from diffusers.models.cross_attention import AttnProcessor
15
+ except:
16
+ from diffusers.models.attention_processor import AttnProcessor
17
+ from typing import Dict
18
+
19
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
20
+ from diffusers.models import ModelMixin
21
+ from diffusers.loaders import UNet2DConditionLoadersMixin
22
+ from diffusers.utils import BaseOutput, logging
23
+ from diffusers.models.embeddings import TimestepEmbedding, Timesteps
24
+ from .unet_blocks import (
25
+ CrossAttnDownBlock3D,
26
+ CrossAttnUpBlock3D,
27
+ DownBlock3D,
28
+ UNetMidBlock3DCrossAttn,
29
+ UpBlock3D,
30
+ get_down_block,
31
+ get_up_block,
32
+ )
33
+ from .resnet import InflatedConv3d
34
+ from .motion_module import VersatileAttention
35
+ def zero_module(module):
36
+ # Zero out the parameters of a module and return it.
37
+ for p in module.parameters():
38
+ p.detach().zero_()
39
+ return module
40
+
41
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
42
+
43
+
44
+ @dataclass
45
+ class UNet3DConditionOutput(BaseOutput):
46
+ sample: torch.FloatTensor
47
+
48
+
49
+ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
50
+ _supports_gradient_checkpointing = True
51
+
52
+ @register_to_config
53
+ def __init__(
54
+ self,
55
+ sample_size: Optional[int] = None,
56
+ in_channels: int = 4,
57
+ out_channels: int = 4,
58
+ center_input_sample: bool = False,
59
+ flip_sin_to_cos: bool = True,
60
+ freq_shift: int = 0,
61
+ down_block_types: Tuple[str] = (
62
+ "CrossAttnDownBlock3D",
63
+ "CrossAttnDownBlock3D",
64
+ "CrossAttnDownBlock3D",
65
+ "DownBlock3D",
66
+ ),
67
+ mid_block_type: str = "UNetMidBlock3DCrossAttn",
68
+ up_block_types: Tuple[str] = (
69
+ "UpBlock3D",
70
+ "CrossAttnUpBlock3D",
71
+ "CrossAttnUpBlock3D",
72
+ "CrossAttnUpBlock3D"
73
+ ),
74
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
75
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
76
+ layers_per_block: int = 2,
77
+ downsample_padding: int = 1,
78
+ mid_block_scale_factor: float = 1,
79
+ act_fn: str = "silu",
80
+ norm_num_groups: int = 32,
81
+ norm_eps: float = 1e-5,
82
+ cross_attention_dim: int = 1280,
83
+ attention_head_dim: Union[int, Tuple[int]] = 8,
84
+ dual_cross_attention: bool = False,
85
+ use_linear_projection: bool = False,
86
+ class_embed_type: Optional[str] = None,
87
+ num_class_embeds: Optional[int] = None,
88
+ upcast_attention: bool = False,
89
+ resnet_time_scale_shift: str = "default",
90
+
91
+ # Additional
92
+ use_motion_module = True,
93
+ motion_module_resolutions = ( 1,2,4,8 ),
94
+ motion_module_mid_block = False,
95
+ motion_module_decoder_only = False,
96
+ motion_module_type = None,
97
+ motion_module_kwargs = {},
98
+ unet_use_cross_frame_attention = None,
99
+ unet_use_temporal_attention = None,
100
+
101
+ ):
102
+ super().__init__()
103
+
104
+ self.sample_size = sample_size
105
+ time_embed_dim = block_out_channels[0] * 4
106
+
107
+ # Image to Video Conv
108
+ # input
109
+ self.conv_in = InflatedConv3d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
110
+
111
+ # time
112
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
113
+ timestep_input_dim = block_out_channels[0]
114
+
115
+ self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
116
+
117
+ # class embedding
118
+ if class_embed_type is None and num_class_embeds is not None:
119
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
120
+ elif class_embed_type == "timestep":
121
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
122
+ elif class_embed_type == "identity":
123
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
124
+ else:
125
+ self.class_embedding = None
126
+
127
+ self.down_blocks = nn.ModuleList([])
128
+ self.mid_block = None
129
+ self.up_blocks = nn.ModuleList([])
130
+
131
+ if isinstance(only_cross_attention, bool):
132
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
133
+
134
+ if isinstance(attention_head_dim, int):
135
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
136
+
137
+ # down
138
+ output_channel = block_out_channels[0]
139
+ for i, down_block_type in enumerate(down_block_types):
140
+ res = 2 ** i
141
+ input_channel = output_channel
142
+ output_channel = block_out_channels[i]
143
+ is_final_block = i == len(block_out_channels) - 1
144
+
145
+ down_block = get_down_block(
146
+ down_block_type,
147
+ num_layers=layers_per_block,
148
+ in_channels=input_channel,
149
+ out_channels=output_channel,
150
+ temb_channels=time_embed_dim,
151
+ add_downsample=not is_final_block,
152
+ resnet_eps=norm_eps,
153
+ resnet_act_fn=act_fn,
154
+ resnet_groups=norm_num_groups,
155
+ cross_attention_dim=cross_attention_dim,
156
+ attn_num_head_channels=attention_head_dim[i],
157
+ downsample_padding=downsample_padding,
158
+ dual_cross_attention=dual_cross_attention,
159
+ use_linear_projection=use_linear_projection,
160
+ only_cross_attention=only_cross_attention[i],
161
+ upcast_attention=upcast_attention,
162
+ resnet_time_scale_shift=resnet_time_scale_shift,
163
+
164
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
165
+ unet_use_temporal_attention=unet_use_temporal_attention,
166
+
167
+ use_motion_module=use_motion_module and (res in motion_module_resolutions) and (not motion_module_decoder_only),
168
+ motion_module_type=motion_module_type,
169
+ motion_module_kwargs=motion_module_kwargs,
170
+ )
171
+ self.down_blocks.append(down_block)
172
+
173
+ # mid
174
+ if mid_block_type == "UNetMidBlock3DCrossAttn":
175
+ self.mid_block = UNetMidBlock3DCrossAttn(
176
+ in_channels=block_out_channels[-1],
177
+ temb_channels=time_embed_dim,
178
+ resnet_eps=norm_eps,
179
+ resnet_act_fn=act_fn,
180
+ output_scale_factor=mid_block_scale_factor,
181
+ resnet_time_scale_shift=resnet_time_scale_shift,
182
+ cross_attention_dim=cross_attention_dim,
183
+ attn_num_head_channels=attention_head_dim[-1],
184
+ resnet_groups=norm_num_groups,
185
+ dual_cross_attention=dual_cross_attention,
186
+ use_linear_projection=use_linear_projection,
187
+ upcast_attention=upcast_attention,
188
+
189
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
190
+ unet_use_temporal_attention=unet_use_temporal_attention,
191
+
192
+ use_motion_module=use_motion_module and motion_module_mid_block,
193
+ motion_module_type=motion_module_type,
194
+ motion_module_kwargs=motion_module_kwargs,
195
+ )
196
+ else:
197
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
198
+
199
+ # count how many layers upsample the videos
200
+ self.num_upsamplers = 0
201
+
202
+ # up
203
+ reversed_block_out_channels = list(reversed(block_out_channels))
204
+ reversed_attention_head_dim = list(reversed(attention_head_dim))
205
+ only_cross_attention = list(reversed(only_cross_attention))
206
+ output_channel = reversed_block_out_channels[0]
207
+ for i, up_block_type in enumerate(up_block_types):
208
+ res = 2 ** (3 - i)
209
+ is_final_block = i == len(block_out_channels) - 1
210
+
211
+ prev_output_channel = output_channel
212
+ output_channel = reversed_block_out_channels[i]
213
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
214
+
215
+ # add upsample block for all BUT final layer
216
+ if not is_final_block:
217
+ add_upsample = True
218
+ self.num_upsamplers += 1
219
+ else:
220
+ add_upsample = False
221
+
222
+ up_block = get_up_block(
223
+ up_block_type,
224
+ num_layers=layers_per_block + 1,
225
+ in_channels=input_channel,
226
+ out_channels=output_channel,
227
+ prev_output_channel=prev_output_channel,
228
+ temb_channels=time_embed_dim,
229
+ add_upsample=add_upsample,
230
+ resnet_eps=norm_eps,
231
+ resnet_act_fn=act_fn,
232
+ resnet_groups=norm_num_groups,
233
+ cross_attention_dim=cross_attention_dim,
234
+ attn_num_head_channels=reversed_attention_head_dim[i],
235
+ dual_cross_attention=dual_cross_attention,
236
+ use_linear_projection=use_linear_projection,
237
+ only_cross_attention=only_cross_attention[i],
238
+ upcast_attention=upcast_attention,
239
+ resnet_time_scale_shift=resnet_time_scale_shift,
240
+
241
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
242
+ unet_use_temporal_attention=unet_use_temporal_attention,
243
+
244
+ use_motion_module=use_motion_module and (res in motion_module_resolutions),
245
+ motion_module_type=motion_module_type,
246
+ motion_module_kwargs=motion_module_kwargs,
247
+ )
248
+ self.up_blocks.append(up_block)
249
+ prev_output_channel = output_channel
250
+
251
+ # out
252
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
253
+ self.conv_act = nn.SiLU()
254
+ self.conv_out = InflatedConv3d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
255
+
256
+ @property
257
+ def attn_processors(self) -> Dict[str, AttnProcessor]:
258
+ r"""
259
+ Returns:
260
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
261
+ indexed by its weight name.
262
+ """
263
+ # set recursively
264
+ processors = {}
265
+
266
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttnProcessor]):
267
+ if hasattr(module, "set_processor"):
268
+ processors[f"{name}.processor"] = module.processor
269
+
270
+ for sub_name, child in module.named_children():
271
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
272
+
273
+ return processors
274
+
275
+ for name, module in self.named_children():
276
+ fn_recursive_add_processors(name, module, processors)
277
+
278
+ return processors
279
+
280
+ def set_attn_processor(self, processor: Union[AttnProcessor, Dict[str, AttnProcessor]]):
281
+ r"""
282
+ Parameters:
283
+ `processor (`dict` of `AttnProcessor` or `AttnProcessor`):
284
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
285
+ of **all** `CrossAttention` layers.
286
+ In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainablae attention processors.:
287
+
288
+ """
289
+ count = len(self.attn_processors.keys())
290
+
291
+ if isinstance(processor, dict) and len(processor) != count:
292
+ raise ValueError(
293
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
294
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
295
+ )
296
+
297
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
298
+ if hasattr(module, "set_processor"):
299
+ if not isinstance(processor, dict):
300
+ print(f'Set {module}')
301
+ module.set_processor(processor)
302
+ else:
303
+ print(f'Set {module}')
304
+ module.set_processor(processor.pop(f"{name}.processor"))
305
+
306
+ for sub_name, child in module.named_children():
307
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
308
+
309
+ for name, module in self.named_children():
310
+ fn_recursive_attn_processor(name, module, processor)
311
+
312
+ def set_attention_slice(self, slice_size):
313
+ r"""
314
+ Enable sliced attention computation.
315
+
316
+ When this option is enabled, the attention module will split the input tensor in slices, to compute attention
317
+ in several steps. This is useful to save some memory in exchange for a small speed decrease.
318
+
319
+ Args:
320
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
321
+ When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
322
+ `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is
323
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
324
+ must be a multiple of `slice_size`.
325
+ """
326
+ sliceable_head_dims = []
327
+
328
+ def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module):
329
+ if hasattr(module, "set_attention_slice"):
330
+ sliceable_head_dims.append(module.sliceable_head_dim)
331
+
332
+ for child in module.children():
333
+ fn_recursive_retrieve_slicable_dims(child)
334
+
335
+ # retrieve number of attention layers
336
+ for module in self.children():
337
+ fn_recursive_retrieve_slicable_dims(module)
338
+
339
+ num_slicable_layers = len(sliceable_head_dims)
340
+
341
+ if slice_size == "auto":
342
+ # half the attention head size is usually a good trade-off between
343
+ # speed and memory
344
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
345
+ elif slice_size == "max":
346
+ # make smallest slice possible
347
+ slice_size = num_slicable_layers * [1]
348
+
349
+ slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
350
+
351
+ if len(slice_size) != len(sliceable_head_dims):
352
+ raise ValueError(
353
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
354
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
355
+ )
356
+
357
+ for i in range(len(slice_size)):
358
+ size = slice_size[i]
359
+ dim = sliceable_head_dims[i]
360
+ if size is not None and size > dim:
361
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
362
+
363
+ # Recursively walk through all the children.
364
+ # Any children which exposes the set_attention_slice method
365
+ # gets the message
366
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
367
+ if hasattr(module, "set_attention_slice"):
368
+ module.set_attention_slice(slice_size.pop())
369
+
370
+ for child in module.children():
371
+ fn_recursive_set_attention_slice(child, slice_size)
372
+
373
+ reversed_slice_size = list(reversed(slice_size))
374
+ for module in self.children():
375
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
376
+
377
+ def _set_gradient_checkpointing(self, module, value=False):
378
+ if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
379
+ module.gradient_checkpointing = value
380
+
381
+ def forward(
382
+ self,
383
+ sample: torch.FloatTensor,
384
+ mask_sample: torch.FloatTensor,
385
+ masked_sample: torch.FloatTensor,
386
+ timestep: Union[torch.Tensor, float, int],
387
+ encoder_hidden_states: torch.Tensor,
388
+ class_labels: Optional[torch.Tensor] = None,
389
+ attention_mask: Optional[torch.Tensor] = None,
390
+ image_embeds: Optional[torch.Tensor] = None,
391
+ return_dict: bool = True,
392
+ ) -> Union[UNet3DConditionOutput, Tuple]:
393
+ r"""
394
+ Args:
395
+ sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
396
+ timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
397
+ encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
398
+ return_dict (`bool`, *optional*, defaults to `True`):
399
+ Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
400
+
401
+ Returns:
402
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
403
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
404
+ returning a tuple, the first element is the sample tensor.
405
+ """
406
+ # image to video b c f h w
407
+ sample = torch.cat([sample, mask_sample, masked_sample], dim=1).to(sample.device)
408
+
409
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
410
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
411
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
412
+ # on the fly if necessary.
413
+
414
+ default_overall_up_factor = 2**self.num_upsamplers
415
+
416
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
417
+ forward_upsample_size = False
418
+ upsample_size = None
419
+
420
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
421
+ logger.info("Forward upsample size to force interpolation output size.")
422
+ forward_upsample_size = True
423
+
424
+ # prepare attention_mask
425
+ if attention_mask is not None:
426
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * - 10000.0
427
+ attention_mask = attention_mask.unsqueeze(1)
428
+
429
+ # center input if necessary
430
+ if self.config.center_input_sample:
431
+ sample = 2 * sample - 1.0
432
+
433
+ # time
434
+ timesteps = timestep
435
+ if not torch.is_tensor(timesteps):
436
+ # This would be a good case for the `match` statement (Python 3.10+)
437
+ is_mps = sample.device.type == "mps"
438
+ if isinstance(timestep, float):
439
+ dtype = torch.float32 if is_mps else torch.float64
440
+ else:
441
+ dtype = torch.int32 if is_mps else torch.int64
442
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
443
+ elif len(timesteps.shape) == 0:
444
+ timesteps = timesteps[None].to(sample.device)
445
+
446
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
447
+ timesteps = timesteps.expand(sample.shape[0])
448
+
449
+ t_emb = self.time_proj(timesteps)
450
+
451
+ # timesteps does not contain any weights and will always return f32 tensors
452
+ # but time_embedding might actually be running in fp16. so we need to cast here.
453
+ # there might be better ways to encapsulate this.
454
+ t_emb = t_emb.to(dtype=self.dtype)
455
+ emb = self.time_embedding(t_emb)
456
+
457
+ if self.class_embedding is not None:
458
+ if class_labels is None:
459
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
460
+
461
+ if self.config.class_embed_type == "timestep":
462
+ class_labels = self.time_proj(class_labels)
463
+
464
+ class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
465
+ emb = emb + class_emb
466
+
467
+ # prepare for ip-adapter
468
+ if image_embeds is not None:
469
+ image_embeds = self.encoder_hid_proj(
470
+ image_embeds).to(encoder_hidden_states.dtype)
471
+ encoder_hidden_states = torch.cat(
472
+ [encoder_hidden_states, image_embeds], dim=1)
473
+
474
+ # pre-process
475
+ # b c f h w
476
+ # 2 4 16 64 64
477
+ sample = self.conv_in(sample)
478
+ # down
479
+ down_block_res_samples = (sample,)
480
+ for downsample_block in self.down_blocks:
481
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
482
+ sample, res_samples = downsample_block(
483
+ hidden_states=sample,
484
+ temb=emb,
485
+ encoder_hidden_states=encoder_hidden_states,
486
+ attention_mask=attention_mask,
487
+ )
488
+ else:
489
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states)
490
+ down_block_res_samples += res_samples
491
+
492
+ # mid
493
+ sample = self.mid_block(
494
+ sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
495
+ )
496
+
497
+ # up
498
+ for i, upsample_block in enumerate(self.up_blocks):
499
+ is_final_block = i == len(self.up_blocks) - 1
500
+
501
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
502
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
503
+
504
+ # if we have not reached the final block and need to forward the
505
+ # upsample size, we do it here
506
+ if not is_final_block and forward_upsample_size:
507
+ upsample_size = down_block_res_samples[-1].shape[2:]
508
+
509
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
510
+ sample = upsample_block(
511
+ hidden_states=sample,
512
+ temb=emb,
513
+ res_hidden_states_tuple=res_samples,
514
+ encoder_hidden_states=encoder_hidden_states,
515
+ upsample_size=upsample_size,
516
+ attention_mask=attention_mask,
517
+ )
518
+ else:
519
+ sample = upsample_block(
520
+ hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size, encoder_hidden_states=encoder_hidden_states,
521
+ )
522
+
523
+ # post-process
524
+ sample = self.conv_norm_out(sample)
525
+ sample = self.conv_act(sample)
526
+ sample = self.conv_out(sample)
527
+
528
+ if not return_dict:
529
+ return (sample,)
530
+
531
+ return UNet3DConditionOutput(sample=sample)
532
+
533
+ @classmethod
534
+ def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, unet_additional_kwargs=None):
535
+ if subfolder is not None:
536
+ pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
537
+ print(f"loaded temporal unet's pretrained weights from {pretrained_model_path} ...")
538
+
539
+ config_file = os.path.join(pretrained_model_path, 'config.json')
540
+ if not os.path.isfile(config_file):
541
+ raise RuntimeError(f"{config_file} does not exist")
542
+ with open(config_file, "r") as f:
543
+ config = json.load(f)
544
+ config["_class_name"] = cls.__name__
545
+ config["down_block_types"] = [
546
+ "CrossAttnDownBlock3D",
547
+ "CrossAttnDownBlock3D",
548
+ "CrossAttnDownBlock3D",
549
+ "DownBlock3D"
550
+ ]
551
+ config["up_block_types"] = [
552
+ "UpBlock3D",
553
+ "CrossAttnUpBlock3D",
554
+ "CrossAttnUpBlock3D",
555
+ "CrossAttnUpBlock3D"
556
+ ]
557
+
558
+ from diffusers.utils import WEIGHTS_NAME
559
+ model = cls.from_config(config, **unet_additional_kwargs)
560
+ model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
561
+ if not os.path.isfile(model_file):
562
+ raise RuntimeError(f"{model_file} does not exist")
563
+ state_dict = torch.load(model_file, map_location="cpu")
564
+
565
+ m, u = model.load_state_dict(state_dict, strict=False)
566
+ print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
567
+ # print(f"### missing keys:\n{m}\n### unexpected keys:\n{u}\n")
568
+
569
+ params = [p.numel() if "temporal" in n else 0 for n, p in model.named_parameters()]
570
+ print(f"### Temporal Module Parameters: {sum(params) / 1e6} M")
571
+
572
+ return model
animatediff/models/unet_blocks.py ADDED
@@ -0,0 +1,733 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/guoyww/AnimateDiff
2
+
3
+ import torch
4
+ from torch import nn
5
+
6
+ from .attention import Transformer3DModel
7
+ from .resnet import Downsample3D, ResnetBlock3D, Upsample3D
8
+ from .motion_module import get_motion_module
9
+
10
+ import pdb
11
+
12
+ def get_down_block(
13
+ down_block_type,
14
+ num_layers,
15
+ in_channels,
16
+ out_channels,
17
+ temb_channels,
18
+ add_downsample,
19
+ resnet_eps,
20
+ resnet_act_fn,
21
+ attn_num_head_channels,
22
+ resnet_groups=None,
23
+ cross_attention_dim=None,
24
+ downsample_padding=None,
25
+ dual_cross_attention=False,
26
+ use_linear_projection=False,
27
+ only_cross_attention=False,
28
+ upcast_attention=False,
29
+ resnet_time_scale_shift="default",
30
+
31
+ unet_use_cross_frame_attention=None,
32
+ unet_use_temporal_attention=None,
33
+
34
+ use_motion_module=None,
35
+
36
+ motion_module_type=None,
37
+ motion_module_kwargs=None,
38
+ ):
39
+ down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
40
+ if down_block_type == "DownBlock3D":
41
+ return DownBlock3D(
42
+ num_layers=num_layers,
43
+ in_channels=in_channels,
44
+ out_channels=out_channels,
45
+ temb_channels=temb_channels,
46
+ add_downsample=add_downsample,
47
+ resnet_eps=resnet_eps,
48
+ resnet_act_fn=resnet_act_fn,
49
+ resnet_groups=resnet_groups,
50
+ downsample_padding=downsample_padding,
51
+ resnet_time_scale_shift=resnet_time_scale_shift,
52
+
53
+ use_motion_module=use_motion_module,
54
+ motion_module_type=motion_module_type,
55
+ motion_module_kwargs=motion_module_kwargs,
56
+ )
57
+ elif down_block_type == "CrossAttnDownBlock3D":
58
+ if cross_attention_dim is None:
59
+ raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D")
60
+ return CrossAttnDownBlock3D(
61
+ num_layers=num_layers,
62
+ in_channels=in_channels,
63
+ out_channels=out_channels,
64
+ temb_channels=temb_channels,
65
+ add_downsample=add_downsample,
66
+ resnet_eps=resnet_eps,
67
+ resnet_act_fn=resnet_act_fn,
68
+ resnet_groups=resnet_groups,
69
+ downsample_padding=downsample_padding,
70
+ cross_attention_dim=cross_attention_dim,
71
+ attn_num_head_channels=attn_num_head_channels,
72
+ dual_cross_attention=dual_cross_attention,
73
+ use_linear_projection=use_linear_projection,
74
+ only_cross_attention=only_cross_attention,
75
+ upcast_attention=upcast_attention,
76
+ resnet_time_scale_shift=resnet_time_scale_shift,
77
+
78
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
79
+ unet_use_temporal_attention=unet_use_temporal_attention,
80
+
81
+ use_motion_module=use_motion_module,
82
+ motion_module_type=motion_module_type,
83
+ motion_module_kwargs=motion_module_kwargs,
84
+ )
85
+ raise ValueError(f"{down_block_type} does not exist.")
86
+
87
+
88
+ def get_up_block(
89
+ up_block_type,
90
+ num_layers,
91
+ in_channels,
92
+ out_channels,
93
+ prev_output_channel,
94
+ temb_channels,
95
+ add_upsample,
96
+ resnet_eps,
97
+ resnet_act_fn,
98
+ attn_num_head_channels,
99
+ resnet_groups=None,
100
+ cross_attention_dim=None,
101
+ dual_cross_attention=False,
102
+ use_linear_projection=False,
103
+ only_cross_attention=False,
104
+ upcast_attention=False,
105
+ resnet_time_scale_shift="default",
106
+
107
+ unet_use_cross_frame_attention=None,
108
+ unet_use_temporal_attention=None,
109
+
110
+ use_motion_module=None,
111
+ motion_module_type=None,
112
+ motion_module_kwargs=None,
113
+ ):
114
+ up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
115
+ if up_block_type == "UpBlock3D":
116
+ return UpBlock3D(
117
+ num_layers=num_layers,
118
+ in_channels=in_channels,
119
+ out_channels=out_channels,
120
+ prev_output_channel=prev_output_channel,
121
+ temb_channels=temb_channels,
122
+ add_upsample=add_upsample,
123
+ resnet_eps=resnet_eps,
124
+ resnet_act_fn=resnet_act_fn,
125
+ resnet_groups=resnet_groups,
126
+ resnet_time_scale_shift=resnet_time_scale_shift,
127
+
128
+ use_motion_module=use_motion_module,
129
+ motion_module_type=motion_module_type,
130
+ motion_module_kwargs=motion_module_kwargs,
131
+ )
132
+ elif up_block_type == "CrossAttnUpBlock3D":
133
+ if cross_attention_dim is None:
134
+ raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D")
135
+ return CrossAttnUpBlock3D(
136
+ num_layers=num_layers,
137
+ in_channels=in_channels,
138
+ out_channels=out_channels,
139
+ prev_output_channel=prev_output_channel,
140
+ temb_channels=temb_channels,
141
+ add_upsample=add_upsample,
142
+ resnet_eps=resnet_eps,
143
+ resnet_act_fn=resnet_act_fn,
144
+ resnet_groups=resnet_groups,
145
+ cross_attention_dim=cross_attention_dim,
146
+ attn_num_head_channels=attn_num_head_channels,
147
+ dual_cross_attention=dual_cross_attention,
148
+ use_linear_projection=use_linear_projection,
149
+ only_cross_attention=only_cross_attention,
150
+ upcast_attention=upcast_attention,
151
+ resnet_time_scale_shift=resnet_time_scale_shift,
152
+
153
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
154
+ unet_use_temporal_attention=unet_use_temporal_attention,
155
+
156
+ use_motion_module=use_motion_module,
157
+ motion_module_type=motion_module_type,
158
+ motion_module_kwargs=motion_module_kwargs,
159
+ )
160
+ raise ValueError(f"{up_block_type} does not exist.")
161
+
162
+
163
+ class UNetMidBlock3DCrossAttn(nn.Module):
164
+ def __init__(
165
+ self,
166
+ in_channels: int,
167
+ temb_channels: int,
168
+ dropout: float = 0.0,
169
+ num_layers: int = 1,
170
+ resnet_eps: float = 1e-6,
171
+ resnet_time_scale_shift: str = "default",
172
+ resnet_act_fn: str = "swish",
173
+ resnet_groups: int = 32,
174
+ resnet_pre_norm: bool = True,
175
+ attn_num_head_channels=1,
176
+ output_scale_factor=1.0,
177
+ cross_attention_dim=1280,
178
+ dual_cross_attention=False,
179
+ use_linear_projection=False,
180
+ upcast_attention=False,
181
+
182
+ unet_use_cross_frame_attention=None,
183
+ unet_use_temporal_attention=None,
184
+
185
+ use_motion_module=None,
186
+
187
+ motion_module_type=None,
188
+ motion_module_kwargs=None,
189
+ ):
190
+ super().__init__()
191
+
192
+ self.has_cross_attention = True
193
+ self.attn_num_head_channels = attn_num_head_channels
194
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
195
+
196
+ # there is always at least one resnet
197
+ resnets = [
198
+ ResnetBlock3D(
199
+ in_channels=in_channels,
200
+ out_channels=in_channels,
201
+ temb_channels=temb_channels,
202
+ eps=resnet_eps,
203
+ groups=resnet_groups,
204
+ dropout=dropout,
205
+ time_embedding_norm=resnet_time_scale_shift,
206
+ non_linearity=resnet_act_fn,
207
+ output_scale_factor=output_scale_factor,
208
+ pre_norm=resnet_pre_norm,
209
+ )
210
+ ]
211
+ attentions = []
212
+ motion_modules = []
213
+
214
+ for _ in range(num_layers):
215
+ if dual_cross_attention:
216
+ raise NotImplementedError
217
+ attentions.append(
218
+ Transformer3DModel(
219
+ attn_num_head_channels,
220
+ in_channels // attn_num_head_channels,
221
+ in_channels=in_channels,
222
+ num_layers=1,
223
+ cross_attention_dim=cross_attention_dim,
224
+ norm_num_groups=resnet_groups,
225
+ use_linear_projection=use_linear_projection,
226
+ upcast_attention=upcast_attention,
227
+
228
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
229
+ unet_use_temporal_attention=unet_use_temporal_attention,
230
+ )
231
+ )
232
+ motion_modules.append(
233
+ get_motion_module(
234
+ in_channels=in_channels,
235
+ motion_module_type=motion_module_type,
236
+ motion_module_kwargs=motion_module_kwargs,
237
+ ) if use_motion_module else None
238
+ )
239
+ resnets.append(
240
+ ResnetBlock3D(
241
+ in_channels=in_channels,
242
+ out_channels=in_channels,
243
+ temb_channels=temb_channels,
244
+ eps=resnet_eps,
245
+ groups=resnet_groups,
246
+ dropout=dropout,
247
+ time_embedding_norm=resnet_time_scale_shift,
248
+ non_linearity=resnet_act_fn,
249
+ output_scale_factor=output_scale_factor,
250
+ pre_norm=resnet_pre_norm,
251
+ )
252
+ )
253
+
254
+ self.attentions = nn.ModuleList(attentions)
255
+ self.resnets = nn.ModuleList(resnets)
256
+ self.motion_modules = nn.ModuleList(motion_modules)
257
+
258
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None):
259
+ hidden_states = self.resnets[0](hidden_states, temb)
260
+ for attn, resnet, motion_module in zip(self.attentions, self.resnets[1:], self.motion_modules):
261
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
262
+ hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states
263
+ hidden_states = resnet(hidden_states, temb)
264
+
265
+ return hidden_states
266
+
267
+
268
+ class CrossAttnDownBlock3D(nn.Module):
269
+ def __init__(
270
+ self,
271
+ in_channels: int,
272
+ out_channels: int,
273
+ temb_channels: int,
274
+ dropout: float = 0.0,
275
+ num_layers: int = 1,
276
+ resnet_eps: float = 1e-6,
277
+ resnet_time_scale_shift: str = "default",
278
+ resnet_act_fn: str = "swish",
279
+ resnet_groups: int = 32,
280
+ resnet_pre_norm: bool = True,
281
+ attn_num_head_channels=1,
282
+ cross_attention_dim=1280,
283
+ output_scale_factor=1.0,
284
+ downsample_padding=1,
285
+ add_downsample=True,
286
+ dual_cross_attention=False,
287
+ use_linear_projection=False,
288
+ only_cross_attention=False,
289
+ upcast_attention=False,
290
+
291
+ unet_use_cross_frame_attention=None,
292
+ unet_use_temporal_attention=None,
293
+
294
+ use_motion_module=None,
295
+
296
+ motion_module_type=None,
297
+ motion_module_kwargs=None,
298
+ ):
299
+ super().__init__()
300
+ resnets = []
301
+ attentions = []
302
+ motion_modules = []
303
+
304
+ self.has_cross_attention = True
305
+ self.attn_num_head_channels = attn_num_head_channels
306
+
307
+ for i in range(num_layers):
308
+ in_channels = in_channels if i == 0 else out_channels
309
+ resnets.append(
310
+ ResnetBlock3D(
311
+ in_channels=in_channels,
312
+ out_channels=out_channels,
313
+ temb_channels=temb_channels,
314
+ eps=resnet_eps,
315
+ groups=resnet_groups,
316
+ dropout=dropout,
317
+ time_embedding_norm=resnet_time_scale_shift,
318
+ non_linearity=resnet_act_fn,
319
+ output_scale_factor=output_scale_factor,
320
+ pre_norm=resnet_pre_norm,
321
+ )
322
+ )
323
+ if dual_cross_attention:
324
+ raise NotImplementedError
325
+ attentions.append(
326
+ Transformer3DModel(
327
+ attn_num_head_channels,
328
+ out_channels // attn_num_head_channels,
329
+ in_channels=out_channels,
330
+ num_layers=1,
331
+ cross_attention_dim=cross_attention_dim,
332
+ norm_num_groups=resnet_groups,
333
+ use_linear_projection=use_linear_projection,
334
+ only_cross_attention=only_cross_attention,
335
+ upcast_attention=upcast_attention,
336
+
337
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
338
+ unet_use_temporal_attention=unet_use_temporal_attention,
339
+ )
340
+ )
341
+ motion_modules.append(
342
+ get_motion_module(
343
+ in_channels=out_channels,
344
+ motion_module_type=motion_module_type,
345
+ motion_module_kwargs=motion_module_kwargs,
346
+ ) if use_motion_module else None
347
+ )
348
+
349
+ self.attentions = nn.ModuleList(attentions)
350
+ self.resnets = nn.ModuleList(resnets)
351
+ self.motion_modules = nn.ModuleList(motion_modules)
352
+
353
+ if add_downsample:
354
+ self.downsamplers = nn.ModuleList(
355
+ [
356
+ Downsample3D(
357
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
358
+ )
359
+ ]
360
+ )
361
+ else:
362
+ self.downsamplers = None
363
+
364
+ self.gradient_checkpointing = False
365
+
366
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None):
367
+ output_states = ()
368
+
369
+ for resnet, attn, motion_module in zip(self.resnets, self.attentions, self.motion_modules):
370
+ if self.training and self.gradient_checkpointing:
371
+
372
+ def create_custom_forward(module, return_dict=None):
373
+ def custom_forward(*inputs):
374
+ if return_dict is not None:
375
+ return module(*inputs, return_dict=return_dict)
376
+ else:
377
+ return module(*inputs)
378
+
379
+ return custom_forward
380
+
381
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
382
+ hidden_states = torch.utils.checkpoint.checkpoint(
383
+ create_custom_forward(attn, return_dict=False),
384
+ hidden_states,
385
+ encoder_hidden_states,
386
+ )[0]
387
+ if motion_module is not None:
388
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states)
389
+
390
+ else:
391
+ hidden_states = resnet(hidden_states, temb)
392
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
393
+
394
+ # add motion module
395
+ hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states
396
+
397
+ output_states += (hidden_states,)
398
+
399
+ if self.downsamplers is not None:
400
+ for downsampler in self.downsamplers:
401
+ hidden_states = downsampler(hidden_states)
402
+
403
+ output_states += (hidden_states,)
404
+
405
+ return hidden_states, output_states
406
+
407
+
408
+ class DownBlock3D(nn.Module):
409
+ def __init__(
410
+ self,
411
+ in_channels: int,
412
+ out_channels: int,
413
+ temb_channels: int,
414
+ dropout: float = 0.0,
415
+ num_layers: int = 1,
416
+ resnet_eps: float = 1e-6,
417
+ resnet_time_scale_shift: str = "default",
418
+ resnet_act_fn: str = "swish",
419
+ resnet_groups: int = 32,
420
+ resnet_pre_norm: bool = True,
421
+ output_scale_factor=1.0,
422
+ add_downsample=True,
423
+ downsample_padding=1,
424
+
425
+ use_motion_module=None,
426
+ motion_module_type=None,
427
+ motion_module_kwargs=None,
428
+ ):
429
+ super().__init__()
430
+ resnets = []
431
+ motion_modules = []
432
+
433
+ for i in range(num_layers):
434
+ in_channels = in_channels if i == 0 else out_channels
435
+ resnets.append(
436
+ ResnetBlock3D(
437
+ in_channels=in_channels,
438
+ out_channels=out_channels,
439
+ temb_channels=temb_channels,
440
+ eps=resnet_eps,
441
+ groups=resnet_groups,
442
+ dropout=dropout,
443
+ time_embedding_norm=resnet_time_scale_shift,
444
+ non_linearity=resnet_act_fn,
445
+ output_scale_factor=output_scale_factor,
446
+ pre_norm=resnet_pre_norm,
447
+ )
448
+ )
449
+ motion_modules.append(
450
+ get_motion_module(
451
+ in_channels=out_channels,
452
+ motion_module_type=motion_module_type,
453
+ motion_module_kwargs=motion_module_kwargs,
454
+ ) if use_motion_module else None
455
+ )
456
+
457
+ self.resnets = nn.ModuleList(resnets)
458
+ self.motion_modules = nn.ModuleList(motion_modules)
459
+
460
+ if add_downsample:
461
+ self.downsamplers = nn.ModuleList(
462
+ [
463
+ Downsample3D(
464
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
465
+ )
466
+ ]
467
+ )
468
+ else:
469
+ self.downsamplers = None
470
+
471
+ self.gradient_checkpointing = False
472
+
473
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
474
+ output_states = ()
475
+
476
+ for resnet, motion_module in zip(self.resnets, self.motion_modules):
477
+ if self.training and self.gradient_checkpointing:
478
+ def create_custom_forward(module):
479
+ def custom_forward(*inputs):
480
+ return module(*inputs)
481
+
482
+ return custom_forward
483
+
484
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
485
+ if motion_module is not None:
486
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states)
487
+ else:
488
+ hidden_states = resnet(hidden_states, temb)
489
+
490
+ # add motion module
491
+ hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states
492
+
493
+ output_states += (hidden_states,)
494
+
495
+ if self.downsamplers is not None:
496
+ for downsampler in self.downsamplers:
497
+ hidden_states = downsampler(hidden_states)
498
+
499
+ output_states += (hidden_states,)
500
+
501
+ return hidden_states, output_states
502
+
503
+
504
+ class CrossAttnUpBlock3D(nn.Module):
505
+ def __init__(
506
+ self,
507
+ in_channels: int,
508
+ out_channels: int,
509
+ prev_output_channel: int,
510
+ temb_channels: int,
511
+ dropout: float = 0.0,
512
+ num_layers: int = 1,
513
+ resnet_eps: float = 1e-6,
514
+ resnet_time_scale_shift: str = "default",
515
+ resnet_act_fn: str = "swish",
516
+ resnet_groups: int = 32,
517
+ resnet_pre_norm: bool = True,
518
+ attn_num_head_channels=1,
519
+ cross_attention_dim=1280,
520
+ output_scale_factor=1.0,
521
+ add_upsample=True,
522
+ dual_cross_attention=False,
523
+ use_linear_projection=False,
524
+ only_cross_attention=False,
525
+ upcast_attention=False,
526
+
527
+ unet_use_cross_frame_attention=None,
528
+ unet_use_temporal_attention=None,
529
+
530
+ use_motion_module=None,
531
+
532
+ motion_module_type=None,
533
+ motion_module_kwargs=None,
534
+ ):
535
+ super().__init__()
536
+ resnets = []
537
+ attentions = []
538
+ motion_modules = []
539
+
540
+ self.has_cross_attention = True
541
+ self.attn_num_head_channels = attn_num_head_channels
542
+
543
+ for i in range(num_layers):
544
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
545
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
546
+
547
+ resnets.append(
548
+ ResnetBlock3D(
549
+ in_channels=resnet_in_channels + res_skip_channels,
550
+ out_channels=out_channels,
551
+ temb_channels=temb_channels,
552
+ eps=resnet_eps,
553
+ groups=resnet_groups,
554
+ dropout=dropout,
555
+ time_embedding_norm=resnet_time_scale_shift,
556
+ non_linearity=resnet_act_fn,
557
+ output_scale_factor=output_scale_factor,
558
+ pre_norm=resnet_pre_norm,
559
+ )
560
+ )
561
+ if dual_cross_attention:
562
+ raise NotImplementedError
563
+ attentions.append(
564
+ Transformer3DModel(
565
+ attn_num_head_channels,
566
+ out_channels // attn_num_head_channels,
567
+ in_channels=out_channels,
568
+ num_layers=1,
569
+ cross_attention_dim=cross_attention_dim,
570
+ norm_num_groups=resnet_groups,
571
+ use_linear_projection=use_linear_projection,
572
+ only_cross_attention=only_cross_attention,
573
+ upcast_attention=upcast_attention,
574
+
575
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
576
+ unet_use_temporal_attention=unet_use_temporal_attention,
577
+ )
578
+ )
579
+ motion_modules.append(
580
+ get_motion_module(
581
+ in_channels=out_channels,
582
+ motion_module_type=motion_module_type,
583
+ motion_module_kwargs=motion_module_kwargs,
584
+ ) if use_motion_module else None
585
+ )
586
+
587
+ self.attentions = nn.ModuleList(attentions)
588
+ self.resnets = nn.ModuleList(resnets)
589
+ self.motion_modules = nn.ModuleList(motion_modules)
590
+
591
+ if add_upsample:
592
+ self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)])
593
+ else:
594
+ self.upsamplers = None
595
+
596
+ self.gradient_checkpointing = False
597
+
598
+ def forward(
599
+ self,
600
+ hidden_states,
601
+ res_hidden_states_tuple,
602
+ temb=None,
603
+ encoder_hidden_states=None,
604
+ upsample_size=None,
605
+ attention_mask=None,
606
+ ):
607
+ for resnet, attn, motion_module in zip(self.resnets, self.attentions, self.motion_modules):
608
+ # pop res hidden states
609
+ res_hidden_states = res_hidden_states_tuple[-1]
610
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
611
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
612
+
613
+ if self.training and self.gradient_checkpointing:
614
+
615
+ def create_custom_forward(module, return_dict=None):
616
+ def custom_forward(*inputs):
617
+ if return_dict is not None:
618
+ return module(*inputs, return_dict=return_dict)
619
+ else:
620
+ return module(*inputs)
621
+
622
+ return custom_forward
623
+
624
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
625
+ hidden_states = torch.utils.checkpoint.checkpoint(
626
+ create_custom_forward(attn, return_dict=False),
627
+ hidden_states,
628
+ encoder_hidden_states,
629
+ )[0]
630
+ if motion_module is not None:
631
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states)
632
+
633
+ else:
634
+ hidden_states = resnet(hidden_states, temb)
635
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
636
+
637
+ # add motion module
638
+ hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states
639
+
640
+ if self.upsamplers is not None:
641
+ for upsampler in self.upsamplers:
642
+ hidden_states = upsampler(hidden_states, upsample_size)
643
+
644
+ return hidden_states
645
+
646
+
647
+ class UpBlock3D(nn.Module):
648
+ def __init__(
649
+ self,
650
+ in_channels: int,
651
+ prev_output_channel: int,
652
+ out_channels: int,
653
+ temb_channels: int,
654
+ dropout: float = 0.0,
655
+ num_layers: int = 1,
656
+ resnet_eps: float = 1e-6,
657
+ resnet_time_scale_shift: str = "default",
658
+ resnet_act_fn: str = "swish",
659
+ resnet_groups: int = 32,
660
+ resnet_pre_norm: bool = True,
661
+ output_scale_factor=1.0,
662
+ add_upsample=True,
663
+
664
+ use_motion_module=None,
665
+ motion_module_type=None,
666
+ motion_module_kwargs=None,
667
+ ):
668
+ super().__init__()
669
+ resnets = []
670
+ motion_modules = []
671
+
672
+ for i in range(num_layers):
673
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
674
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
675
+
676
+ resnets.append(
677
+ ResnetBlock3D(
678
+ in_channels=resnet_in_channels + res_skip_channels,
679
+ out_channels=out_channels,
680
+ temb_channels=temb_channels,
681
+ eps=resnet_eps,
682
+ groups=resnet_groups,
683
+ dropout=dropout,
684
+ time_embedding_norm=resnet_time_scale_shift,
685
+ non_linearity=resnet_act_fn,
686
+ output_scale_factor=output_scale_factor,
687
+ pre_norm=resnet_pre_norm,
688
+ )
689
+ )
690
+ motion_modules.append(
691
+ get_motion_module(
692
+ in_channels=out_channels,
693
+ motion_module_type=motion_module_type,
694
+ motion_module_kwargs=motion_module_kwargs,
695
+ ) if use_motion_module else None
696
+ )
697
+
698
+ self.resnets = nn.ModuleList(resnets)
699
+ self.motion_modules = nn.ModuleList(motion_modules)
700
+
701
+ if add_upsample:
702
+ self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)])
703
+ else:
704
+ self.upsamplers = None
705
+
706
+ self.gradient_checkpointing = False
707
+
708
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, encoder_hidden_states=None,):
709
+ for resnet, motion_module in zip(self.resnets, self.motion_modules):
710
+ # pop res hidden states
711
+ res_hidden_states = res_hidden_states_tuple[-1]
712
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
713
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
714
+
715
+ if self.training and self.gradient_checkpointing:
716
+ def create_custom_forward(module):
717
+ def custom_forward(*inputs):
718
+ return module(*inputs)
719
+
720
+ return custom_forward
721
+
722
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
723
+ if motion_module is not None:
724
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states)
725
+ else:
726
+ hidden_states = resnet(hidden_states, temb)
727
+ hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states
728
+
729
+ if self.upsamplers is not None:
730
+ for upsampler in self.upsamplers:
731
+ hidden_states = upsampler(hidden_states, upsample_size)
732
+
733
+ return hidden_states
animatediff/pipelines/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .i2v_pipeline import I2VPipeline
2
+ from .pipeline_animation import AnimationPipeline
3
+ from .validation_pipeline import ValidationPipeline
4
+
5
+ __all__ = ['I2VPipeline', 'AnimationPipeline', 'ValidationPipeline']
animatediff/pipelines/i2v_pipeline.py ADDED
@@ -0,0 +1,775 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/showlab/Tune-A-Video/blob/main/tuneavideo/pipelines/pipeline_tuneavideo.py
2
+ import inspect
3
+ import os.path as osp
4
+ from dataclasses import dataclass
5
+ from typing import Callable, List, Optional, Union
6
+
7
+ import numpy as np
8
+ import torch
9
+ from diffusers.configuration_utils import FrozenDict
10
+ from diffusers.loaders import IPAdapterMixin
11
+ from diffusers.models import AutoencoderKL
12
+ from diffusers.pipelines import DiffusionPipeline
13
+ from diffusers.schedulers import (DDIMScheduler, DPMSolverMultistepScheduler,
14
+ EulerAncestralDiscreteScheduler,
15
+ EulerDiscreteScheduler, LMSDiscreteScheduler,
16
+ PNDMScheduler)
17
+ from diffusers.utils import (BaseOutput, deprecate, is_accelerate_available,
18
+ logging)
19
+ from diffusers.utils.import_utils import is_xformers_available
20
+ from einops import rearrange
21
+ from omegaconf import OmegaConf
22
+ from packaging import version
23
+ from safetensors import safe_open
24
+ from tqdm import tqdm
25
+ from transformers import (CLIPImageProcessor, CLIPTextModel, CLIPTokenizer,
26
+ CLIPVisionModelWithProjection)
27
+
28
+ from animatediff.models.resnet import InflatedConv3d
29
+ from animatediff.models.unet import UNet3DConditionModel
30
+ from animatediff.utils.convert_from_ckpt import (convert_ldm_clip_checkpoint,
31
+ convert_ldm_unet_checkpoint,
32
+ convert_ldm_vae_checkpoint)
33
+ from animatediff.utils.convert_lora_safetensor_to_diffusers import \
34
+ convert_lora_model_level
35
+ from animatediff.utils.util import prepare_mask_coef_by_statistics
36
+
37
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
38
+
39
+
40
+ DEFAULT_N_PROMPT = ('wrong white balance, dark, sketches,worst quality,'
41
+ 'low quality, deformed, distorted, disfigured, bad eyes, '
42
+ 'wrong lips,weird mouth, bad teeth, mutated hands and fingers, '
43
+ 'bad anatomy,wrong anatomy, amputation, extra limb, '
44
+ 'missing limb, floating,limbs, disconnected limbs, mutation, '
45
+ 'ugly, disgusting, bad_pictures, negative_hand-neg')
46
+
47
+
48
+ @dataclass
49
+ class AnimationPipelineOutput(BaseOutput):
50
+ videos: Union[torch.Tensor, np.ndarray]
51
+
52
+
53
+ class I2VPipeline(DiffusionPipeline, IPAdapterMixin):
54
+ _optional_components = []
55
+
56
+ def __init__(
57
+ self,
58
+ vae: AutoencoderKL,
59
+ text_encoder: CLIPTextModel,
60
+ tokenizer: CLIPTokenizer,
61
+ unet: UNet3DConditionModel,
62
+ scheduler: Union[
63
+ DDIMScheduler,
64
+ PNDMScheduler,
65
+ LMSDiscreteScheduler,
66
+ EulerDiscreteScheduler,
67
+ EulerAncestralDiscreteScheduler,
68
+ DPMSolverMultistepScheduler,
69
+ ],
70
+ feature_extractor: CLIPImageProcessor = None,
71
+ image_encoder: CLIPVisionModelWithProjection = None,
72
+ ):
73
+ super().__init__()
74
+
75
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
76
+ deprecation_message = (
77
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
78
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
79
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
80
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
81
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
82
+ " file"
83
+ )
84
+ deprecate("steps_offset!=1", "1.0.0",
85
+ deprecation_message, standard_warn=False)
86
+ new_config = dict(scheduler.config)
87
+ new_config["steps_offset"] = 1
88
+ scheduler._internal_dict = FrozenDict(new_config)
89
+
90
+ if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
91
+ deprecation_message = (
92
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
93
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
94
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
95
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
96
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
97
+ )
98
+ deprecate("clip_sample not set", "1.0.0",
99
+ deprecation_message, standard_warn=False)
100
+ new_config = dict(scheduler.config)
101
+ new_config["clip_sample"] = False
102
+ scheduler._internal_dict = FrozenDict(new_config)
103
+
104
+ is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
105
+ version.parse(unet.config._diffusers_version).base_version
106
+ ) < version.parse("0.9.0.dev0")
107
+ is_unet_sample_size_less_64 = hasattr(
108
+ unet.config, "sample_size") and unet.config.sample_size < 64
109
+ if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
110
+ deprecation_message = (
111
+ "The configuration file of the unet has set the default `sample_size` to smaller than"
112
+ " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
113
+ " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
114
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
115
+ " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
116
+ " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
117
+ " in the config might lead to incorrect results in future versions. If you have downloaded this"
118
+ " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
119
+ " the `unet/config.json` file"
120
+ )
121
+ deprecate("sample_size<64", "1.0.0",
122
+ deprecation_message, standard_warn=False)
123
+ new_config = dict(unet.config)
124
+ new_config["sample_size"] = 64
125
+ unet._internal_dict = FrozenDict(new_config)
126
+
127
+ self.register_modules(
128
+ vae=vae,
129
+ text_encoder=text_encoder,
130
+ tokenizer=tokenizer,
131
+ unet=unet,
132
+ image_encoder=image_encoder,
133
+ feature_extractor=feature_extractor,
134
+ scheduler=scheduler,
135
+ )
136
+ self.vae_scale_factor = 2 ** (
137
+ len(self.vae.config.block_out_channels) - 1)
138
+ self.use_ip_adapter = False
139
+ self.st_motion = None
140
+
141
+ def set_st_motion(self, st_motion: List):
142
+ """Set style transfer motion."""
143
+ self.st_motion = st_motion
144
+
145
+ @classmethod
146
+ def build_pipeline(cls,
147
+ base_cfg,
148
+ base_model: str,
149
+ unet_path: str,
150
+ dreambooth_path: Optional[str] = None,
151
+ lora_path: Optional[str] = None,
152
+ lora_alpha: int = 0,
153
+ vae_path: Optional[str] = None,
154
+ ip_adapter_path: Optional[str] = None,
155
+ ip_adapter_scale: float = 0.0,
156
+ only_load_vae_decoder: bool = False,
157
+ only_load_vae_encoder: bool = False) -> 'I2VPipeline':
158
+ """Method to build pipeline in a faster way~
159
+ Args:
160
+ base_cfg: The config to build model
161
+ base_mode: The model id to initialize StableDiffusion
162
+ unet_path: Path for i2v unet
163
+
164
+ dreambooth_path: path for dreambooth model
165
+ lora_path: path for lora model
166
+ lora_alpha: value for lora scale
167
+
168
+ only_load_vae_decoder: Only load VAE decoder from dreambooth / VAE ckpt
169
+ and maitain encoder as original.
170
+
171
+ """
172
+ # build unet
173
+ unet = UNet3DConditionModel.from_pretrained_2d(
174
+ base_model, subfolder="unet",
175
+ unet_additional_kwargs=OmegaConf.to_container(
176
+ base_cfg.unet_additional_kwargs))
177
+
178
+ old_weights = unet.conv_in.weight
179
+ old_bias = unet.conv_in.bias
180
+ new_conv1 = InflatedConv3d(
181
+ 9, old_weights.shape[0],
182
+ kernel_size=unet.conv_in.kernel_size,
183
+ stride=unet.conv_in.stride,
184
+ padding=unet.conv_in.padding,
185
+ bias=True if old_bias is not None else False)
186
+ param = torch.zeros((320, 5, 3, 3), requires_grad=True)
187
+ new_conv1.weight = torch.nn.Parameter(
188
+ torch.cat((old_weights, param), dim=1))
189
+ if old_bias is not None:
190
+ new_conv1.bias = old_bias
191
+ unet.conv_in = new_conv1
192
+ unet.config["in_channels"] = 9
193
+
194
+ unet_ckpt = torch.load(unet_path, map_location='cpu')
195
+ # filter unet ckpt, only load motion module and conv_inv
196
+ unet_ckpt = {k: v for k, v in unet_ckpt.items()
197
+ if 'motion_module' in k or 'conv_in' in k}
198
+ print(f'Unet prefix: ')
199
+ print(set([k.split('.')[0] for k in unet_ckpt.keys()]))
200
+ unet.load_state_dict(unet_ckpt, strict=False)
201
+
202
+ # load vae, tokenizer, text encoder
203
+ vae = AutoencoderKL.from_pretrained(base_model, subfolder="vae")
204
+ tokenizer = CLIPTokenizer.from_pretrained(
205
+ base_model, subfolder="tokenizer")
206
+ text_encoder = CLIPTextModel.from_pretrained(
207
+ base_model, subfolder="text_encoder")
208
+ noise_scheduler = DDIMScheduler(
209
+ **OmegaConf.to_container(base_cfg.noise_scheduler_kwargs))
210
+
211
+ if dreambooth_path and dreambooth_path.upper() != 'NONE':
212
+
213
+ print(" >>> Begin loading DreamBooth >>>")
214
+ base_model_state_dict = {}
215
+ with safe_open(dreambooth_path, framework="pt", device="cpu") as f:
216
+ for key in f.keys():
217
+ base_model_state_dict[key] = f.get_tensor(key)
218
+
219
+ # load unet
220
+ converted_unet_checkpoint = convert_ldm_unet_checkpoint(
221
+ base_model_state_dict, unet.config)
222
+
223
+ old_value = converted_unet_checkpoint['conv_in.weight']
224
+ new_param = unet_ckpt['conv_in.weight'][:, 4:, :, :].clone().cpu()
225
+ new_value = torch.nn.Parameter(
226
+ torch.cat((old_value, new_param), dim=1))
227
+ converted_unet_checkpoint['conv_in.weight'] = new_value
228
+ unet.load_state_dict(converted_unet_checkpoint, strict=False)
229
+
230
+ # load vae
231
+ converted_vae_checkpoint = convert_ldm_vae_checkpoint(
232
+ base_model_state_dict, vae.config,
233
+ only_decoder=only_load_vae_decoder,
234
+ only_encoder=only_load_vae_encoder,)
235
+ need_strict = not (only_load_vae_decoder or only_load_vae_encoder)
236
+ vae.load_state_dict(converted_vae_checkpoint, strict=need_strict)
237
+ print('Prefix in loaded VAE checkpoint: ')
238
+ print(set([k.split('.')[0]
239
+ for k in converted_vae_checkpoint.keys()]))
240
+
241
+ # load text encoder
242
+ text_encoder_checkpoint = convert_ldm_clip_checkpoint(
243
+ base_model_state_dict)
244
+ if text_encoder_checkpoint:
245
+ text_encoder.load_state_dict(text_encoder_checkpoint)
246
+
247
+ print(" <<< Loaded DreamBooth <<<")
248
+
249
+ if vae_path:
250
+ print(' >>> Begin loading VAE >>>')
251
+ vae_state_dict = {}
252
+ if vae_path.endswith('safetensors'):
253
+ with safe_open(vae_path, framework="pt", device="cpu") as f:
254
+ for key in f.keys():
255
+ vae_state_dict[key] = f.get_tensor(key)
256
+ elif vae_path.endswith('ckpt') or vae_path.endswith('pt'):
257
+ vae_state_dict = torch.load(vae_path, map_location='cpu')
258
+ if 'state_dict' in vae_state_dict:
259
+ vae_state_dict = vae_state_dict['state_dict']
260
+
261
+ vae_state_dict = {
262
+ f'first_stage_model.{k}': v for k, v in vae_state_dict.items()}
263
+
264
+ converted_vae_checkpoint = convert_ldm_vae_checkpoint(
265
+ vae_state_dict, vae.config,
266
+ only_decoder=only_load_vae_decoder,
267
+ only_encoder=only_load_vae_encoder,)
268
+ print('Prefix in loaded VAE checkpoint: ')
269
+ print(set([k.split('.')[0]
270
+ for k in converted_vae_checkpoint.keys()]))
271
+ need_strict = not (only_load_vae_decoder or only_load_vae_encoder)
272
+ vae.load_state_dict(converted_vae_checkpoint, strict=need_strict)
273
+ print(" <<< Loaded VAE <<<")
274
+
275
+ if lora_path:
276
+
277
+ print(" >>> Begin loading LoRA >>>")
278
+
279
+ lora_dict = {}
280
+ with safe_open(lora_path, framework='pt', device='cpu') as file:
281
+ for k in file.keys():
282
+ lora_dict[k] = file.get_tensor(k)
283
+ unet, text_encoder = convert_lora_model_level(
284
+ lora_dict, unet, text_encoder, alpha=lora_alpha)
285
+
286
+ print(" <<< Loaded LoRA <<<")
287
+
288
+ # move model to device
289
+ device = torch.device('cuda')
290
+ unet_dtype = torch.float16
291
+ tenc_dtype = torch.float16
292
+ vae_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
293
+
294
+ unet = unet.to(device=device, dtype=unet_dtype)
295
+ text_encoder = text_encoder.to(device=device, dtype=tenc_dtype)
296
+ vae = vae.to(device=device, dtype=vae_dtype)
297
+ print(f'Set Unet to {unet_dtype}')
298
+ print(f'Set text encoder to {tenc_dtype}')
299
+ print(f'Set vae to {vae_dtype}')
300
+
301
+ if is_xformers_available():
302
+ unet.enable_xformers_memory_efficient_attention()
303
+
304
+ pipeline = cls(unet=unet,
305
+ vae=vae,
306
+ tokenizer=tokenizer,
307
+ text_encoder=text_encoder,
308
+ scheduler=noise_scheduler)
309
+
310
+ # ip_adapter_path = 'h94/IP-Adapter'
311
+ if ip_adapter_path and ip_adapter_scale > 0:
312
+ ip_adapter_name = 'ip-adapter_sd15.bin'
313
+ # only online repo need subfolder
314
+ if not osp.isdir(ip_adapter_path):
315
+ subfolder = 'models'
316
+ else:
317
+ subfolder = ''
318
+ pipeline.load_ip_adapter(
319
+ ip_adapter_path, subfolder, ip_adapter_name)
320
+ pipeline.set_ip_adapter_scale(ip_adapter_scale)
321
+ pipeline.use_ip_adapter = True
322
+ print(f'Load IP-Adapter, scale: {ip_adapter_scale}')
323
+
324
+ # text_inversion_path = './models/TextualInversion/easynegative.safetensors'
325
+ # if text_inversion_path:
326
+ # pipeline.load_textual_inversion(text_inversion_path, 'easynegative')
327
+
328
+ return pipeline
329
+
330
+ def enable_vae_slicing(self):
331
+ self.vae.enable_slicing()
332
+
333
+ def disable_vae_slicing(self):
334
+ self.vae.disable_slicing()
335
+
336
+ def enable_sequential_cpu_offload(self, gpu_id=0):
337
+ if is_accelerate_available():
338
+ from accelerate import cpu_offload
339
+ else:
340
+ raise ImportError(
341
+ "Please install accelerate via `pip install accelerate`")
342
+
343
+ device = torch.device(f"cuda:{gpu_id}")
344
+
345
+ for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
346
+ if cpu_offloaded_model is not None:
347
+ cpu_offload(cpu_offloaded_model, device)
348
+
349
+ @property
350
+ def _execution_device(self):
351
+ if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
352
+ return self.device
353
+ for module in self.unet.modules():
354
+ if (
355
+ hasattr(module, "_hf_hook")
356
+ and hasattr(module._hf_hook, "execution_device")
357
+ and module._hf_hook.execution_device is not None
358
+ ):
359
+ return torch.device(module._hf_hook.execution_device)
360
+ return self.device
361
+
362
+ def _encode_prompt(self, prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt):
363
+ batch_size = len(prompt) if isinstance(prompt, list) else 1
364
+
365
+ text_inputs = self.tokenizer(
366
+ prompt,
367
+ padding="max_length",
368
+ max_length=self.tokenizer.model_max_length,
369
+ truncation=True,
370
+ return_tensors="pt",
371
+ )
372
+ text_input_ids = text_inputs.input_ids
373
+ untruncated_ids = self.tokenizer(
374
+ prompt, padding="longest", return_tensors="pt").input_ids
375
+
376
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
377
+ removed_text = self.tokenizer.batch_decode(
378
+ untruncated_ids[:, self.tokenizer.model_max_length - 1: -1])
379
+ logger.warning(
380
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
381
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
382
+ )
383
+
384
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
385
+ attention_mask = text_inputs.attention_mask.to(device)
386
+ else:
387
+ attention_mask = None
388
+
389
+ text_embeddings = self.text_encoder(
390
+ text_input_ids.to(device),
391
+ attention_mask=attention_mask,
392
+ )
393
+ text_embeddings = text_embeddings[0]
394
+
395
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
396
+ bs_embed, seq_len, _ = text_embeddings.shape
397
+ text_embeddings = text_embeddings.repeat(1, num_videos_per_prompt, 1)
398
+ text_embeddings = text_embeddings.view(
399
+ bs_embed * num_videos_per_prompt, seq_len, -1)
400
+
401
+ # get unconditional embeddings for classifier free guidance
402
+ if do_classifier_free_guidance:
403
+ uncond_tokens: List[str]
404
+ if negative_prompt is None:
405
+ uncond_tokens = [""] * batch_size
406
+ elif type(prompt) is not type(negative_prompt):
407
+ raise TypeError(
408
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
409
+ f" {type(prompt)}."
410
+ )
411
+ elif isinstance(negative_prompt, str):
412
+ uncond_tokens = [negative_prompt]
413
+ elif batch_size != len(negative_prompt):
414
+ raise ValueError(
415
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
416
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
417
+ " the batch size of `prompt`."
418
+ )
419
+ else:
420
+ uncond_tokens = negative_prompt
421
+
422
+ max_length = text_input_ids.shape[-1]
423
+ uncond_input = self.tokenizer(
424
+ uncond_tokens,
425
+ padding="max_length",
426
+ max_length=max_length,
427
+ truncation=True,
428
+ return_tensors="pt",
429
+ )
430
+
431
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
432
+ attention_mask = uncond_input.attention_mask.to(device)
433
+ else:
434
+ attention_mask = None
435
+
436
+ uncond_embeddings = self.text_encoder(
437
+ uncond_input.input_ids.to(device),
438
+ attention_mask=attention_mask,
439
+ )
440
+ uncond_embeddings = uncond_embeddings[0]
441
+
442
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
443
+ seq_len = uncond_embeddings.shape[1]
444
+ uncond_embeddings = uncond_embeddings.repeat(
445
+ 1, num_videos_per_prompt, 1)
446
+ uncond_embeddings = uncond_embeddings.view(
447
+ batch_size * num_videos_per_prompt, seq_len, -1)
448
+
449
+ # For classifier free guidance, we need to do two forward passes.
450
+ # Here we concatenate the unconditional and text embeddings into a single batch
451
+ # to avoid doing two forward passes
452
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
453
+
454
+ return text_embeddings
455
+
456
+ def decode_latents(self, latents):
457
+ video_length = latents.shape[2]
458
+ latents = 1 / 0.18215 * latents
459
+ latents = rearrange(latents, "b c f h w -> (b f) c h w")
460
+ # video = self.vae.decode(latents).sample
461
+ video = []
462
+ for frame_idx in tqdm(range(latents.shape[0])):
463
+ video.append(self.vae.decode(
464
+ latents[frame_idx:frame_idx+1]).sample)
465
+ video = torch.cat(video)
466
+ video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
467
+ video = (video / 2 + 0.5).clamp(0, 1)
468
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
469
+ video = video.cpu().float().numpy()
470
+ return video
471
+
472
+ def prepare_extra_step_kwargs(self, generator, eta):
473
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
474
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
475
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
476
+ # and should be between [0, 1]
477
+
478
+ accepts_eta = "eta" in set(inspect.signature(
479
+ self.scheduler.step).parameters.keys())
480
+ extra_step_kwargs = {}
481
+ if accepts_eta:
482
+ extra_step_kwargs["eta"] = eta
483
+
484
+ # check if the scheduler accepts generator
485
+ accepts_generator = "generator" in set(
486
+ inspect.signature(self.scheduler.step).parameters.keys())
487
+ if accepts_generator:
488
+ extra_step_kwargs["generator"] = generator
489
+ return extra_step_kwargs
490
+
491
+ def check_inputs(self, prompt, height, width, callback_steps):
492
+ if not isinstance(prompt, str) and not isinstance(prompt, list):
493
+ raise ValueError(
494
+ f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
495
+
496
+ if height % 8 != 0 or width % 8 != 0:
497
+ raise ValueError(
498
+ f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
499
+
500
+ if (callback_steps is None) or (
501
+ callback_steps is not None and (not isinstance(
502
+ callback_steps, int) or callback_steps <= 0)
503
+ ):
504
+ raise ValueError(
505
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
506
+ f" {type(callback_steps)}."
507
+ )
508
+
509
+ def get_timesteps(self, num_inference_steps, strength, device):
510
+ # get the original timestep using init_timestep
511
+ init_timestep = min(
512
+ int(num_inference_steps * strength), num_inference_steps)
513
+
514
+ t_start = max(num_inference_steps - init_timestep, 0)
515
+ timesteps = self.scheduler.timesteps[t_start:]
516
+
517
+ return timesteps, num_inference_steps - t_start
518
+
519
+ def prepare_latents(self, add_noise_time_step, batch_size, num_channels_latents, video_length, height, width, dtype, device, generator, latents=None):
520
+ shape = (batch_size, num_channels_latents, video_length, height //
521
+ self.vae_scale_factor, width // self.vae_scale_factor)
522
+
523
+ if isinstance(generator, list) and len(generator) != batch_size:
524
+ raise ValueError(
525
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
526
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
527
+ )
528
+ if latents is None:
529
+ rand_device = "cpu" if device.type == "mps" else device
530
+
531
+ if isinstance(generator, list):
532
+ shape = shape
533
+ # shape = (1,) + shape[1:]
534
+ latents = [
535
+ torch.randn(
536
+ shape, generator=generator[i], device=rand_device, dtype=dtype)
537
+ for i in range(batch_size)
538
+ ]
539
+ latents = torch.cat(latents, dim=0).to(device)
540
+ else:
541
+ latents = torch.randn(
542
+ shape, generator=generator, device=rand_device, dtype=dtype).to(device)
543
+ else:
544
+ if latents.shape != shape:
545
+ raise ValueError(
546
+ f"Unexpected latents shape, got {latents.shape}, expected {shape}")
547
+ latents = latents.to(device)
548
+
549
+ return latents
550
+
551
+ def encode_image(self, image, device, num_images_per_prompt):
552
+ """Encode image for ip-adapter. Copied from
553
+ https://github.com/huggingface/diffusers/blob/f9487783228cd500a21555da3346db40e8f05992/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L492-L514 # noqa
554
+ """
555
+ dtype = next(self.image_encoder.parameters()).dtype
556
+
557
+ if not isinstance(image, torch.Tensor):
558
+ image = self.feature_extractor(
559
+ image, return_tensors="pt").pixel_values
560
+
561
+ image = image.to(device=device, dtype=dtype)
562
+ image_embeds = self.image_encoder(image).image_embeds
563
+ image_embeds = image_embeds.repeat_interleave(
564
+ num_images_per_prompt, dim=0)
565
+
566
+ uncond_image_embeds = torch.zeros_like(image_embeds)
567
+ return image_embeds, uncond_image_embeds
568
+
569
+ @torch.no_grad()
570
+ def __call__(
571
+ self,
572
+ image: np.ndarray,
573
+ prompt: Union[str, List[str]],
574
+ video_length: Optional[int],
575
+ height: Optional[int] = None,
576
+ width: Optional[int] = None,
577
+ global_inf_num: int = 0,
578
+ num_inference_steps: int = 50,
579
+ guidance_scale: float = 7.5,
580
+ negative_prompt: Optional[Union[str, List[str]]] = None,
581
+ num_videos_per_prompt: Optional[int] = 1,
582
+ eta: float = 0.0,
583
+ generator: Optional[Union[torch.Generator,
584
+ List[torch.Generator]]] = None,
585
+ latents: Optional[torch.FloatTensor] = None,
586
+ output_type: Optional[str] = "tensor",
587
+ return_dict: bool = True,
588
+ callback: Optional[Callable[[
589
+ int, int, torch.FloatTensor], None]] = None,
590
+ callback_steps: Optional[int] = 1,
591
+
592
+ cond_frame: int = 0,
593
+ mask_sim_template_idx: int = 0,
594
+ ip_adapter_scale: float = 0,
595
+ strength: float = 1,
596
+ is_real_img: bool = False,
597
+ progress_fn=None,
598
+ **kwargs,
599
+ ):
600
+ # Default height and width to unet
601
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
602
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
603
+
604
+ assert strength > 0 and strength <= 1, (
605
+ f'"strength" for img2vid must in (0, 1]. But receive {strength}.')
606
+
607
+ # Check inputs. Raise error if not correct
608
+ self.check_inputs(prompt, height, width, callback_steps)
609
+
610
+ # Define call parameters
611
+ # batch_size = 1 if isinstance(prompt, str) else len(prompt)
612
+ batch_size = 1
613
+ if latents is not None:
614
+ batch_size = latents.shape[0]
615
+ if isinstance(prompt, list):
616
+ batch_size = len(prompt)
617
+
618
+ device = self._execution_device
619
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
620
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
621
+ # corresponds to doing no classifier free guidance.
622
+ do_classifier_free_guidance = guidance_scale > 1.0
623
+
624
+ # Encode input prompt
625
+ prompt = prompt if isinstance(prompt, list) else [prompt] * batch_size
626
+
627
+ if negative_prompt is None:
628
+ negative_prompt = DEFAULT_N_PROMPT
629
+ negative_prompt = negative_prompt if isinstance(negative_prompt, list) else [
630
+ negative_prompt] * batch_size
631
+ text_embeddings = self._encode_prompt(
632
+ prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt
633
+ )
634
+
635
+ # Prepare timesteps
636
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
637
+ timesteps, num_inference_steps = self.get_timesteps(
638
+ num_inference_steps, strength, device)
639
+ latent_timestep = timesteps[:1].repeat(batch_size)
640
+
641
+ # Prepare latent variables
642
+ num_channels_latents = self.unet.in_channels
643
+ latents = self.prepare_latents(
644
+ latent_timestep,
645
+ batch_size * num_videos_per_prompt,
646
+ 4,
647
+ video_length,
648
+ height,
649
+ width,
650
+ text_embeddings.dtype,
651
+ device,
652
+ generator,
653
+ latents,
654
+ )
655
+
656
+ shape = (batch_size, num_channels_latents, video_length, height //
657
+ self.vae_scale_factor, width // self.vae_scale_factor)
658
+
659
+ raw_image = image.copy()
660
+ image = torch.from_numpy(image)[None, ...].permute(0, 3, 1, 2)
661
+ image = image / 255 # [0, 1]
662
+ image = image * 2 - 1 # [-1, 1]
663
+ image = image.to(device=device, dtype=self.vae.dtype)
664
+
665
+ if isinstance(generator, list):
666
+ image_latent = [
667
+ self.vae.encode(image[k: k + 1]).latent_dist.sample(generator[k]) for k in range(batch_size)
668
+ ]
669
+ image_latent = torch.cat(image_latent, dim=0)
670
+ else:
671
+ image_latent = self.vae.encode(image).latent_dist.sample(generator)
672
+
673
+ image_latent = image_latent.to(device=device, dtype=self.unet.dtype)
674
+ image_latent = torch.nn.functional.interpolate(
675
+ image_latent, size=[shape[-2], shape[-1]])
676
+ image_latent_padding = image_latent.clone() * 0.18215
677
+ mask = torch.zeros((shape[0], 1, shape[2], shape[3], shape[4])).to(
678
+ device=device, dtype=self.unet.dtype)
679
+
680
+ # prepare mask
681
+ # NOTE: pass specific st_motion for real image style transfer
682
+ if mask_sim_template_idx == -1 and is_real_img:
683
+ mask_coef = prepare_mask_coef_by_statistics(
684
+ video_length, cond_frame, mask_sim_template_idx, self.st_motion)
685
+ else:
686
+ mask_coef = prepare_mask_coef_by_statistics(
687
+ video_length, cond_frame, mask_sim_template_idx)
688
+
689
+ masked_image = torch.zeros(shape[0], 4, shape[2], shape[3], shape[4]).to(
690
+ device=device, dtype=self.unet.dtype)
691
+ for f in range(video_length):
692
+ mask[:, :, f, :, :] = mask_coef[f]
693
+ masked_image[:, :, f, :, :] = image_latent_padding.clone()
694
+
695
+ # Prepare extra step kwargs.
696
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
697
+ mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
698
+ masked_image = torch.cat(
699
+ [masked_image] * 2) if do_classifier_free_guidance else masked_image
700
+ # Denoising loop
701
+ num_warmup_steps = len(timesteps) - \
702
+ num_inference_steps * self.scheduler.order
703
+
704
+ # prepare for ip-adapter
705
+ if self.use_ip_adapter:
706
+ image_embeds, neg_image_embeds = self.encode_image(
707
+ raw_image, device, num_videos_per_prompt)
708
+ image_embeds = torch.cat([neg_image_embeds, image_embeds])
709
+ image_embeds = image_embeds.to(device, self.unet.dtype)
710
+
711
+ self.set_ip_adapter_scale(ip_adapter_scale)
712
+ print(f'Set IP-Adapter Scale as {ip_adapter_scale}')
713
+
714
+ else:
715
+ image_embeds = None
716
+
717
+ # prepare for latents if strength < 1, add convert gaussian latent to masked_img and add noise
718
+ if strength < 1:
719
+ noise = torch.randn_like(latents)
720
+ latents = self.scheduler.add_noise(
721
+ masked_image[0], noise, timesteps[0])
722
+
723
+ if progress_fn is None:
724
+ progress_bar = tqdm(timesteps)
725
+ terminal_pbar = None
726
+ else:
727
+ progress_bar = progress_fn.tqdm(timesteps)
728
+ terminal_pbar = tqdm(total=len(timesteps))
729
+
730
+ for i, t in enumerate(progress_bar):
731
+ # expand the latents if we are doing classifier free guidance
732
+ latent_model_input = torch.cat(
733
+ [latents] * 2) if do_classifier_free_guidance else latents
734
+ latent_model_input = self.scheduler.scale_model_input(
735
+ latent_model_input, t)
736
+
737
+ # predict the noise residual
738
+ noise_pred = self.unet(
739
+ latent_model_input,
740
+ mask,
741
+ masked_image,
742
+ t,
743
+ encoder_hidden_states=text_embeddings,
744
+ image_embeds=image_embeds
745
+ )['sample']
746
+
747
+ # perform guidance
748
+ if do_classifier_free_guidance:
749
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
750
+ noise_pred = noise_pred_uncond + guidance_scale * \
751
+ (noise_pred_text - noise_pred_uncond)
752
+
753
+ # compute the previous noisy sample x_t -> x_t-1
754
+ latents = self.scheduler.step(
755
+ noise_pred, t, latents, **extra_step_kwargs).prev_sample
756
+
757
+ # call the callback, if provided
758
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
759
+ if callback is not None and i % callback_steps == 0:
760
+ callback(i, t, latents)
761
+
762
+ if terminal_pbar is not None:
763
+ terminal_pbar.update(1)
764
+
765
+ # Post-processing
766
+ video = self.decode_latents(latents.to(device, dtype=self.vae.dtype))
767
+
768
+ # Convert to tensor
769
+ if output_type == "tensor":
770
+ video = torch.from_numpy(video)
771
+
772
+ if not return_dict:
773
+ return video
774
+
775
+ return AnimationPipelineOutput(videos=video)
animatediff/pipelines/pipeline_animation.py ADDED
@@ -0,0 +1,446 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/showlab/Tune-A-Video/blob/main/tuneavideo/pipelines/pipeline_tuneavideo.py
2
+
3
+ import inspect
4
+ from typing import Callable, List, Optional, Union
5
+ from dataclasses import dataclass
6
+
7
+ import numpy as np
8
+ import torch
9
+ from tqdm import tqdm
10
+
11
+ from diffusers.utils import is_accelerate_available
12
+ from packaging import version
13
+ from transformers import CLIPTextModel, CLIPTokenizer
14
+
15
+ from diffusers.configuration_utils import FrozenDict
16
+ from diffusers.models import AutoencoderKL
17
+ from diffusers.pipelines import DiffusionPipeline
18
+ from diffusers.schedulers import (
19
+ DDIMScheduler,
20
+ DPMSolverMultistepScheduler,
21
+ EulerAncestralDiscreteScheduler,
22
+ EulerDiscreteScheduler,
23
+ LMSDiscreteScheduler,
24
+ PNDMScheduler,
25
+ )
26
+ from diffusers.utils import deprecate, logging, BaseOutput
27
+
28
+ from einops import rearrange
29
+
30
+ from ..models.unet import UNet3DConditionModel
31
+
32
+
33
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
34
+
35
+
36
+ @dataclass
37
+ class AnimationPipelineOutput(BaseOutput):
38
+ videos: Union[torch.Tensor, np.ndarray]
39
+
40
+
41
+ class AnimationPipeline(DiffusionPipeline):
42
+ _optional_components = []
43
+
44
+ def __init__(
45
+ self,
46
+ vae: AutoencoderKL,
47
+ text_encoder: CLIPTextModel,
48
+ tokenizer: CLIPTokenizer,
49
+ unet: UNet3DConditionModel,
50
+ scheduler: Union[
51
+ DDIMScheduler,
52
+ PNDMScheduler,
53
+ LMSDiscreteScheduler,
54
+ EulerDiscreteScheduler,
55
+ EulerAncestralDiscreteScheduler,
56
+ DPMSolverMultistepScheduler,
57
+ ],
58
+ ):
59
+ super().__init__()
60
+
61
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
62
+ deprecation_message = (
63
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
64
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
65
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
66
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
67
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
68
+ " file"
69
+ )
70
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
71
+ new_config = dict(scheduler.config)
72
+ new_config["steps_offset"] = 1
73
+ scheduler._internal_dict = FrozenDict(new_config)
74
+
75
+ if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
76
+ deprecation_message = (
77
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
78
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
79
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
80
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
81
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
82
+ )
83
+ deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
84
+ new_config = dict(scheduler.config)
85
+ new_config["clip_sample"] = False
86
+ scheduler._internal_dict = FrozenDict(new_config)
87
+
88
+ is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
89
+ version.parse(unet.config._diffusers_version).base_version
90
+ ) < version.parse("0.9.0.dev0")
91
+ is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
92
+ if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
93
+ deprecation_message = (
94
+ "The configuration file of the unet has set the default `sample_size` to smaller than"
95
+ " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
96
+ " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
97
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
98
+ " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
99
+ " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
100
+ " in the config might lead to incorrect results in future versions. If you have downloaded this"
101
+ " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
102
+ " the `unet/config.json` file"
103
+ )
104
+ deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
105
+ new_config = dict(unet.config)
106
+ new_config["sample_size"] = 64
107
+ unet._internal_dict = FrozenDict(new_config)
108
+
109
+ self.register_modules(
110
+ vae=vae,
111
+ text_encoder=text_encoder,
112
+ tokenizer=tokenizer,
113
+ unet=unet,
114
+ scheduler=scheduler,
115
+ )
116
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
117
+
118
+ def enable_vae_slicing(self):
119
+ self.vae.enable_slicing()
120
+
121
+ def disable_vae_slicing(self):
122
+ self.vae.disable_slicing()
123
+
124
+ def enable_sequential_cpu_offload(self, gpu_id=0):
125
+ if is_accelerate_available():
126
+ from accelerate import cpu_offload
127
+ else:
128
+ raise ImportError("Please install accelerate via `pip install accelerate`")
129
+
130
+ device = torch.device(f"cuda:{gpu_id}")
131
+
132
+ for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
133
+ if cpu_offloaded_model is not None:
134
+ cpu_offload(cpu_offloaded_model, device)
135
+
136
+
137
+ @property
138
+ def _execution_device(self):
139
+ if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
140
+ return self.device
141
+ for module in self.unet.modules():
142
+ if (
143
+ hasattr(module, "_hf_hook")
144
+ and hasattr(module._hf_hook, "execution_device")
145
+ and module._hf_hook.execution_device is not None
146
+ ):
147
+ return torch.device(module._hf_hook.execution_device)
148
+ return self.device
149
+
150
+ def _encode_prompt(self, prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt):
151
+ batch_size = len(prompt) if isinstance(prompt, list) else 1
152
+
153
+ text_inputs = self.tokenizer(
154
+ prompt,
155
+ padding="max_length",
156
+ max_length=self.tokenizer.model_max_length,
157
+ truncation=True,
158
+ return_tensors="pt",
159
+ )
160
+ text_input_ids = text_inputs.input_ids
161
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
162
+
163
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
164
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
165
+ logger.warning(
166
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
167
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
168
+ )
169
+
170
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
171
+ attention_mask = text_inputs.attention_mask.to(device)
172
+ else:
173
+ attention_mask = None
174
+
175
+ text_embeddings = self.text_encoder(
176
+ text_input_ids.to(device),
177
+ attention_mask=attention_mask,
178
+ )
179
+ text_embeddings = text_embeddings[0]
180
+
181
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
182
+ bs_embed, seq_len, _ = text_embeddings.shape
183
+ text_embeddings = text_embeddings.repeat(1, num_videos_per_prompt, 1)
184
+ text_embeddings = text_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1)
185
+
186
+ # get unconditional embeddings for classifier free guidance
187
+ if do_classifier_free_guidance:
188
+ uncond_tokens: List[str]
189
+ if negative_prompt is None:
190
+ uncond_tokens = [""] * batch_size
191
+ elif type(prompt) is not type(negative_prompt):
192
+ raise TypeError(
193
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
194
+ f" {type(prompt)}."
195
+ )
196
+ elif isinstance(negative_prompt, str):
197
+ uncond_tokens = [negative_prompt]
198
+ elif batch_size != len(negative_prompt):
199
+ raise ValueError(
200
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
201
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
202
+ " the batch size of `prompt`."
203
+ )
204
+ else:
205
+ uncond_tokens = negative_prompt
206
+
207
+ max_length = text_input_ids.shape[-1]
208
+ uncond_input = self.tokenizer(
209
+ uncond_tokens,
210
+ padding="max_length",
211
+ max_length=max_length,
212
+ truncation=True,
213
+ return_tensors="pt",
214
+ )
215
+
216
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
217
+ attention_mask = uncond_input.attention_mask.to(device)
218
+ else:
219
+ attention_mask = None
220
+
221
+ uncond_embeddings = self.text_encoder(
222
+ uncond_input.input_ids.to(device),
223
+ attention_mask=attention_mask,
224
+ )
225
+ uncond_embeddings = uncond_embeddings[0]
226
+
227
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
228
+ seq_len = uncond_embeddings.shape[1]
229
+ uncond_embeddings = uncond_embeddings.repeat(1, num_videos_per_prompt, 1)
230
+ uncond_embeddings = uncond_embeddings.view(batch_size * num_videos_per_prompt, seq_len, -1)
231
+
232
+ # For classifier free guidance, we need to do two forward passes.
233
+ # Here we concatenate the unconditional and text embeddings into a single batch
234
+ # to avoid doing two forward passes
235
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
236
+
237
+ return text_embeddings
238
+
239
+ def decode_latents(self, latents):
240
+ video_length = latents.shape[2]
241
+ latents = 1 / 0.18215 * latents
242
+ latents = rearrange(latents, "b c f h w -> (b f) c h w")
243
+ # video = self.vae.decode(latents).sample
244
+ video = []
245
+ for frame_idx in tqdm(range(latents.shape[0])):
246
+ video.append(self.vae.decode(latents[frame_idx:frame_idx+1]).sample)
247
+ video = torch.cat(video)
248
+ video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
249
+ video = (video / 2 + 0.5).clamp(0, 1)
250
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
251
+ video = video.cpu().float().numpy()
252
+ return video
253
+
254
+ def prepare_extra_step_kwargs(self, generator, eta):
255
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
256
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
257
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
258
+ # and should be between [0, 1]
259
+
260
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
261
+ extra_step_kwargs = {}
262
+ if accepts_eta:
263
+ extra_step_kwargs["eta"] = eta
264
+
265
+ # check if the scheduler accepts generator
266
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
267
+ if accepts_generator:
268
+ extra_step_kwargs["generator"] = generator
269
+ return extra_step_kwargs
270
+
271
+ def check_inputs(self, prompt, height, width, callback_steps):
272
+ if not isinstance(prompt, str) and not isinstance(prompt, list):
273
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
274
+
275
+ if height % 8 != 0 or width % 8 != 0:
276
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
277
+
278
+ if (callback_steps is None) or (
279
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
280
+ ):
281
+ raise ValueError(
282
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
283
+ f" {type(callback_steps)}."
284
+ )
285
+
286
+ def prepare_latents(self, batch_size, num_channels_latents, video_length, height, width, dtype, device, generator, latents=None):
287
+ shape = (batch_size, num_channels_latents, video_length, height // self.vae_scale_factor, width // self.vae_scale_factor)
288
+ if isinstance(generator, list) and len(generator) != batch_size:
289
+ raise ValueError(
290
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
291
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
292
+ )
293
+ if latents is None:
294
+ rand_device = "cpu" if device.type == "mps" else device
295
+
296
+ if isinstance(generator, list):
297
+ shape = shape
298
+ # shape = (1,) + shape[1:]
299
+ latents = [
300
+ torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype)
301
+ for i in range(batch_size)
302
+ ]
303
+ latents = torch.cat(latents, dim=0).to(device)
304
+ else:
305
+ latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
306
+ else:
307
+ if latents.shape != shape:
308
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
309
+ latents = latents.to(device)
310
+
311
+ # scale the initial noise by the standard deviation required by the scheduler
312
+ latents = latents * self.scheduler.init_noise_sigma
313
+ return latents
314
+
315
+ @torch.no_grad()
316
+ def __call__(
317
+ self,
318
+ prompt: Union[str, List[str]],
319
+ video_length: Optional[int],
320
+ height: Optional[int] = None,
321
+ width: Optional[int] = None,
322
+ num_inference_steps: int = 50,
323
+ guidance_scale: float = 7.5,
324
+ negative_prompt: Optional[Union[str, List[str]]] = None,
325
+ num_videos_per_prompt: Optional[int] = 1,
326
+ eta: float = 0.0,
327
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
328
+ latents: Optional[torch.FloatTensor] = None,
329
+ output_type: Optional[str] = "tensor",
330
+ return_dict: bool = True,
331
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
332
+ callback_steps: Optional[int] = 1,
333
+ **kwargs,
334
+ ):
335
+ # Default height and width to unet
336
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
337
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
338
+
339
+ # Check inputs. Raise error if not correct
340
+ self.check_inputs(prompt, height, width, callback_steps)
341
+
342
+ # Define call parameters
343
+ # batch_size = 1 if isinstance(prompt, str) else len(prompt)
344
+ batch_size = 1
345
+ if latents is not None:
346
+ batch_size = latents.shape[0]
347
+ if isinstance(prompt, list):
348
+ batch_size = len(prompt)
349
+
350
+ device = self._execution_device
351
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
352
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
353
+ # corresponds to doing no classifier free guidance.
354
+ do_classifier_free_guidance = guidance_scale > 1.0
355
+
356
+ # Encode input prompt
357
+ prompt = prompt if isinstance(prompt, list) else [prompt] * batch_size
358
+ if negative_prompt is not None:
359
+ negative_prompt = negative_prompt if isinstance(negative_prompt, list) else [negative_prompt] * batch_size
360
+ text_embeddings = self._encode_prompt(
361
+ prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt
362
+ )
363
+
364
+ # Prepare timesteps
365
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
366
+ timesteps = self.scheduler.timesteps
367
+
368
+ # Prepare latent variables
369
+ num_channels_latents = self.unet.in_channels
370
+ latents = self.prepare_latents(
371
+ batch_size * num_videos_per_prompt,
372
+ num_channels_latents,
373
+ video_length,
374
+ height,
375
+ width,
376
+ text_embeddings.dtype,
377
+ device,
378
+ generator,
379
+ latents,
380
+ )
381
+ latents_dtype = latents.dtype
382
+
383
+ # Prepare extra step kwargs.
384
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
385
+
386
+ # Denoising loop
387
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
388
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
389
+ for i, t in enumerate(timesteps):
390
+
391
+ # import os
392
+ # import cv2 as cv
393
+ # feature_path = f'feature/timestep{t}'
394
+ # if not os.path.exists(feature_path):
395
+ # os.makedirs(feature_path)
396
+ # # latents B C F H W -> B C H W -> B H W -> H W
397
+ # #
398
+ # features = latents.sum(dim=1, keepdim=False)
399
+ # features = [features[:,frame,:,:] for frame in range(video_length)]
400
+ # features = [feature.squeeze(0) for feature in features]
401
+ # features = [feature.detach().cpu().numpy() for feature in features]
402
+
403
+ # features = [((feature - feature.min()) / (feature.max() - feature.min()) * 255) for feature in features]
404
+
405
+ # for feature_num in range(len(features)):
406
+ # cv.imwrite(f'{feature_path}/{feature_num}.jpg', features[feature_num])
407
+
408
+ # expand the latents if we are doing classifier free guidance
409
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
410
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
411
+
412
+ # predict the noise residual
413
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample.to(dtype=latents_dtype)
414
+ # noise_pred = []
415
+ # import pdb
416
+ # pdb.set_trace()
417
+ # for batch_idx in range(latent_model_input.shape[0]):
418
+ # noise_pred_single = self.unet(latent_model_input[batch_idx:batch_idx+1], t, encoder_hidden_states=text_embeddings[batch_idx:batch_idx+1]).sample.to(dtype=latents_dtype)
419
+ # noise_pred.append(noise_pred_single)
420
+ # noise_pred = torch.cat(noise_pred)
421
+
422
+ # perform guidance
423
+ if do_classifier_free_guidance:
424
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
425
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
426
+
427
+ # compute the previous noisy sample x_t -> x_t-1
428
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
429
+
430
+ # call the callback, if provided
431
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
432
+ progress_bar.update()
433
+ if callback is not None and i % callback_steps == 0:
434
+ callback(i, t, latents)
435
+
436
+ # Post-processing
437
+ video = self.decode_latents(latents)
438
+
439
+ # Convert to tensor
440
+ if output_type == "tensor":
441
+ video = torch.from_numpy(video)
442
+
443
+ if not return_dict:
444
+ return video
445
+
446
+ return AnimationPipelineOutput(videos=video)
animatediff/pipelines/validation_pipeline.py ADDED
@@ -0,0 +1,504 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/showlab/Tune-A-Video/blob/main/tuneavideo/pipelines/pipeline_tuneavideo.py
2
+
3
+ import inspect
4
+ from typing import Callable, List, Optional, Union
5
+ from dataclasses import dataclass
6
+ import random
7
+ import argparse
8
+
9
+ import numpy as np
10
+ import torch
11
+ from tqdm import tqdm
12
+ from omegaconf import OmegaConf
13
+
14
+ from diffusers.utils import is_accelerate_available
15
+ from packaging import version
16
+ from transformers import CLIPTextModel, CLIPTokenizer
17
+
18
+ import os
19
+ from safetensors import safe_open
20
+
21
+ from diffusers.configuration_utils import FrozenDict
22
+ from diffusers.models import AutoencoderKL
23
+ from diffusers.pipelines import DiffusionPipeline
24
+ from diffusers.schedulers import (
25
+ DDIMScheduler,
26
+ DPMSolverMultistepScheduler,
27
+ EulerAncestralDiscreteScheduler,
28
+ EulerDiscreteScheduler,
29
+ LMSDiscreteScheduler,
30
+ PNDMScheduler,
31
+ )
32
+ from diffusers.utils import deprecate, logging, BaseOutput
33
+
34
+ from einops import rearrange
35
+
36
+ from animatediff.models.unet import UNet3DConditionModel
37
+
38
+ from animatediff.utils.convert_from_ckpt import convert_ldm_unet_checkpoint, convert_ldm_clip_checkpoint, convert_ldm_vae_checkpoint
39
+ from animatediff.utils.convert_lora_safetensor_to_diffusers import convert_lora
40
+
41
+ from animatediff.utils.util import prepare_mask_coef, save_videos_grid
42
+ from animatediff.models.resnet import InflatedConv3d
43
+
44
+ from PIL import Image
45
+
46
+ PIL_INTERPOLATION = {
47
+ "linear": Image.Resampling.BILINEAR,
48
+ "bilinear": Image.Resampling.BILINEAR,
49
+ "bicubic": Image.Resampling.BICUBIC,
50
+ "lanczos": Image.Resampling.LANCZOS,
51
+ "nearest": Image.Resampling.NEAREST,
52
+ }
53
+ def preprocess_image(image):
54
+ if isinstance(image, torch.Tensor):
55
+ return image
56
+ elif isinstance(image, Image.Image):
57
+ image = [image]
58
+
59
+ if isinstance(image[0], Image.Image):
60
+ w, h = image[0].size
61
+ w, h = map(lambda x: x - x % 8, (w, h)) # resize to integer multiple of 8
62
+
63
+ image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image]
64
+ image = np.concatenate(image, axis=0)
65
+ if len(image.shape) == 3:
66
+ image = image.reshape(image.shape[0], image.shape[1], image.shape[2], 1)
67
+ image = np.array(image).astype(np.float32) / 255.0
68
+ image = image.transpose(0, 3, 1, 2)
69
+ image = 2.0 * image - 1.0
70
+ image = torch.from_numpy(image)
71
+ elif isinstance(image[0], torch.Tensor):
72
+ image = torch.cat(image, dim=0)
73
+ return image
74
+
75
+
76
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
77
+
78
+
79
+ @dataclass
80
+ class AnimationPipelineOutput(BaseOutput):
81
+ videos: Union[torch.Tensor, np.ndarray]
82
+
83
+
84
+ class ValidationPipeline(DiffusionPipeline):
85
+ _optional_components = []
86
+
87
+ def __init__(
88
+ self,
89
+ vae: AutoencoderKL,
90
+ text_encoder: CLIPTextModel,
91
+ tokenizer: CLIPTokenizer,
92
+ unet: UNet3DConditionModel,
93
+ scheduler: Union[
94
+ DDIMScheduler,
95
+ PNDMScheduler,
96
+ LMSDiscreteScheduler,
97
+ EulerDiscreteScheduler,
98
+ EulerAncestralDiscreteScheduler,
99
+ DPMSolverMultistepScheduler,
100
+ ],
101
+ ):
102
+ super().__init__()
103
+
104
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
105
+ deprecation_message = (
106
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
107
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
108
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
109
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
110
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
111
+ " file"
112
+ )
113
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
114
+ new_config = dict(scheduler.config)
115
+ new_config["steps_offset"] = 1
116
+ scheduler._internal_dict = FrozenDict(new_config)
117
+
118
+ if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
119
+ deprecation_message = (
120
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
121
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
122
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
123
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
124
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
125
+ )
126
+ deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
127
+ new_config = dict(scheduler.config)
128
+ new_config["clip_sample"] = False
129
+ scheduler._internal_dict = FrozenDict(new_config)
130
+
131
+ is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
132
+ version.parse(unet.config._diffusers_version).base_version
133
+ ) < version.parse("0.9.0.dev0")
134
+ is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
135
+ if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
136
+ deprecation_message = (
137
+ "The configuration file of the unet has set the default `sample_size` to smaller than"
138
+ " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
139
+ " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
140
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
141
+ " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
142
+ " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
143
+ " in the config might lead to incorrect results in future versions. If you have downloaded this"
144
+ " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
145
+ " the `unet/config.json` file"
146
+ )
147
+ deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
148
+ new_config = dict(unet.config)
149
+ new_config["sample_size"] = 64
150
+ unet._internal_dict = FrozenDict(new_config)
151
+
152
+ self.register_modules(
153
+ vae=vae,
154
+ text_encoder=text_encoder,
155
+ tokenizer=tokenizer,
156
+ unet=unet,
157
+ scheduler=scheduler,
158
+ )
159
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
160
+
161
+ def enable_vae_slicing(self):
162
+ self.vae.enable_slicing()
163
+
164
+ def disable_vae_slicing(self):
165
+ self.vae.disable_slicing()
166
+
167
+ def enable_sequential_cpu_offload(self, gpu_id=0):
168
+ if is_accelerate_available():
169
+ from accelerate import cpu_offload
170
+ else:
171
+ raise ImportError("Please install accelerate via `pip install accelerate`")
172
+
173
+ device = torch.device(f"cuda:{gpu_id}")
174
+
175
+ for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
176
+ if cpu_offloaded_model is not None:
177
+ cpu_offload(cpu_offloaded_model, device)
178
+
179
+
180
+ @property
181
+ def _execution_device(self):
182
+ if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
183
+ return self.device
184
+ for module in self.unet.modules():
185
+ if (
186
+ hasattr(module, "_hf_hook")
187
+ and hasattr(module._hf_hook, "execution_device")
188
+ and module._hf_hook.execution_device is not None
189
+ ):
190
+ return torch.device(module._hf_hook.execution_device)
191
+ return self.device
192
+
193
+ def _encode_prompt(self, prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt):
194
+ batch_size = len(prompt) if isinstance(prompt, list) else 1
195
+
196
+ text_inputs = self.tokenizer(
197
+ prompt,
198
+ padding="max_length",
199
+ max_length=self.tokenizer.model_max_length,
200
+ truncation=True,
201
+ return_tensors="pt",
202
+ )
203
+ text_input_ids = text_inputs.input_ids
204
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
205
+
206
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
207
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
208
+ logger.warning(
209
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
210
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
211
+ )
212
+
213
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
214
+ attention_mask = text_inputs.attention_mask.to(device)
215
+ else:
216
+ attention_mask = None
217
+
218
+ text_embeddings = self.text_encoder(
219
+ text_input_ids.to(device),
220
+ attention_mask=attention_mask,
221
+ )
222
+ text_embeddings = text_embeddings[0]
223
+
224
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
225
+ bs_embed, seq_len, _ = text_embeddings.shape
226
+ text_embeddings = text_embeddings.repeat(1, num_videos_per_prompt, 1)
227
+ text_embeddings = text_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1)
228
+
229
+ # get unconditional embeddings for classifier free guidance
230
+ if do_classifier_free_guidance:
231
+ uncond_tokens: List[str]
232
+ if negative_prompt is None:
233
+ uncond_tokens = [""] * batch_size
234
+ elif type(prompt) is not type(negative_prompt):
235
+ raise TypeError(
236
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
237
+ f" {type(prompt)}."
238
+ )
239
+ elif isinstance(negative_prompt, str):
240
+ uncond_tokens = [negative_prompt]
241
+ elif batch_size != len(negative_prompt):
242
+ raise ValueError(
243
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
244
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
245
+ " the batch size of `prompt`."
246
+ )
247
+ else:
248
+ uncond_tokens = negative_prompt
249
+
250
+ max_length = text_input_ids.shape[-1]
251
+ uncond_input = self.tokenizer(
252
+ uncond_tokens,
253
+ padding="max_length",
254
+ max_length=max_length,
255
+ truncation=True,
256
+ return_tensors="pt",
257
+ )
258
+
259
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
260
+ attention_mask = uncond_input.attention_mask.to(device)
261
+ else:
262
+ attention_mask = None
263
+
264
+ uncond_embeddings = self.text_encoder(
265
+ uncond_input.input_ids.to(device),
266
+ attention_mask=attention_mask,
267
+ )
268
+ uncond_embeddings = uncond_embeddings[0]
269
+
270
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
271
+ seq_len = uncond_embeddings.shape[1]
272
+ uncond_embeddings = uncond_embeddings.repeat(1, num_videos_per_prompt, 1)
273
+ uncond_embeddings = uncond_embeddings.view(batch_size * num_videos_per_prompt, seq_len, -1)
274
+
275
+ # For classifier free guidance, we need to do two forward passes.
276
+ # Here we concatenate the unconditional and text embeddings into a single batch
277
+ # to avoid doing two forward passes
278
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
279
+
280
+ return text_embeddings
281
+
282
+ def decode_latents(self, latents):
283
+ video_length = latents.shape[2]
284
+ latents = 1 / 0.18215 * latents
285
+ latents = rearrange(latents, "b c f h w -> (b f) c h w")
286
+ # video = self.vae.decode(latents).sample
287
+ video = []
288
+ for frame_idx in tqdm(range(latents.shape[0])):
289
+ video.append(self.vae.decode(latents[frame_idx:frame_idx+1]).sample)
290
+ video = torch.cat(video)
291
+ video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
292
+ video = (video / 2 + 0.5).clamp(0, 1)
293
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
294
+ video = video.cpu().float().numpy()
295
+ return video
296
+
297
+ def prepare_extra_step_kwargs(self, generator, eta):
298
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
299
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
300
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
301
+ # and should be between [0, 1]
302
+
303
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
304
+ extra_step_kwargs = {}
305
+ if accepts_eta:
306
+ extra_step_kwargs["eta"] = eta
307
+
308
+ # check if the scheduler accepts generator
309
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
310
+ if accepts_generator:
311
+ extra_step_kwargs["generator"] = generator
312
+ return extra_step_kwargs
313
+
314
+ def check_inputs(self, prompt, height, width, callback_steps):
315
+ if not isinstance(prompt, str) and not isinstance(prompt, list):
316
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
317
+
318
+ if height % 8 != 0 or width % 8 != 0:
319
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
320
+
321
+ if (callback_steps is None) or (
322
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
323
+ ):
324
+ raise ValueError(
325
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
326
+ f" {type(callback_steps)}."
327
+ )
328
+
329
+ def prepare_latents(self, batch_size, num_channels_latents, video_length, height, width, dtype, device, generator, latents=None):
330
+ shape = (batch_size, num_channels_latents, video_length, height // self.vae_scale_factor, width // self.vae_scale_factor)
331
+
332
+ if isinstance(generator, list) and len(generator) != batch_size:
333
+ raise ValueError(
334
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
335
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
336
+ )
337
+ if latents is None:
338
+ rand_device = "cpu" if device.type == "mps" else device
339
+
340
+ if isinstance(generator, list):
341
+ shape = shape
342
+ # shape = (1,) + shape[1:]
343
+ latents = [
344
+ torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype)
345
+ for i in range(batch_size)
346
+ ]
347
+ latents = torch.cat(latents, dim=0).to(device)
348
+ else:
349
+ latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
350
+ else:
351
+ if latents.shape != shape:
352
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
353
+ latents = latents.to(device)
354
+
355
+ # scale the initial noise by the standard deviation required by the scheduler
356
+ latents = latents * self.scheduler.init_noise_sigma
357
+ return latents
358
+
359
+ @torch.no_grad()
360
+ def __call__(
361
+ self,
362
+ prompt: Union[str, List[str]],
363
+ use_image: bool,
364
+ video_length: Optional[int],
365
+ height: Optional[int] = None,
366
+ width: Optional[int] = None,
367
+ num_inference_steps: int = 50,
368
+ guidance_scale: float = 7.5,
369
+ negative_prompt: Optional[Union[str, List[str]]] = None,
370
+ num_videos_per_prompt: Optional[int] = 1,
371
+ eta: float = 0.0,
372
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
373
+ latents: Optional[torch.FloatTensor] = None,
374
+ output_type: Optional[str] = "tensor",
375
+ return_dict: bool = True,
376
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
377
+ callback_steps: Optional[int] = 1,
378
+ **kwargs,
379
+ ):
380
+ # Default height and width to unet
381
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
382
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
383
+
384
+ # Check inputs. Raise error if not correct
385
+ self.check_inputs(prompt, height, width, callback_steps)
386
+
387
+ # Define call parameters
388
+ # batch_size = 1 if isinstance(prompt, str) else len(prompt)
389
+ batch_size = 1
390
+ if latents is not None:
391
+ batch_size = latents.shape[0]
392
+ if isinstance(prompt, list):
393
+ batch_size = len(prompt)
394
+
395
+ device = self._execution_device
396
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
397
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
398
+ # corresponds to doing no classifier free guidance.
399
+ do_classifier_free_guidance = guidance_scale > 1.0
400
+
401
+ # Encode input prompt
402
+ prompt = prompt if isinstance(prompt, list) else [prompt] * batch_size
403
+ if negative_prompt is not None:
404
+ negative_prompt = negative_prompt if isinstance(negative_prompt, list) else [negative_prompt] * batch_size
405
+ text_embeddings = self._encode_prompt(
406
+ prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt
407
+ )
408
+
409
+ # Prepare timesteps
410
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
411
+ timesteps = self.scheduler.timesteps
412
+
413
+ # Prepare latent variables
414
+ num_channels_latents = self.unet.in_channels
415
+ latents = self.prepare_latents(
416
+ batch_size * num_videos_per_prompt,
417
+ num_channels_latents,
418
+ video_length,
419
+ height,
420
+ width,
421
+ text_embeddings.dtype,
422
+ device,
423
+ generator,
424
+ latents,
425
+ )
426
+ latents_dtype = latents.dtype
427
+
428
+ if use_image != False:
429
+ shape = (batch_size, num_channels_latents, video_length, height // self.vae_scale_factor, width // self.vae_scale_factor)
430
+
431
+ image = Image.open(f'test_image/init_image{use_image}.png').convert('RGB')
432
+ image = preprocess_image(image).to(device)
433
+ if isinstance(generator, list):
434
+ image_latent = [
435
+ self.vae.encode(image[k : k + 1]).latent_dist.sample(generator[k]) for k in range(batch_size)
436
+ ]
437
+ image_latent = torch.cat(image_latent, dim=0).to(device=device)
438
+ else:
439
+ image_latent = self.vae.encode(image).latent_dist.sample(generator).to(device=device)
440
+
441
+ image_latent = torch.nn.functional.interpolate(image_latent, size=[shape[-2], shape[-1]])
442
+ image_latent_padding = image_latent.clone() * 0.18215
443
+ mask = torch.zeros((shape[0], 1, shape[2], shape[3], shape[4])).to(device)
444
+ mask_coef = prepare_mask_coef(video_length, 0, kwargs['mask_sim_range'])
445
+
446
+ add_noise = torch.randn(shape).to(device)
447
+ masked_image = torch.zeros(shape).to(device)
448
+ for f in range(video_length):
449
+ mask[:,:,f,:,:] = mask_coef[f]
450
+ masked_image[:,:,f,:,:] = image_latent_padding.clone()
451
+ mask = mask.to(device)
452
+ else:
453
+ shape = (batch_size, num_channels_latents, video_length, height // self.vae_scale_factor, width // self.vae_scale_factor)
454
+ add_noise = torch.zeros_like(latents).to(device)
455
+ masked_image = add_noise
456
+ mask = torch.zeros((shape[0], 1, shape[2], shape[3], shape[4])).to(device)
457
+
458
+ # Prepare extra step kwargs.
459
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
460
+ mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
461
+ masked_image = torch.cat([masked_image] * 2) if do_classifier_free_guidance else masked_image
462
+ # Denoising loop
463
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
464
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
465
+ for i, t in enumerate(timesteps):
466
+ # expand the latents if we are doing classifier free guidance
467
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
468
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
469
+
470
+ # predict the noise residual
471
+ noise_pred = self.unet(latent_model_input, mask, masked_image, t, encoder_hidden_states=text_embeddings).sample.to(dtype=latents_dtype)
472
+ # noise_pred = []
473
+ # import pdb
474
+ # pdb.set_trace()
475
+ # for batch_idx in range(latent_model_input.shape[0]):
476
+ # noise_pred_single = self.unet(latent_model_input[batch_idx:batch_idx+1], t, encoder_hidden_states=text_embeddings[batch_idx:batch_idx+1]).sample.to(dtype=latents_dtype)
477
+ # noise_pred.append(noise_pred_single)
478
+ # noise_pred = torch.cat(noise_pred)
479
+
480
+ # perform guidance
481
+ if do_classifier_free_guidance:
482
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
483
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
484
+
485
+ # compute the previous noisy sample x_t -> x_t-1
486
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
487
+
488
+ # call the callback, if provided
489
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
490
+ progress_bar.update()
491
+ if callback is not None and i % callback_steps == 0:
492
+ callback(i, t, latents)
493
+
494
+ # Post-processing
495
+ video = self.decode_latents(latents)
496
+
497
+ # Convert to tensor
498
+ if output_type == "tensor":
499
+ video = torch.from_numpy(video)
500
+
501
+ if not return_dict:
502
+ return video
503
+
504
+ return AnimationPipelineOutput(videos=video)
animatediff/utils/convert_from_ckpt.py ADDED
@@ -0,0 +1,964 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ Conversion script for the Stable Diffusion checkpoints."""
16
+
17
+ import re
18
+ from io import BytesIO
19
+ from typing import Optional
20
+
21
+ import requests
22
+ import torch
23
+ from transformers import (
24
+ AutoFeatureExtractor,
25
+ BertTokenizerFast,
26
+ CLIPImageProcessor,
27
+ CLIPTextModel,
28
+ CLIPTextModelWithProjection,
29
+ CLIPTokenizer,
30
+ CLIPVisionConfig,
31
+ CLIPVisionModelWithProjection,
32
+ )
33
+
34
+ from diffusers.models import (
35
+ AutoencoderKL,
36
+ PriorTransformer,
37
+ UNet2DConditionModel,
38
+ )
39
+ from diffusers.schedulers import (
40
+ DDIMScheduler,
41
+ DDPMScheduler,
42
+ DPMSolverMultistepScheduler,
43
+ EulerAncestralDiscreteScheduler,
44
+ EulerDiscreteScheduler,
45
+ HeunDiscreteScheduler,
46
+ LMSDiscreteScheduler,
47
+ PNDMScheduler,
48
+ UnCLIPScheduler,
49
+ )
50
+ from diffusers.utils.import_utils import BACKENDS_MAPPING
51
+
52
+
53
+ def shave_segments(path, n_shave_prefix_segments=1):
54
+ """
55
+ Removes segments. Positive values shave the first segments, negative shave the last segments.
56
+ """
57
+ if n_shave_prefix_segments >= 0:
58
+ return ".".join(path.split(".")[n_shave_prefix_segments:])
59
+ else:
60
+ return ".".join(path.split(".")[:n_shave_prefix_segments])
61
+
62
+
63
+ def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
64
+ """
65
+ Updates paths inside resnets to the new naming scheme (local renaming)
66
+ """
67
+ mapping = []
68
+ for old_item in old_list:
69
+ new_item = old_item.replace("in_layers.0", "norm1")
70
+ new_item = new_item.replace("in_layers.2", "conv1")
71
+
72
+ new_item = new_item.replace("out_layers.0", "norm2")
73
+ new_item = new_item.replace("out_layers.3", "conv2")
74
+
75
+ new_item = new_item.replace("emb_layers.1", "time_emb_proj")
76
+ new_item = new_item.replace("skip_connection", "conv_shortcut")
77
+
78
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
79
+
80
+ mapping.append({"old": old_item, "new": new_item})
81
+
82
+ return mapping
83
+
84
+
85
+ def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
86
+ """
87
+ Updates paths inside resnets to the new naming scheme (local renaming)
88
+ """
89
+ mapping = []
90
+ for old_item in old_list:
91
+ new_item = old_item
92
+
93
+ new_item = new_item.replace("nin_shortcut", "conv_shortcut")
94
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
95
+
96
+ mapping.append({"old": old_item, "new": new_item})
97
+
98
+ return mapping
99
+
100
+
101
+ def renew_attention_paths(old_list, n_shave_prefix_segments=0):
102
+ """
103
+ Updates paths inside attentions to the new naming scheme (local renaming)
104
+ """
105
+ mapping = []
106
+ for old_item in old_list:
107
+ new_item = old_item
108
+
109
+ # new_item = new_item.replace('norm.weight', 'group_norm.weight')
110
+ # new_item = new_item.replace('norm.bias', 'group_norm.bias')
111
+
112
+ # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
113
+ # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
114
+
115
+ # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
116
+
117
+ mapping.append({"old": old_item, "new": new_item})
118
+
119
+ return mapping
120
+
121
+
122
+ def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
123
+ """
124
+ Updates paths inside attentions to the new naming scheme (local renaming)
125
+ """
126
+ mapping = []
127
+ for old_item in old_list:
128
+ new_item = old_item
129
+
130
+ new_item = new_item.replace("norm.weight", "group_norm.weight")
131
+ new_item = new_item.replace("norm.bias", "group_norm.bias")
132
+
133
+ new_item = new_item.replace("q.weight", "to_q.weight")
134
+ new_item = new_item.replace("q.bias", "to_q.bias")
135
+
136
+ new_item = new_item.replace("k.weight", "to_k.weight")
137
+ new_item = new_item.replace("k.bias", "to_k.bias")
138
+
139
+ new_item = new_item.replace("v.weight", "to_v.weight")
140
+ new_item = new_item.replace("v.bias", "to_v.bias")
141
+
142
+ new_item = new_item.replace("proj_out.weight", "to_out.0.weight")
143
+ new_item = new_item.replace("proj_out.bias", "to_out.0.bias")
144
+
145
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
146
+
147
+ mapping.append({"old": old_item, "new": new_item})
148
+ return mapping
149
+
150
+
151
+ def assign_to_checkpoint(
152
+ paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
153
+ ):
154
+ """
155
+ This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits
156
+ attention layers, and takes into account additional replacements that may arise.
157
+
158
+ Assigns the weights to the new checkpoint.
159
+ """
160
+ assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
161
+
162
+ # Splits the attention layers into three variables.
163
+ if attention_paths_to_split is not None:
164
+ for path, path_map in attention_paths_to_split.items():
165
+ old_tensor = old_checkpoint[path]
166
+ channels = old_tensor.shape[0] // 3
167
+
168
+ target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
169
+
170
+ num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
171
+
172
+ old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
173
+ query, key, value = old_tensor.split(channels // num_heads, dim=1)
174
+
175
+ checkpoint[path_map["query"]] = query.reshape(target_shape)
176
+ checkpoint[path_map["key"]] = key.reshape(target_shape)
177
+ checkpoint[path_map["value"]] = value.reshape(target_shape)
178
+
179
+ for path in paths:
180
+ new_path = path["new"]
181
+
182
+ # These have already been assigned
183
+ if attention_paths_to_split is not None and new_path in attention_paths_to_split:
184
+ continue
185
+
186
+ # Global renaming happens here
187
+ new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
188
+ new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
189
+ new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
190
+
191
+ if additional_replacements is not None:
192
+ for replacement in additional_replacements:
193
+ new_path = new_path.replace(replacement["old"], replacement["new"])
194
+
195
+ # proj_attn.weight has to be converted from conv 1D to linear
196
+ if "proj_attn.weight" in new_path:
197
+ checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
198
+ elif 'to_out.0.weight' in new_path:
199
+ checkpoint[new_path] = old_checkpoint[path['old']].squeeze()
200
+ elif any([qkv in new_path for qkv in ['to_q', 'to_k', 'to_v']]):
201
+ checkpoint[new_path] = old_checkpoint[path['old']].squeeze()
202
+ else:
203
+ checkpoint[new_path] = old_checkpoint[path["old"]]
204
+
205
+
206
+ def conv_attn_to_linear(checkpoint):
207
+ keys = list(checkpoint.keys())
208
+ attn_keys = ["query.weight", "key.weight", "value.weight"]
209
+ for key in keys:
210
+ if ".".join(key.split(".")[-2:]) in attn_keys:
211
+ if checkpoint[key].ndim > 2:
212
+ checkpoint[key] = checkpoint[key][:, :, 0, 0]
213
+ elif "proj_attn.weight" in key:
214
+ if checkpoint[key].ndim > 2:
215
+ checkpoint[key] = checkpoint[key][:, :, 0]
216
+
217
+
218
+ def create_unet_diffusers_config(original_config, image_size: int, controlnet=False):
219
+ """
220
+ Creates a config for the diffusers based on the config of the LDM model.
221
+ """
222
+ if controlnet:
223
+ unet_params = original_config.model.params.control_stage_config.params
224
+ else:
225
+ unet_params = original_config.model.params.unet_config.params
226
+
227
+ vae_params = original_config.model.params.first_stage_config.params.ddconfig
228
+
229
+ block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult]
230
+
231
+ down_block_types = []
232
+ resolution = 1
233
+ for i in range(len(block_out_channels)):
234
+ block_type = "CrossAttnDownBlock2D" if resolution in unet_params.attention_resolutions else "DownBlock2D"
235
+ down_block_types.append(block_type)
236
+ if i != len(block_out_channels) - 1:
237
+ resolution *= 2
238
+
239
+ up_block_types = []
240
+ for i in range(len(block_out_channels)):
241
+ block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D"
242
+ up_block_types.append(block_type)
243
+ resolution //= 2
244
+
245
+ vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1)
246
+
247
+ head_dim = unet_params.num_heads if "num_heads" in unet_params else None
248
+ use_linear_projection = (
249
+ unet_params.use_linear_in_transformer if "use_linear_in_transformer" in unet_params else False
250
+ )
251
+ if use_linear_projection:
252
+ # stable diffusion 2-base-512 and 2-768
253
+ if head_dim is None:
254
+ head_dim = [5, 10, 20, 20]
255
+
256
+ class_embed_type = None
257
+ projection_class_embeddings_input_dim = None
258
+
259
+ if "num_classes" in unet_params:
260
+ if unet_params.num_classes == "sequential":
261
+ class_embed_type = "projection"
262
+ assert "adm_in_channels" in unet_params
263
+ projection_class_embeddings_input_dim = unet_params.adm_in_channels
264
+ else:
265
+ raise NotImplementedError(f"Unknown conditional unet num_classes config: {unet_params.num_classes}")
266
+
267
+ config = {
268
+ "sample_size": image_size // vae_scale_factor,
269
+ "in_channels": unet_params.in_channels,
270
+ "down_block_types": tuple(down_block_types),
271
+ "block_out_channels": tuple(block_out_channels),
272
+ "layers_per_block": unet_params.num_res_blocks,
273
+ "cross_attention_dim": unet_params.context_dim,
274
+ "attention_head_dim": head_dim,
275
+ "use_linear_projection": use_linear_projection,
276
+ "class_embed_type": class_embed_type,
277
+ "projection_class_embeddings_input_dim": projection_class_embeddings_input_dim,
278
+ }
279
+
280
+ if not controlnet:
281
+ config["out_channels"] = unet_params.out_channels
282
+ config["up_block_types"] = tuple(up_block_types)
283
+
284
+ return config
285
+
286
+
287
+ def create_vae_diffusers_config(original_config, image_size: int):
288
+ """
289
+ Creates a config for the diffusers based on the config of the LDM model.
290
+ """
291
+ vae_params = original_config.model.params.first_stage_config.params.ddconfig
292
+ _ = original_config.model.params.first_stage_config.params.embed_dim
293
+
294
+ block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult]
295
+ down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
296
+ up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
297
+
298
+ config = {
299
+ "sample_size": image_size,
300
+ "in_channels": vae_params.in_channels,
301
+ "out_channels": vae_params.out_ch,
302
+ "down_block_types": tuple(down_block_types),
303
+ "up_block_types": tuple(up_block_types),
304
+ "block_out_channels": tuple(block_out_channels),
305
+ "latent_channels": vae_params.z_channels,
306
+ "layers_per_block": vae_params.num_res_blocks,
307
+ }
308
+ return config
309
+
310
+
311
+ def create_diffusers_schedular(original_config):
312
+ schedular = DDIMScheduler(
313
+ num_train_timesteps=original_config.model.params.timesteps,
314
+ beta_start=original_config.model.params.linear_start,
315
+ beta_end=original_config.model.params.linear_end,
316
+ beta_schedule="scaled_linear",
317
+ )
318
+ return schedular
319
+
320
+
321
+ def create_ldm_bert_config(original_config):
322
+ bert_params = original_config.model.parms.cond_stage_config.params
323
+ config = LDMBertConfig(
324
+ d_model=bert_params.n_embed,
325
+ encoder_layers=bert_params.n_layer,
326
+ encoder_ffn_dim=bert_params.n_embed * 4,
327
+ )
328
+ return config
329
+
330
+
331
+ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False, controlnet=False):
332
+ """
333
+ Takes a state dict and a config, and returns a converted checkpoint.
334
+ """
335
+
336
+ # extract state_dict for UNet
337
+ unet_state_dict = {}
338
+ keys = list(checkpoint.keys())
339
+
340
+ if controlnet:
341
+ unet_key = "control_model."
342
+ else:
343
+ unet_key = "model.diffusion_model."
344
+
345
+ # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
346
+ if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema:
347
+ print(f"Checkpoint {path} has both EMA and non-EMA weights.")
348
+ print(
349
+ "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA"
350
+ " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag."
351
+ )
352
+ for key in keys:
353
+ if key.startswith("model.diffusion_model"):
354
+ flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
355
+ unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key)
356
+ else:
357
+ if sum(k.startswith("model_ema") for k in keys) > 100:
358
+ print(
359
+ "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA"
360
+ " weights (usually better for inference), please make sure to add the `--extract_ema` flag."
361
+ )
362
+
363
+ for key in keys:
364
+ if key.startswith(unet_key):
365
+ unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
366
+
367
+ new_checkpoint = {}
368
+
369
+ new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
370
+ new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
371
+ new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
372
+ new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
373
+
374
+ if config["class_embed_type"] is None:
375
+ # No parameters to port
376
+ ...
377
+ elif config["class_embed_type"] == "timestep" or config["class_embed_type"] == "projection":
378
+ new_checkpoint["class_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"]
379
+ new_checkpoint["class_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"]
380
+ new_checkpoint["class_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"]
381
+ new_checkpoint["class_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"]
382
+ else:
383
+ raise NotImplementedError(f"Not implemented `class_embed_type`: {config['class_embed_type']}")
384
+
385
+ new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
386
+ new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
387
+
388
+ if not controlnet:
389
+ new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
390
+ new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
391
+ new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
392
+ new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
393
+
394
+ # Retrieves the keys for the input blocks only
395
+ num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
396
+ input_blocks = {
397
+ layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key]
398
+ for layer_id in range(num_input_blocks)
399
+ }
400
+
401
+ # Retrieves the keys for the middle blocks only
402
+ num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
403
+ middle_blocks = {
404
+ layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]
405
+ for layer_id in range(num_middle_blocks)
406
+ }
407
+
408
+ # Retrieves the keys for the output blocks only
409
+ num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
410
+ output_blocks = {
411
+ layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key]
412
+ for layer_id in range(num_output_blocks)
413
+ }
414
+
415
+ for i in range(1, num_input_blocks):
416
+ block_id = (i - 1) // (config["layers_per_block"] + 1)
417
+ layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
418
+
419
+ resnets = [
420
+ key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
421
+ ]
422
+ attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
423
+
424
+ if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
425
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
426
+ f"input_blocks.{i}.0.op.weight"
427
+ )
428
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
429
+ f"input_blocks.{i}.0.op.bias"
430
+ )
431
+
432
+ paths = renew_resnet_paths(resnets)
433
+ meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
434
+ assign_to_checkpoint(
435
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
436
+ )
437
+
438
+ if len(attentions):
439
+ paths = renew_attention_paths(attentions)
440
+ meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
441
+ assign_to_checkpoint(
442
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
443
+ )
444
+
445
+ resnet_0 = middle_blocks[0]
446
+ attentions = middle_blocks[1]
447
+ resnet_1 = middle_blocks[2]
448
+
449
+ resnet_0_paths = renew_resnet_paths(resnet_0)
450
+ assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
451
+
452
+ resnet_1_paths = renew_resnet_paths(resnet_1)
453
+ assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
454
+
455
+ attentions_paths = renew_attention_paths(attentions)
456
+ meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
457
+ assign_to_checkpoint(
458
+ attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
459
+ )
460
+
461
+ for i in range(num_output_blocks):
462
+ block_id = i // (config["layers_per_block"] + 1)
463
+ layer_in_block_id = i % (config["layers_per_block"] + 1)
464
+ output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
465
+ output_block_list = {}
466
+
467
+ for layer in output_block_layers:
468
+ layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
469
+ if layer_id in output_block_list:
470
+ output_block_list[layer_id].append(layer_name)
471
+ else:
472
+ output_block_list[layer_id] = [layer_name]
473
+
474
+ if len(output_block_list) > 1:
475
+ resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
476
+ attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
477
+
478
+ resnet_0_paths = renew_resnet_paths(resnets)
479
+ paths = renew_resnet_paths(resnets)
480
+
481
+ meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
482
+ assign_to_checkpoint(
483
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
484
+ )
485
+
486
+ output_block_list = {k: sorted(v) for k, v in output_block_list.items()}
487
+ if ["conv.bias", "conv.weight"] in output_block_list.values():
488
+ index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
489
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
490
+ f"output_blocks.{i}.{index}.conv.weight"
491
+ ]
492
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
493
+ f"output_blocks.{i}.{index}.conv.bias"
494
+ ]
495
+
496
+ # Clear attentions as they have been attributed above.
497
+ if len(attentions) == 2:
498
+ attentions = []
499
+
500
+ if len(attentions):
501
+ paths = renew_attention_paths(attentions)
502
+ meta_path = {
503
+ "old": f"output_blocks.{i}.1",
504
+ "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
505
+ }
506
+ assign_to_checkpoint(
507
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
508
+ )
509
+ else:
510
+ resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
511
+ for path in resnet_0_paths:
512
+ old_path = ".".join(["output_blocks", str(i), path["old"]])
513
+ new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
514
+
515
+ new_checkpoint[new_path] = unet_state_dict[old_path]
516
+
517
+ if controlnet:
518
+ # conditioning embedding
519
+
520
+ orig_index = 0
521
+
522
+ new_checkpoint["controlnet_cond_embedding.conv_in.weight"] = unet_state_dict.pop(
523
+ f"input_hint_block.{orig_index}.weight"
524
+ )
525
+ new_checkpoint["controlnet_cond_embedding.conv_in.bias"] = unet_state_dict.pop(
526
+ f"input_hint_block.{orig_index}.bias"
527
+ )
528
+
529
+ orig_index += 2
530
+
531
+ diffusers_index = 0
532
+
533
+ while diffusers_index < 6:
534
+ new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.weight"] = unet_state_dict.pop(
535
+ f"input_hint_block.{orig_index}.weight"
536
+ )
537
+ new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.bias"] = unet_state_dict.pop(
538
+ f"input_hint_block.{orig_index}.bias"
539
+ )
540
+ diffusers_index += 1
541
+ orig_index += 2
542
+
543
+ new_checkpoint["controlnet_cond_embedding.conv_out.weight"] = unet_state_dict.pop(
544
+ f"input_hint_block.{orig_index}.weight"
545
+ )
546
+ new_checkpoint["controlnet_cond_embedding.conv_out.bias"] = unet_state_dict.pop(
547
+ f"input_hint_block.{orig_index}.bias"
548
+ )
549
+
550
+ # down blocks
551
+ for i in range(num_input_blocks):
552
+ new_checkpoint[f"controlnet_down_blocks.{i}.weight"] = unet_state_dict.pop(f"zero_convs.{i}.0.weight")
553
+ new_checkpoint[f"controlnet_down_blocks.{i}.bias"] = unet_state_dict.pop(f"zero_convs.{i}.0.bias")
554
+
555
+ # mid block
556
+ new_checkpoint["controlnet_mid_block.weight"] = unet_state_dict.pop("middle_block_out.0.weight")
557
+ new_checkpoint["controlnet_mid_block.bias"] = unet_state_dict.pop("middle_block_out.0.bias")
558
+
559
+ return new_checkpoint
560
+
561
+
562
+ def convert_ldm_vae_checkpoint(checkpoint, config, only_decoder=False, only_encoder=False):
563
+ # extract state dict for VAE
564
+ vae_state_dict = {}
565
+ vae_key = "first_stage_model."
566
+ keys = list(checkpoint.keys())
567
+ for key in keys:
568
+ if key.startswith(vae_key):
569
+ vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
570
+
571
+ new_checkpoint = {}
572
+
573
+ new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
574
+ new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
575
+ new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
576
+ new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
577
+ new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
578
+ new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
579
+
580
+ new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
581
+ new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
582
+ new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
583
+ new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
584
+ new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
585
+ new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
586
+
587
+ new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
588
+ new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
589
+ new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
590
+ new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
591
+
592
+ # Retrieves the keys for the encoder down blocks only
593
+ num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
594
+ down_blocks = {
595
+ layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
596
+ }
597
+
598
+ # Retrieves the keys for the decoder up blocks only
599
+ num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
600
+ up_blocks = {
601
+ layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
602
+ }
603
+
604
+ for i in range(num_down_blocks):
605
+ resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
606
+
607
+ if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
608
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
609
+ f"encoder.down.{i}.downsample.conv.weight"
610
+ )
611
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
612
+ f"encoder.down.{i}.downsample.conv.bias"
613
+ )
614
+
615
+ paths = renew_vae_resnet_paths(resnets)
616
+ meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
617
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
618
+
619
+ mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
620
+ num_mid_res_blocks = 2
621
+ for i in range(1, num_mid_res_blocks + 1):
622
+ resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
623
+
624
+ paths = renew_vae_resnet_paths(resnets)
625
+ meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
626
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
627
+
628
+ mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
629
+ paths = renew_vae_attention_paths(mid_attentions)
630
+ meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
631
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
632
+ conv_attn_to_linear(new_checkpoint)
633
+
634
+ for i in range(num_up_blocks):
635
+ block_id = num_up_blocks - 1 - i
636
+ resnets = [
637
+ key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
638
+ ]
639
+
640
+ if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
641
+ new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
642
+ f"decoder.up.{block_id}.upsample.conv.weight"
643
+ ]
644
+ new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
645
+ f"decoder.up.{block_id}.upsample.conv.bias"
646
+ ]
647
+
648
+ paths = renew_vae_resnet_paths(resnets)
649
+ meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
650
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
651
+
652
+ mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
653
+ num_mid_res_blocks = 2
654
+ for i in range(1, num_mid_res_blocks + 1):
655
+ resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
656
+
657
+ paths = renew_vae_resnet_paths(resnets)
658
+ meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
659
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
660
+
661
+ mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
662
+ paths = renew_vae_attention_paths(mid_attentions)
663
+ meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
664
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
665
+ conv_attn_to_linear(new_checkpoint)
666
+
667
+ if only_decoder:
668
+ new_checkpoint = {k: v for k, v in new_checkpoint.items() if k.startswith('decoder') or k.startswith('post_quant')}
669
+ elif only_encoder:
670
+ new_checkpoint = {k: v for k, v in new_checkpoint.items() if k.startswith('encoder') or k.startswith('quant')}
671
+
672
+ return new_checkpoint
673
+
674
+
675
+ def convert_ldm_bert_checkpoint(checkpoint, config):
676
+ def _copy_attn_layer(hf_attn_layer, pt_attn_layer):
677
+ hf_attn_layer.q_proj.weight.data = pt_attn_layer.to_q.weight
678
+ hf_attn_layer.k_proj.weight.data = pt_attn_layer.to_k.weight
679
+ hf_attn_layer.v_proj.weight.data = pt_attn_layer.to_v.weight
680
+
681
+ hf_attn_layer.out_proj.weight = pt_attn_layer.to_out.weight
682
+ hf_attn_layer.out_proj.bias = pt_attn_layer.to_out.bias
683
+
684
+ def _copy_linear(hf_linear, pt_linear):
685
+ hf_linear.weight = pt_linear.weight
686
+ hf_linear.bias = pt_linear.bias
687
+
688
+ def _copy_layer(hf_layer, pt_layer):
689
+ # copy layer norms
690
+ _copy_linear(hf_layer.self_attn_layer_norm, pt_layer[0][0])
691
+ _copy_linear(hf_layer.final_layer_norm, pt_layer[1][0])
692
+
693
+ # copy attn
694
+ _copy_attn_layer(hf_layer.self_attn, pt_layer[0][1])
695
+
696
+ # copy MLP
697
+ pt_mlp = pt_layer[1][1]
698
+ _copy_linear(hf_layer.fc1, pt_mlp.net[0][0])
699
+ _copy_linear(hf_layer.fc2, pt_mlp.net[2])
700
+
701
+ def _copy_layers(hf_layers, pt_layers):
702
+ for i, hf_layer in enumerate(hf_layers):
703
+ if i != 0:
704
+ i += i
705
+ pt_layer = pt_layers[i : i + 2]
706
+ _copy_layer(hf_layer, pt_layer)
707
+
708
+ hf_model = LDMBertModel(config).eval()
709
+
710
+ # copy embeds
711
+ hf_model.model.embed_tokens.weight = checkpoint.transformer.token_emb.weight
712
+ hf_model.model.embed_positions.weight.data = checkpoint.transformer.pos_emb.emb.weight
713
+
714
+ # copy layer norm
715
+ _copy_linear(hf_model.model.layer_norm, checkpoint.transformer.norm)
716
+
717
+ # copy hidden layers
718
+ _copy_layers(hf_model.model.layers, checkpoint.transformer.attn_layers.layers)
719
+
720
+ _copy_linear(hf_model.to_logits, checkpoint.transformer.to_logits)
721
+
722
+ return hf_model
723
+
724
+
725
+ def convert_ldm_clip_checkpoint(checkpoint):
726
+ keys = list(checkpoint.keys())
727
+
728
+ text_model_dict = {}
729
+ for key in keys:
730
+ if key.startswith("cond_stage_model.transformer"):
731
+ text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
732
+
733
+ return text_model_dict
734
+
735
+
736
+ textenc_conversion_lst = [
737
+ ("cond_stage_model.model.positional_embedding", "text_model.embeddings.position_embedding.weight"),
738
+ ("cond_stage_model.model.token_embedding.weight", "text_model.embeddings.token_embedding.weight"),
739
+ ("cond_stage_model.model.ln_final.weight", "text_model.final_layer_norm.weight"),
740
+ ("cond_stage_model.model.ln_final.bias", "text_model.final_layer_norm.bias"),
741
+ ]
742
+ textenc_conversion_map = {x[0]: x[1] for x in textenc_conversion_lst}
743
+
744
+ textenc_transformer_conversion_lst = [
745
+ # (stable-diffusion, HF Diffusers)
746
+ ("resblocks.", "text_model.encoder.layers."),
747
+ ("ln_1", "layer_norm1"),
748
+ ("ln_2", "layer_norm2"),
749
+ (".c_fc.", ".fc1."),
750
+ (".c_proj.", ".fc2."),
751
+ (".attn", ".self_attn"),
752
+ ("ln_final.", "transformer.text_model.final_layer_norm."),
753
+ ("token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"),
754
+ ("positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"),
755
+ ]
756
+ protected = {re.escape(x[0]): x[1] for x in textenc_transformer_conversion_lst}
757
+ textenc_pattern = re.compile("|".join(protected.keys()))
758
+
759
+
760
+ def convert_paint_by_example_checkpoint(checkpoint):
761
+ config = CLIPVisionConfig.from_pretrained("openai/clip-vit-large-patch14")
762
+ model = PaintByExampleImageEncoder(config)
763
+
764
+ keys = list(checkpoint.keys())
765
+
766
+ text_model_dict = {}
767
+
768
+ for key in keys:
769
+ if key.startswith("cond_stage_model.transformer"):
770
+ text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
771
+
772
+ # load clip vision
773
+ model.model.load_state_dict(text_model_dict)
774
+
775
+ # load mapper
776
+ keys_mapper = {
777
+ k[len("cond_stage_model.mapper.res") :]: v
778
+ for k, v in checkpoint.items()
779
+ if k.startswith("cond_stage_model.mapper")
780
+ }
781
+
782
+ MAPPING = {
783
+ "attn.c_qkv": ["attn1.to_q", "attn1.to_k", "attn1.to_v"],
784
+ "attn.c_proj": ["attn1.to_out.0"],
785
+ "ln_1": ["norm1"],
786
+ "ln_2": ["norm3"],
787
+ "mlp.c_fc": ["ff.net.0.proj"],
788
+ "mlp.c_proj": ["ff.net.2"],
789
+ }
790
+
791
+ mapped_weights = {}
792
+ for key, value in keys_mapper.items():
793
+ prefix = key[: len("blocks.i")]
794
+ suffix = key.split(prefix)[-1].split(".")[-1]
795
+ name = key.split(prefix)[-1].split(suffix)[0][1:-1]
796
+ mapped_names = MAPPING[name]
797
+
798
+ num_splits = len(mapped_names)
799
+ for i, mapped_name in enumerate(mapped_names):
800
+ new_name = ".".join([prefix, mapped_name, suffix])
801
+ shape = value.shape[0] // num_splits
802
+ mapped_weights[new_name] = value[i * shape : (i + 1) * shape]
803
+
804
+ model.mapper.load_state_dict(mapped_weights)
805
+
806
+ # load final layer norm
807
+ model.final_layer_norm.load_state_dict(
808
+ {
809
+ "bias": checkpoint["cond_stage_model.final_ln.bias"],
810
+ "weight": checkpoint["cond_stage_model.final_ln.weight"],
811
+ }
812
+ )
813
+
814
+ # load final proj
815
+ model.proj_out.load_state_dict(
816
+ {
817
+ "bias": checkpoint["proj_out.bias"],
818
+ "weight": checkpoint["proj_out.weight"],
819
+ }
820
+ )
821
+
822
+ # load uncond vector
823
+ model.uncond_vector.data = torch.nn.Parameter(checkpoint["learnable_vector"])
824
+ return model
825
+
826
+
827
+ def convert_open_clip_checkpoint(checkpoint):
828
+ text_model = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2", subfolder="text_encoder")
829
+
830
+ keys = list(checkpoint.keys())
831
+
832
+ text_model_dict = {}
833
+
834
+ if "cond_stage_model.model.text_projection" in checkpoint:
835
+ d_model = int(checkpoint["cond_stage_model.model.text_projection"].shape[0])
836
+ else:
837
+ d_model = 1024
838
+
839
+ text_model_dict["text_model.embeddings.position_ids"] = text_model.text_model.embeddings.get_buffer("position_ids")
840
+
841
+ for key in keys:
842
+ if "resblocks.23" in key: # Diffusers drops the final layer and only uses the penultimate layer
843
+ continue
844
+ if key in textenc_conversion_map:
845
+ text_model_dict[textenc_conversion_map[key]] = checkpoint[key]
846
+ if key.startswith("cond_stage_model.model.transformer."):
847
+ new_key = key[len("cond_stage_model.model.transformer.") :]
848
+ if new_key.endswith(".in_proj_weight"):
849
+ new_key = new_key[: -len(".in_proj_weight")]
850
+ new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key)
851
+ text_model_dict[new_key + ".q_proj.weight"] = checkpoint[key][:d_model, :]
852
+ text_model_dict[new_key + ".k_proj.weight"] = checkpoint[key][d_model : d_model * 2, :]
853
+ text_model_dict[new_key + ".v_proj.weight"] = checkpoint[key][d_model * 2 :, :]
854
+ elif new_key.endswith(".in_proj_bias"):
855
+ new_key = new_key[: -len(".in_proj_bias")]
856
+ new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key)
857
+ text_model_dict[new_key + ".q_proj.bias"] = checkpoint[key][:d_model]
858
+ text_model_dict[new_key + ".k_proj.bias"] = checkpoint[key][d_model : d_model * 2]
859
+ text_model_dict[new_key + ".v_proj.bias"] = checkpoint[key][d_model * 2 :]
860
+ else:
861
+ new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key)
862
+
863
+ text_model_dict[new_key] = checkpoint[key]
864
+
865
+ text_model.load_state_dict(text_model_dict)
866
+
867
+ return text_model
868
+
869
+
870
+ def stable_unclip_image_encoder(original_config):
871
+ """
872
+ Returns the image processor and clip image encoder for the img2img unclip pipeline.
873
+
874
+ We currently know of two types of stable unclip models which separately use the clip and the openclip image
875
+ encoders.
876
+ """
877
+
878
+ image_embedder_config = original_config.model.params.embedder_config
879
+
880
+ sd_clip_image_embedder_class = image_embedder_config.target
881
+ sd_clip_image_embedder_class = sd_clip_image_embedder_class.split(".")[-1]
882
+
883
+ if sd_clip_image_embedder_class == "ClipImageEmbedder":
884
+ clip_model_name = image_embedder_config.params.model
885
+
886
+ if clip_model_name == "ViT-L/14":
887
+ feature_extractor = CLIPImageProcessor()
888
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14")
889
+ else:
890
+ raise NotImplementedError(f"Unknown CLIP checkpoint name in stable diffusion checkpoint {clip_model_name}")
891
+
892
+ elif sd_clip_image_embedder_class == "FrozenOpenCLIPImageEmbedder":
893
+ feature_extractor = CLIPImageProcessor()
894
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K")
895
+ else:
896
+ raise NotImplementedError(
897
+ f"Unknown CLIP image embedder class in stable diffusion checkpoint {sd_clip_image_embedder_class}"
898
+ )
899
+
900
+ return feature_extractor, image_encoder
901
+
902
+
903
+ def stable_unclip_image_noising_components(
904
+ original_config, clip_stats_path: Optional[str] = None, device: Optional[str] = None
905
+ ):
906
+ """
907
+ Returns the noising components for the img2img and txt2img unclip pipelines.
908
+
909
+ Converts the stability noise augmentor into
910
+ 1. a `StableUnCLIPImageNormalizer` for holding the CLIP stats
911
+ 2. a `DDPMScheduler` for holding the noise schedule
912
+
913
+ If the noise augmentor config specifies a clip stats path, the `clip_stats_path` must be provided.
914
+ """
915
+ noise_aug_config = original_config.model.params.noise_aug_config
916
+ noise_aug_class = noise_aug_config.target
917
+ noise_aug_class = noise_aug_class.split(".")[-1]
918
+
919
+ if noise_aug_class == "CLIPEmbeddingNoiseAugmentation":
920
+ noise_aug_config = noise_aug_config.params
921
+ embedding_dim = noise_aug_config.timestep_dim
922
+ max_noise_level = noise_aug_config.noise_schedule_config.timesteps
923
+ beta_schedule = noise_aug_config.noise_schedule_config.beta_schedule
924
+
925
+ image_normalizer = StableUnCLIPImageNormalizer(embedding_dim=embedding_dim)
926
+ image_noising_scheduler = DDPMScheduler(num_train_timesteps=max_noise_level, beta_schedule=beta_schedule)
927
+
928
+ if "clip_stats_path" in noise_aug_config:
929
+ if clip_stats_path is None:
930
+ raise ValueError("This stable unclip config requires a `clip_stats_path`")
931
+
932
+ clip_mean, clip_std = torch.load(clip_stats_path, map_location=device)
933
+ clip_mean = clip_mean[None, :]
934
+ clip_std = clip_std[None, :]
935
+
936
+ clip_stats_state_dict = {
937
+ "mean": clip_mean,
938
+ "std": clip_std,
939
+ }
940
+
941
+ image_normalizer.load_state_dict(clip_stats_state_dict)
942
+ else:
943
+ raise NotImplementedError(f"Unknown noise augmentor class: {noise_aug_class}")
944
+
945
+ return image_normalizer, image_noising_scheduler
946
+
947
+
948
+ def convert_controlnet_checkpoint(
949
+ checkpoint, original_config, checkpoint_path, image_size, upcast_attention, extract_ema
950
+ ):
951
+ ctrlnet_config = create_unet_diffusers_config(original_config, image_size=image_size, controlnet=True)
952
+ ctrlnet_config["upcast_attention"] = upcast_attention
953
+
954
+ ctrlnet_config.pop("sample_size")
955
+
956
+ controlnet_model = ControlNetModel(**ctrlnet_config)
957
+
958
+ converted_ctrl_checkpoint = convert_ldm_unet_checkpoint(
959
+ checkpoint, ctrlnet_config, path=checkpoint_path, extract_ema=extract_ema, controlnet=True
960
+ )
961
+
962
+ controlnet_model.load_state_dict(converted_ctrl_checkpoint)
963
+
964
+ return controlnet_model
animatediff/utils/convert_lora_safetensor_to_diffusers.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023, Haofan Wang, Qixun Wang, All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """ Conversion script for the LoRA's safetensors checkpoints. """
17
+
18
+ import argparse
19
+
20
+ import torch
21
+ from safetensors.torch import load_file
22
+
23
+ from diffusers import StableDiffusionPipeline
24
+ import pdb
25
+
26
+ def convert_lora(pipeline, state_dict, LORA_PREFIX_UNET="lora_unet", LORA_PREFIX_TEXT_ENCODER="lora_te", alpha=0.6):
27
+ # load base model
28
+ # pipeline = StableDiffusionPipeline.from_pretrained(base_model_path, torch_dtype=torch.float32)
29
+
30
+ # load LoRA weight from .safetensors
31
+ # state_dict = load_file(checkpoint_path)
32
+
33
+ visited = []
34
+
35
+ # directly update weight in diffusers model
36
+ for key in state_dict:
37
+ # it is suggested to print out the key, it usually will be something like below
38
+ # "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight"
39
+
40
+ # as we have set the alpha beforehand, so just skip
41
+ if ".alpha" in key or key in visited:
42
+ continue
43
+
44
+ if "text" in key:
45
+ layer_infos = key.split(".")[0].split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_")
46
+ curr_layer = pipeline.text_encoder
47
+ else:
48
+ layer_infos = key.split(".")[0].split(LORA_PREFIX_UNET + "_")[-1].split("_")
49
+ curr_layer = pipeline.unet
50
+
51
+ # find the target layer
52
+ temp_name = layer_infos.pop(0)
53
+ while len(layer_infos) > -1:
54
+ try:
55
+ curr_layer = curr_layer.__getattr__(temp_name)
56
+ if len(layer_infos) > 0:
57
+ temp_name = layer_infos.pop(0)
58
+ elif len(layer_infos) == 0:
59
+ break
60
+ except Exception:
61
+ if len(temp_name) > 0:
62
+ temp_name += "_" + layer_infos.pop(0)
63
+ else:
64
+ temp_name = layer_infos.pop(0)
65
+
66
+ pair_keys = []
67
+ if "lora_down" in key:
68
+ pair_keys.append(key.replace("lora_down", "lora_up"))
69
+ pair_keys.append(key)
70
+ else:
71
+ pair_keys.append(key)
72
+ pair_keys.append(key.replace("lora_up", "lora_down"))
73
+
74
+ # update weight
75
+ if len(state_dict[pair_keys[0]].shape) == 4:
76
+ weight_up = state_dict[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32)
77
+ weight_down = state_dict[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32)
78
+ curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3).to(curr_layer.weight.data.device)
79
+ else:
80
+ weight_up = state_dict[pair_keys[0]].to(torch.float32)
81
+ weight_down = state_dict[pair_keys[1]].to(torch.float32)
82
+ curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).to(curr_layer.weight.data.device)
83
+
84
+ # update visited list
85
+ for item in pair_keys:
86
+ visited.append(item)
87
+
88
+ return pipeline
89
+
90
+
91
+ def convert_lora_model_level(state_dict, unet, text_encoder=None, LORA_PREFIX_UNET="lora_unet", LORA_PREFIX_TEXT_ENCODER="lora_te", alpha=0.6):
92
+ """convert lora in model level instead of pipeline leval
93
+ """
94
+
95
+ visited = []
96
+
97
+ # directly update weight in diffusers model
98
+ for key in state_dict:
99
+ # it is suggested to print out the key, it usually will be something like below
100
+ # "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight"
101
+
102
+ # as we have set the alpha beforehand, so just skip
103
+ if ".alpha" in key or key in visited:
104
+ continue
105
+
106
+ if "text" in key:
107
+ layer_infos = key.split(".")[0].split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_")
108
+ assert text_encoder is not None, (
109
+ 'text_encoder must be passed since lora contains text encoder layers')
110
+ curr_layer = text_encoder
111
+ else:
112
+ layer_infos = key.split(".")[0].split(LORA_PREFIX_UNET + "_")[-1].split("_")
113
+ curr_layer = unet
114
+
115
+ # find the target layer
116
+ temp_name = layer_infos.pop(0)
117
+ while len(layer_infos) > -1:
118
+ try:
119
+ curr_layer = curr_layer.__getattr__(temp_name)
120
+ if len(layer_infos) > 0:
121
+ temp_name = layer_infos.pop(0)
122
+ elif len(layer_infos) == 0:
123
+ break
124
+ except Exception:
125
+ if len(temp_name) > 0:
126
+ temp_name += "_" + layer_infos.pop(0)
127
+ else:
128
+ temp_name = layer_infos.pop(0)
129
+
130
+ pair_keys = []
131
+ if "lora_down" in key:
132
+ pair_keys.append(key.replace("lora_down", "lora_up"))
133
+ pair_keys.append(key)
134
+ else:
135
+ pair_keys.append(key)
136
+ pair_keys.append(key.replace("lora_up", "lora_down"))
137
+
138
+ # update weight
139
+ # NOTE: load lycon, meybe have bugs :(
140
+ if 'conv_in' in pair_keys[0]:
141
+ weight_up = state_dict[pair_keys[0]].to(torch.float32)
142
+ weight_down = state_dict[pair_keys[1]].to(torch.float32)
143
+ weight_up = weight_up.view(weight_up.size(0), -1)
144
+ weight_down = weight_down.view(weight_down.size(0), -1)
145
+ shape = [e for e in curr_layer.weight.data.shape]
146
+ shape[1] = 4
147
+ curr_layer.weight.data[:, :4, ...] += alpha * (weight_up @ weight_down).view(*shape)
148
+ elif 'conv' in pair_keys[0]:
149
+ weight_up = state_dict[pair_keys[0]].to(torch.float32)
150
+ weight_down = state_dict[pair_keys[1]].to(torch.float32)
151
+ weight_up = weight_up.view(weight_up.size(0), -1)
152
+ weight_down = weight_down.view(weight_down.size(0), -1)
153
+ shape = [e for e in curr_layer.weight.data.shape]
154
+ curr_layer.weight.data += alpha * (weight_up @ weight_down).view(*shape)
155
+ elif len(state_dict[pair_keys[0]].shape) == 4:
156
+ weight_up = state_dict[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32)
157
+ weight_down = state_dict[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32)
158
+ curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3).to(curr_layer.weight.data.device)
159
+ else:
160
+ weight_up = state_dict[pair_keys[0]].to(torch.float32)
161
+ weight_down = state_dict[pair_keys[1]].to(torch.float32)
162
+ curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).to(curr_layer.weight.data.device)
163
+
164
+ # update visited list
165
+ for item in pair_keys:
166
+ visited.append(item)
167
+
168
+ return unet, text_encoder
169
+
170
+
171
+ if __name__ == "__main__":
172
+ parser = argparse.ArgumentParser()
173
+
174
+ parser.add_argument(
175
+ "--base_model_path", default=None, type=str, required=True, help="Path to the base model in diffusers format."
176
+ )
177
+ parser.add_argument(
178
+ "--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert."
179
+ )
180
+ parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
181
+ parser.add_argument(
182
+ "--lora_prefix_unet", default="lora_unet", type=str, help="The prefix of UNet weight in safetensors"
183
+ )
184
+ parser.add_argument(
185
+ "--lora_prefix_text_encoder",
186
+ default="lora_te",
187
+ type=str,
188
+ help="The prefix of text encoder weight in safetensors",
189
+ )
190
+ parser.add_argument("--alpha", default=0.75, type=float, help="The merging ratio in W = W0 + alpha * deltaW")
191
+ parser.add_argument(
192
+ "--to_safetensors", action="store_true", help="Whether to store pipeline in safetensors format or not."
193
+ )
194
+ parser.add_argument("--device", type=str, help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)")
195
+
196
+ args = parser.parse_args()
197
+
198
+ base_model_path = args.base_model_path
199
+ checkpoint_path = args.checkpoint_path
200
+ dump_path = args.dump_path
201
+ lora_prefix_unet = args.lora_prefix_unet
202
+ lora_prefix_text_encoder = args.lora_prefix_text_encoder
203
+ alpha = args.alpha
204
+
205
+ pipe = convert(base_model_path, checkpoint_path, lora_prefix_unet, lora_prefix_text_encoder, alpha)
206
+
207
+ pipe = pipe.to(args.device)
208
+ pipe.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors)
animatediff/utils/util.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import imageio
3
+ import numpy as np
4
+ from typing import Union, Optional
5
+
6
+ import torch
7
+ import torchvision
8
+ import torch.distributed as dist
9
+
10
+ from tqdm import tqdm
11
+ from einops import rearrange
12
+ import cv2
13
+ import math
14
+ import moviepy.editor as mpy
15
+ from PIL import Image
16
+
17
+ # We recommend to use the following affinity score(motion magnitude)
18
+ # Also encourage to try to construct different score by yourself
19
+ # RANGE_LIST = [
20
+ # [1.0, 0.9, 0.85, 0.85, 0.85, 0.8], # 0 Small Motion
21
+ # [1.0, 0.8, 0.8, 0.8, 0.79, 0.78, 0.75], # Moderate Motion
22
+ # [1.0, 0.8, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.6, 0.5, 0.5], # Large Motion
23
+ # # [1.0, 0.8, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.6], # Large Motion
24
+ # # [1.0, 0.65, 0.6], # candidate moderate
25
+ # # [1.0, 0.65, 0.6, 0.6, 0.6, 0.5, 0.5, 0.5, 0.5, 0.4], # candidate large
26
+ # [1.0 , 0.9 , 0.85, 0.85, 0.85, 0.8 , 0.8 , 0.8 , 0.8 , 0.8 , 0.8 , 0.8 , 0.85, 0.85, 0.9 , 1.0 ], # Loop
27
+ # [1.0 , 0.8 , 0.8 , 0.8 , 0.79, 0.78, 0.75, 0.75, 0.75, 0.75, 0.75, 0.78, 0.79, 0.8 , 0.8 , 1.0 ], # Loop
28
+ # [1.0 , 0.8 , 0.7 , 0.7 , 0.7 , 0.7 , 0.6 , 0.5 , 0.5 , 0.6 , 0.7 , 0.7 , 0.7 , 0.7 , 0.8 , 1.0 ], # Loop
29
+ # # [1.0], # Static
30
+ # # [0],
31
+ # # [0.6, 0.5, 0.5, 0.45, 0.45, 0.4], # Style Transfer Test
32
+ # # [0.4, 0.3, 0.3, 0.25, 0.25, 0.2], # Style Transfer
33
+ # [0.5, 0.2], # Style Transfer Large Motion
34
+ # [0.5, 0.4, 0.4, 0.4, 0.35, 0.35, 0.3, 0.25, 0.2], # Style Transfer Moderate Motion
35
+ # [0.5, 0.4, 0.4, 0.4, 0.35, 0.3], # Style Transfer Candidate Small Motion
36
+ # ]
37
+ RANGE_LIST = [
38
+ [0.5, 0.4, 0.4, 0.4, 0.35, 0.3], # Style Transfer Candidate Small Motion
39
+ [0.5, 0.4, 0.4, 0.4, 0.35, 0.35, 0.3, 0.25, 0.2], # Style Transfer Moderate Motion
40
+ [0.5, 0.2], # Style Transfer Large Motion
41
+ ]
42
+
43
+
44
+ def zero_rank_print(s):
45
+ if (not dist.is_initialized()) or (dist.is_initialized() and dist.get_rank() == 0): print("### " + s)
46
+
47
+ def save_videos_mp4(video: torch.Tensor, path: str, fps: int=8):
48
+ video = rearrange(video, "b c t h w -> t b c h w")
49
+ num_frames, batch_size, channels, height, width = video.shape
50
+ assert batch_size == 1,\
51
+ 'Only support batch size == 1'
52
+ video = video.squeeze(1)
53
+ video = rearrange(video, "t c h w -> t h w c")
54
+ def make_frame(t):
55
+ frame_tensor = video[int(t * fps)]
56
+ frame_np = (frame_tensor * 255).numpy().astype('uint8')
57
+ return frame_np
58
+ clip = mpy.VideoClip(make_frame, duration=num_frames / fps)
59
+ clip.write_videofile(path, fps=fps, codec='libx264')
60
+
61
+ def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8):
62
+ videos = rearrange(videos, "b c t h w -> t b c h w")
63
+ outputs = []
64
+ for x in videos:
65
+ x = torchvision.utils.make_grid(x, nrow=n_rows)
66
+ x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
67
+ if rescale:
68
+ x = (x + 1.0) / 2.0 # -1,1 -> 0,1
69
+ x = torch.clamp((x * 255), 0, 255).numpy().astype(np.uint8)
70
+ outputs.append(x)
71
+
72
+ os.makedirs(os.path.dirname(path), exist_ok=True)
73
+ imageio.mimsave(path, outputs, fps=fps)
74
+
75
+
76
+ # DDIM Inversion
77
+ @torch.no_grad()
78
+ def init_prompt(prompt, pipeline):
79
+ uncond_input = pipeline.tokenizer(
80
+ [""], padding="max_length", max_length=pipeline.tokenizer.model_max_length,
81
+ return_tensors="pt"
82
+ )
83
+ uncond_embeddings = pipeline.text_encoder(uncond_input.input_ids.to(pipeline.device))[0]
84
+ text_input = pipeline.tokenizer(
85
+ [prompt],
86
+ padding="max_length",
87
+ max_length=pipeline.tokenizer.model_max_length,
88
+ truncation=True,
89
+ return_tensors="pt",
90
+ )
91
+ text_embeddings = pipeline.text_encoder(text_input.input_ids.to(pipeline.device))[0]
92
+ context = torch.cat([uncond_embeddings, text_embeddings])
93
+
94
+ return context
95
+
96
+
97
+ def next_step(model_output: Union[torch.FloatTensor, np.ndarray], timestep: int,
98
+ sample: Union[torch.FloatTensor, np.ndarray], ddim_scheduler):
99
+ timestep, next_timestep = min(
100
+ timestep - ddim_scheduler.config.num_train_timesteps // ddim_scheduler.num_inference_steps, 999), timestep
101
+ alpha_prod_t = ddim_scheduler.alphas_cumprod[timestep] if timestep >= 0 else ddim_scheduler.final_alpha_cumprod
102
+ alpha_prod_t_next = ddim_scheduler.alphas_cumprod[next_timestep]
103
+ beta_prod_t = 1 - alpha_prod_t
104
+ next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5
105
+ next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output
106
+ next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction
107
+ return next_sample
108
+
109
+
110
+ def get_noise_pred_single(latents, t, context, unet):
111
+ noise_pred = unet(latents, t, encoder_hidden_states=context)["sample"]
112
+ return noise_pred
113
+
114
+
115
+ @torch.no_grad()
116
+ def ddim_loop(pipeline, ddim_scheduler, latent, num_inv_steps, prompt):
117
+ context = init_prompt(prompt, pipeline)
118
+ uncond_embeddings, cond_embeddings = context.chunk(2)
119
+ all_latent = [latent]
120
+ latent = latent.clone().detach()
121
+ for i in tqdm(range(num_inv_steps)):
122
+ t = ddim_scheduler.timesteps[len(ddim_scheduler.timesteps) - i - 1]
123
+ noise_pred = get_noise_pred_single(latent, t, cond_embeddings, pipeline.unet)
124
+ latent = next_step(noise_pred, t, latent, ddim_scheduler)
125
+ all_latent.append(latent)
126
+ return all_latent
127
+
128
+
129
+ @torch.no_grad()
130
+ def ddim_inversion(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt=""):
131
+ ddim_latents = ddim_loop(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt)
132
+ return ddim_latents
133
+
134
+ def prepare_mask_coef(video_length:int, cond_frame:int, sim_range:list=[0.2, 1.0]):
135
+
136
+ assert len(sim_range) == 2, \
137
+ 'sim_range should has the length of 2, including the min and max similarity'
138
+
139
+ assert video_length > 1, \
140
+ 'video_length should be greater than 1'
141
+
142
+ assert video_length > cond_frame,\
143
+ 'video_length should be greater than cond_frame'
144
+
145
+ diff = abs(sim_range[0] - sim_range[1]) / (video_length - 1)
146
+ coef = [1.0] * video_length
147
+ for f in range(video_length):
148
+ f_diff = diff * abs(cond_frame - f)
149
+ f_diff = 1 - f_diff
150
+ coef[f] *= f_diff
151
+
152
+ return coef
153
+
154
+ def prepare_mask_coef_by_score(video_shape: list, cond_frame_idx: list, sim_range: list = [0.2, 1.0],
155
+ statistic: list = [1, 100], coef_max: int = 0.98, score: Optional[torch.Tensor] = None):
156
+ '''
157
+ the shape of video_data is (b f c h w)
158
+ cond_frame_idx is a list, with length of batch_size
159
+ the shape of statistic is (f 2)
160
+ the shape of score is (b f)
161
+ the shape of coef is (b f)
162
+ '''
163
+ assert len(video_shape) == 2, \
164
+ f'the shape of video_shape should be (b f c h w), but now get {len(video_shape.shape)} channels'
165
+
166
+ batch_size, frame_num = video_shape[0], video_shape[1]
167
+
168
+ score = score.permute(0, 2, 1).squeeze(0)
169
+
170
+ # list -> b 1
171
+ cond_fram_mat = torch.tensor(cond_frame_idx).unsqueeze(-1)
172
+
173
+ statistic = torch.tensor(statistic)
174
+ # (f 2) -> (b f 2)
175
+ statistic = statistic.repeat(batch_size, 1, 1)
176
+
177
+ # shape of order (b f), shape of cond_mat (b f)
178
+ order = torch.arange(0, frame_num, 1)
179
+ order = order.repeat(batch_size, 1)
180
+ cond_mat = torch.ones((batch_size, frame_num)) * cond_fram_mat
181
+ order = abs(order - cond_mat)
182
+
183
+ statistic = statistic[:,order.to(torch.long)][0,:,:,:]
184
+
185
+ # score (b f) max_s (b f 1)
186
+ max_stats = torch.max(statistic, dim=2).values.to(dtype=score.dtype)
187
+ min_stats = torch.min(statistic, dim=2).values.to(dtype=score.dtype)
188
+
189
+ score[score > max_stats] = max_stats[score > max_stats] * 0.95
190
+ score[score < min_stats] = min_stats[score < min_stats]
191
+
192
+ eps = 1e-10
193
+ coef = 1 - abs((score / (max_stats + eps)) * (max(sim_range) - min(sim_range)))
194
+
195
+ indices = torch.arange(coef.shape[0]).unsqueeze(1)
196
+ coef[indices, cond_fram_mat] = 1.0
197
+
198
+ return coef
199
+
200
+
201
+ def prepare_mask_coef_by_statistics(video_length: int, cond_frame: int, sim_range: int,
202
+ coef: Optional[list] = None):
203
+ """
204
+ coef: User defined coef, if passed, `sim_range` index will be ignored. This is useful
205
+ for defining custom style transform coef for different models.
206
+ """
207
+ assert video_length > 1, \
208
+ 'video_length should be greater than 1'
209
+
210
+ assert video_length > cond_frame,\
211
+ 'video_length should be greater than cond_frame'
212
+
213
+ # Recommend index: 13
214
+
215
+ # range_list = [
216
+ # # [0.8, 0.8, 0.7, 0.6],
217
+ # [1.0, 0.8, 0.7, 0.6],
218
+ # [1.0, 0.8, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.6, 0.5, 0.5],
219
+ # [1.0, 0.9, 0.85, 0.85, 0.85, 0.8], # 0
220
+ # [1.0, 0.9, 0.8, 0.7],
221
+ # [1.0, 0.8, 0.7, 0.6, 0.7, 0.6],
222
+ # [1.0, 0.9, 0.85],
223
+ # # [1.0, 0.9, 0.7, 0.5, 0.3, 0.2],
224
+ # # [1.0, 0.8, 0.6, 0.4],
225
+ # # [1.0, 0.65, 0.6], # 1
226
+ # [1.0, 0.6, 0.4], # 2
227
+ # [1.0, 0.2, 0.2],
228
+ # # [1.0, 0.8, 0.6, 0.6, 0.5, 0.5, 0.4],
229
+ # # [1.0, 0.9, 0.9, 0.9, 0.9, 0.8],
230
+ # # [1.0, 0.65, 0.6, 0.6, 0.5, 0.5, 0.4],
231
+ # # [1.0, 0.9, 0.9, 0.9, 0.7, 0.7, 0.6, 0.5, 0.4],
232
+ # [1.0, 0.8, 0.8, 0.8, 0.79, 0.78, 0.75], # 4 style_transfer
233
+ # [1.0, 0.9, 0.9],
234
+ # [0.8, 0.7, 0.6],
235
+ # [0.8, 0.8, 0.8, 0.8, 0.7],
236
+ # [0.9, 0.6, 0.6, 0.6, 0.5, 0.4, 0.2],
237
+ # # [1.0, 0.91, 0.9, 0.89, 0.88, 0.87],
238
+ # # [1.0, 0.7, 0.65, 0.65, 0.65, 0.65, 0.6],
239
+ # # [1.0, 0.85, 0.9, 0.85, 0.9, 0.85],
240
+ # # [1.0, 0.8, 0.82, 0.84, 0.86, 0.88, 0.78, 0.82, 0.84],
241
+ # # [1.0],
242
+ # ]
243
+ range_list = RANGE_LIST
244
+
245
+ assert sim_range < len(range_list),\
246
+ f'sim_range type{sim_range} not implemented'
247
+
248
+ if coef is None:
249
+ coef = range_list[sim_range]
250
+ coef = coef + ([coef[-1]] * (video_length - len(coef)))
251
+
252
+ order = [abs(i - cond_frame) for i in range(video_length)]
253
+ coef = [coef[order[i]] for i in range(video_length)]
254
+
255
+ return coef
app-counterfeit-only.py ADDED
@@ -0,0 +1,441 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import os.path as osp
4
+ import random
5
+ from argparse import ArgumentParser
6
+ from datetime import datetime
7
+
8
+ import gradio as gr
9
+ import numpy as np
10
+ import openxlab
11
+ import torch
12
+ from diffusers import DDIMScheduler, EulerDiscreteScheduler, PNDMScheduler
13
+ from omegaconf import OmegaConf
14
+ from openxlab.model import download
15
+ from PIL import Image
16
+
17
+ from animatediff.pipelines import I2VPipeline
18
+ from animatediff.utils.util import RANGE_LIST, save_videos_grid
19
+
20
+ sample_idx = 0
21
+ scheduler_dict = {
22
+ "DDIM": DDIMScheduler,
23
+ "Euler": EulerDiscreteScheduler,
24
+ "PNDM": PNDMScheduler,
25
+ }
26
+
27
+ css = """
28
+ .toolbutton {
29
+ margin-buttom: 0em 0em 0em 0em;
30
+ max-width: 2.5em;
31
+ min-width: 2.5em !important;
32
+ height: 2.5em;
33
+ }
34
+ """
35
+
36
+ parser = ArgumentParser()
37
+ parser.add_argument('--config', type=str, default='example/config/base.yaml')
38
+ parser.add_argument('--server-name', type=str, default='0.0.0.0')
39
+ parser.add_argument('--port', type=int, default=7860)
40
+ parser.add_argument('--share', action='store_true')
41
+ parser.add_argument('--local-debug', action='store_true')
42
+
43
+ parser.add_argument('--save-path', default='samples')
44
+
45
+ args = parser.parse_args()
46
+ LOCAL_DEBUG = args.local_debug
47
+
48
+
49
+ BASE_CONFIG = 'example/config/base.yaml'
50
+ STYLE_CONFIG_LIST = {
51
+ 'anime': './example/openxlab/2-animation.yaml',
52
+ }
53
+
54
+
55
+ # download models
56
+ PIA_PATH = './models/PIA'
57
+ VAE_PATH = './models/VAE'
58
+ DreamBooth_LoRA_PATH = './models/DreamBooth_LoRA'
59
+
60
+
61
+ if not LOCAL_DEBUG:
62
+ CACHE_PATH = '/home/xlab-app-center/.cache/model'
63
+
64
+ PIA_PATH = osp.join(CACHE_PATH, 'PIA')
65
+ VAE_PATH = osp.join(CACHE_PATH, 'VAE')
66
+ DreamBooth_LoRA_PATH = osp.join(CACHE_PATH, 'DreamBooth_LoRA')
67
+ STABLE_DIFFUSION_PATH = osp.join(CACHE_PATH, 'StableDiffusion')
68
+
69
+ IP_ADAPTER_PATH = osp.join(CACHE_PATH, 'IP_Adapter')
70
+
71
+ os.makedirs(PIA_PATH, exist_ok=True)
72
+ os.makedirs(VAE_PATH, exist_ok=True)
73
+ os.makedirs(DreamBooth_LoRA_PATH, exist_ok=True)
74
+ os.makedirs(STABLE_DIFFUSION_PATH, exist_ok=True)
75
+
76
+ openxlab.login(os.environ['OPENXLAB_AK'], os.environ['OPENXLAB_SK'])
77
+ download(model_repo='zhangyiming/PIA-pruned', model_name='PIA', output=PIA_PATH)
78
+ download(model_repo='zhangyiming/Counterfeit-V3.0',
79
+ model_name='Counterfeit-V3.0_fp32_pruned', output=DreamBooth_LoRA_PATH)
80
+ download(model_repo='zhangyiming/kl-f8-anime2_VAE',
81
+ model_name='kl-f8-anime2', output=VAE_PATH)
82
+
83
+ # ip_adapter
84
+ download(model_repo='zhangyiming/IP-Adapter',
85
+ model_name='clip_encoder', output=osp.join(IP_ADAPTER_PATH, 'image_encoder'))
86
+ download(model_repo='zhangyiming/IP-Adapter',
87
+ model_name='config', output=osp.join(IP_ADAPTER_PATH, 'image_encoder'))
88
+ download(model_repo='zhangyiming/IP-Adapter',
89
+ model_name='ip_adapter_sd15', output=IP_ADAPTER_PATH)
90
+
91
+ # unet
92
+ download(model_repo='zhangyiming/runwayml_stable-diffusion-v1-5_Unet',
93
+ model_name='unet', output=osp.join(STABLE_DIFFUSION_PATH, 'unet'))
94
+ download(model_repo='zhangyiming/runwayml_stable-diffusion-v1-5_Unet',
95
+ model_name='config', output=osp.join(STABLE_DIFFUSION_PATH, 'unet'))
96
+
97
+ # vae
98
+ download(model_repo='zhangyiming/runwayml_stable-diffusion-v1-5_VAE',
99
+ model_name='vae', output=osp.join(STABLE_DIFFUSION_PATH, 'vae'))
100
+ download(model_repo='zhangyiming/runwayml_stable-diffusion-v1-5_VAE',
101
+ model_name='config', output=osp.join(STABLE_DIFFUSION_PATH, 'vae'))
102
+
103
+ # text encoder
104
+ download(model_repo='zhangyiming/runwayml_stable-diffusion-v1-5_TextEncod',
105
+ model_name='text_encoder', output=osp.join(STABLE_DIFFUSION_PATH, 'text_encoder'))
106
+ download(model_repo='zhangyiming/runwayml_stable-diffusion-v1-5_TextEncod',
107
+ model_name='config', output=osp.join(STABLE_DIFFUSION_PATH, 'text_encoder'))
108
+
109
+ # tokenizer
110
+ download(model_repo='zhangyiming/runwayml_stable-diffusion-v1-5_Tokenizer',
111
+ model_name='merge', output=osp.join(STABLE_DIFFUSION_PATH, 'tokenizer'))
112
+ download(model_repo='zhangyiming/runwayml_stable-diffusion-v1-5_Tokenizer',
113
+ model_name='special_tokens_map', output=osp.join(STABLE_DIFFUSION_PATH, 'tokenizer'))
114
+ download(model_repo='zhangyiming/runwayml_stable-diffusion-v1-5_Tokenizer',
115
+ model_name='tokenizer_config', output=osp.join(STABLE_DIFFUSION_PATH, 'tokenizer'))
116
+ download(model_repo='zhangyiming/runwayml_stable-diffusion-v1-5_Tokenizer',
117
+ model_name='vocab', output=osp.join(STABLE_DIFFUSION_PATH, 'tokenizer'))
118
+
119
+ # scheduler
120
+ scheduler_dict = {
121
+ "_class_name": "PNDMScheduler",
122
+ "_diffusers_version": "0.6.0",
123
+ "beta_end": 0.012,
124
+ "beta_schedule": "scaled_linear",
125
+ "beta_start": 0.00085,
126
+ "num_train_timesteps": 1000,
127
+ "set_alpha_to_one": False,
128
+ "skip_prk_steps": True,
129
+ "steps_offset": 1,
130
+ "trained_betas": None,
131
+ "clip_sample": False
132
+ }
133
+ os.makedirs(osp.join(STABLE_DIFFUSION_PATH, 'scheduler'), exist_ok=True)
134
+ with open(osp.join(STABLE_DIFFUSION_PATH, 'scheduler', 'scheduler_config.json'), 'w') as file:
135
+ json.dump(scheduler_dict, file)
136
+
137
+ # model index
138
+ model_index_dict = {
139
+ "_class_name": "StableDiffusionPipeline",
140
+ "_diffusers_version": "0.6.0",
141
+ "feature_extractor": [
142
+ "transformers",
143
+ "CLIPImageProcessor"
144
+ ],
145
+ "safety_checker": [
146
+ "stable_diffusion",
147
+ "StableDiffusionSafetyChecker"
148
+ ],
149
+ "scheduler": [
150
+ "diffusers",
151
+ "PNDMScheduler"
152
+ ],
153
+ "text_encoder": [
154
+ "transformers",
155
+ "CLIPTextModel"
156
+ ],
157
+ "tokenizer": [
158
+ "transformers",
159
+ "CLIPTokenizer"
160
+ ],
161
+ "unet": [
162
+ "diffusers",
163
+ "UNet2DConditionModel"
164
+ ],
165
+ "vae": [
166
+ "diffusers",
167
+ "AutoencoderKL"
168
+ ]
169
+ }
170
+ with open(osp.join(STABLE_DIFFUSION_PATH, 'model_index.json'), 'w') as file:
171
+ json.dump(model_index_dict, file)
172
+
173
+ else:
174
+ PIA_PATH = './models/PIA'
175
+ VAE_PATH = './models/VAE'
176
+ DreamBooth_LoRA_PATH = './models/DreamBooth_LoRA'
177
+ STABLE_DIFFUSION_PATH = './models/StableDiffusion/sd15'
178
+
179
+
180
+ def preprocess_img(img_np, max_size: int = 512):
181
+
182
+ ori_image = Image.fromarray(img_np).convert('RGB')
183
+
184
+ width, height = ori_image.size
185
+
186
+ short_edge = max(width, height)
187
+ if short_edge > max_size:
188
+ scale_factor = max_size / short_edge
189
+ else:
190
+ scale_factor = 1
191
+ width = int(width * scale_factor)
192
+ height = int(height * scale_factor)
193
+ ori_image = ori_image.resize((width, height))
194
+
195
+ if (width % 8 != 0) or (height % 8 != 0):
196
+ in_width = (width // 8) * 8
197
+ in_height = (height // 8) * 8
198
+ else:
199
+ in_width = width
200
+ in_height = height
201
+ in_image = ori_image
202
+
203
+ in_image = ori_image.resize((in_width, in_height))
204
+ in_image_np = np.array(in_image)
205
+ return in_image_np, in_height, in_width
206
+
207
+
208
+ class AnimateController:
209
+ def __init__(self):
210
+
211
+ # config dirs
212
+ self.basedir = os.getcwd()
213
+ self.savedir = os.path.join(
214
+ self.basedir, args.save_path, datetime.now().strftime("Gradio-%Y-%m-%dT%H-%M-%S"))
215
+ self.savedir_sample = os.path.join(self.savedir, "sample")
216
+ os.makedirs(self.savedir, exist_ok=True)
217
+
218
+ self.inference_config = OmegaConf.load(args.config)
219
+ self.style_configs = {k: OmegaConf.load(
220
+ v) for k, v in STYLE_CONFIG_LIST.items()}
221
+
222
+ self.pipeline_dict = self.load_model_list()
223
+
224
+ def load_model_list(self):
225
+ pipeline_dict = dict()
226
+ for style, cfg in self.style_configs.items():
227
+ dreambooth_path = cfg.get('dreambooth', 'none')
228
+ if dreambooth_path and dreambooth_path.upper() != 'NONE':
229
+ dreambooth_path = osp.join(
230
+ DreamBooth_LoRA_PATH, dreambooth_path)
231
+ lora_path = cfg.get('lora', None)
232
+ if lora_path is not None:
233
+ lora_path = osp.join(DreamBooth_LoRA_PATH, lora_path)
234
+ lora_alpha = cfg.get('lora_alpha', 0.0)
235
+ vae_path = cfg.get('vae', None)
236
+ if vae_path is not None:
237
+ vae_path = osp.join(VAE_PATH, vae_path)
238
+
239
+ pipeline_dict[style] = I2VPipeline.build_pipeline(
240
+ self.inference_config,
241
+ STABLE_DIFFUSION_PATH,
242
+ unet_path=osp.join(PIA_PATH, 'pia.ckpt'),
243
+ dreambooth_path=dreambooth_path,
244
+ lora_path=lora_path,
245
+ lora_alpha=lora_alpha,
246
+ vae_path=vae_path,
247
+ ip_adapter_path='h94/IP-Adapter',
248
+ ip_adapter_scale=0.1)
249
+ return pipeline_dict
250
+
251
+ def fetch_default_n_prompt(self, style: str):
252
+ cfg = self.style_configs[style]
253
+ n_prompt = cfg.get('n_prompt', '')
254
+ ip_adapter_scale = cfg.get('real_ip_adapter_scale', 0)
255
+
256
+ gr.Info('Set default negative prompt and ip_adapter_scale.')
257
+ print('Set default negative prompt and ip_adapter_scale.')
258
+
259
+ return n_prompt, ip_adapter_scale
260
+
261
+ def animate(
262
+ self,
263
+ init_img,
264
+ motion_scale,
265
+ prompt_textbox,
266
+ negative_prompt_textbox,
267
+ sampler_dropdown,
268
+ sample_step_slider,
269
+ cfg_scale_slider,
270
+ seed_textbox,
271
+ ip_adapter_scale,
272
+ style,
273
+ progress=gr.Progress(),
274
+ ):
275
+
276
+ if seed_textbox != -1 and seed_textbox != "":
277
+ torch.manual_seed(int(seed_textbox))
278
+ else:
279
+ torch.seed()
280
+ seed = torch.initial_seed()
281
+
282
+ pipeline = self.pipeline_dict[style]
283
+ init_img, h, w = preprocess_img(init_img)
284
+
285
+ sample = pipeline(
286
+ image=init_img,
287
+ prompt=prompt_textbox,
288
+ negative_prompt=negative_prompt_textbox,
289
+ num_inference_steps=sample_step_slider,
290
+ guidance_scale=cfg_scale_slider,
291
+ width=w,
292
+ height=h,
293
+ video_length=16,
294
+ mask_sim_template_idx=motion_scale - 1,
295
+ ip_adapter_scale=ip_adapter_scale,
296
+ progress_fn=progress,
297
+ ).videos
298
+
299
+ save_sample_path = os.path.join(
300
+ self.savedir_sample, f"{sample_idx}.mp4")
301
+ save_videos_grid(sample, save_sample_path)
302
+
303
+ sample_config = {
304
+ "prompt": prompt_textbox,
305
+ "n_prompt": negative_prompt_textbox,
306
+ "sampler": sampler_dropdown,
307
+ "num_inference_steps": sample_step_slider,
308
+ "guidance_scale": cfg_scale_slider,
309
+ "width": w,
310
+ "height": h,
311
+ "seed": seed,
312
+ "motion": motion_scale,
313
+ }
314
+ json_str = json.dumps(sample_config, indent=4)
315
+ with open(os.path.join(self.savedir, "logs.json"), "a") as f:
316
+ f.write(json_str)
317
+ f.write("\n\n")
318
+
319
+ return save_sample_path
320
+
321
+
322
+ controller = AnimateController()
323
+
324
+
325
+ def ui():
326
+ with gr.Blocks(css=css) as demo:
327
+
328
+ gr.HTML(
329
+ "<div align='center'><font size='7'> <img src=\"file/pia.png\" style=\"height: 72px;\"/ > Your Personalized Image Animator</font></div>"
330
+ "<div align='center'><font size='7'>via Plug-and-Play Modules in Text-to-Image Models </font></div>"
331
+ )
332
+ with gr.Row():
333
+ gr.Markdown(
334
+ "<div align='center'><font size='5'><a href='https://pi-animator.github.io/'>Project Page</a> &ensp;" # noqa
335
+ "<a href='https://arxiv.org/abs/2312.13964/'>Paper</a> &ensp;"
336
+ "<a href='https://github.com/open-mmlab/PIA'>Code</a> &ensp;" # noqa
337
+ # "Try More Style: <a href='https://openxlab.org.cn/apps/detail/zhangyiming/PiaPia'>Click Here!</a> </font></div>" # noqa
338
+ "Try More Style: <a href='https://openxlab.org.cn/apps/detail/zhangyiming/PiaPia'>Click here! </a></font></div>" # noqa
339
+ )
340
+
341
+ with gr.Row(equal_height=False):
342
+ with gr.Column():
343
+ with gr.Row():
344
+ init_img = gr.Image(label='Input Image')
345
+
346
+ style_dropdown = gr.Dropdown(label='Style', choices=list(
347
+ STYLE_CONFIG_LIST.keys()), value=list(STYLE_CONFIG_LIST.keys())[0])
348
+
349
+ with gr.Row():
350
+ prompt_textbox = gr.Textbox(label="Prompt", lines=1)
351
+ gift_button = gr.Button(
352
+ value='🎁', elem_classes='toolbutton'
353
+ )
354
+
355
+ def append_gift(prompt):
356
+ rand = random.randint(0, 2)
357
+ if rand == 1:
358
+ prompt = prompt + 'wearing santa hats'
359
+ elif rand == 2:
360
+ prompt = prompt + 'lift a Christmas gift'
361
+ else:
362
+ prompt = prompt + 'in Christmas suit, lift a Christmas gift'
363
+ gr.Info('Merry Christmas! Add magic to your prompt!')
364
+ return prompt
365
+
366
+ gift_button.click(
367
+ fn=append_gift,
368
+ inputs=[prompt_textbox],
369
+ outputs=[prompt_textbox],
370
+ )
371
+
372
+ prompt_textbox = gr.Textbox(label="Prompt", lines=1)
373
+
374
+ motion_scale_silder = gr.Slider(
375
+ label='Motion Scale (Larger value means larger motion but less identity consistency)', value=2, step=1, minimum=1, maximum=len(RANGE_LIST))
376
+ ip_adapter_scale = gr.Slider(
377
+ label='IP-Apdater Scale', value=controller.fetch_default_n_prompt(
378
+ list(STYLE_CONFIG_LIST.keys())[0])[1], minimum=0, maximum=1)
379
+
380
+ with gr.Accordion('Advance Options', open=False):
381
+ negative_prompt_textbox = gr.Textbox(
382
+ value=controller.fetch_default_n_prompt(
383
+ list(STYLE_CONFIG_LIST.keys())[0])[0],
384
+ label="Negative prompt", lines=2)
385
+
386
+ with gr.Row():
387
+ sampler_dropdown = gr.Dropdown(label="Sampling method", choices=list(
388
+ scheduler_dict.keys()), value=list(scheduler_dict.keys())[0])
389
+ sample_step_slider = gr.Slider(
390
+ label="Sampling steps", value=20, minimum=10, maximum=100, step=1)
391
+
392
+ cfg_scale_slider = gr.Slider(
393
+ label="CFG Scale", value=7.5, minimum=0, maximum=20)
394
+
395
+ with gr.Row():
396
+ seed_textbox = gr.Textbox(label="Seed", value=-1)
397
+ seed_button = gr.Button(
398
+ value="\U0001F3B2", elem_classes="toolbutton")
399
+ seed_button.click(
400
+ fn=lambda x: random.randint(1, 1e8),
401
+ outputs=[seed_textbox],
402
+ queue=False
403
+ )
404
+
405
+ generate_button = gr.Button(
406
+ value="Generate", variant='primary')
407
+
408
+ result_video = gr.Video(
409
+ label="Generated Animation", interactive=False)
410
+
411
+ style_dropdown.change(fn=controller.fetch_default_n_prompt,
412
+ inputs=[style_dropdown],
413
+ outputs=[negative_prompt_textbox, ip_adapter_scale], queue=False)
414
+
415
+ generate_button.click(
416
+ fn=controller.animate,
417
+ inputs=[
418
+ init_img,
419
+ motion_scale_silder,
420
+ prompt_textbox,
421
+ negative_prompt_textbox,
422
+ sampler_dropdown,
423
+ sample_step_slider,
424
+ cfg_scale_slider,
425
+ seed_textbox,
426
+ ip_adapter_scale,
427
+ style_dropdown,
428
+ ],
429
+ outputs=[result_video]
430
+ )
431
+
432
+ return demo
433
+
434
+
435
+ if __name__ == "__main__":
436
+ demo = ui()
437
+ demo.queue(max_size=10)
438
+ demo.launch(server_name=args.server_name,
439
+ server_port=args.port, share=args.share,
440
+ max_threads=10,
441
+ allowed_paths=['pia.png'])
app-huggingface.py ADDED
@@ -0,0 +1,525 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import os.path as osp
4
+ import random
5
+ from argparse import ArgumentParser
6
+ from datetime import datetime
7
+
8
+ import gradio as gr
9
+ import numpy as np
10
+ import torch
11
+ from huggingface_hub import hf_hub_download
12
+ from omegaconf import OmegaConf
13
+ from PIL import Image
14
+
15
+ from animatediff.pipelines import I2VPipeline
16
+ from animatediff.utils.util import RANGE_LIST, save_videos_grid
17
+
18
+ sample_idx = 0
19
+
20
+ css = """
21
+ .toolbutton {
22
+ margin-buttom: 0em 0em 0em 0em;
23
+ max-width: 2.5em;
24
+ min-width: 2.5em !important;
25
+ height: 2.5em;
26
+ }
27
+ """
28
+
29
+ parser = ArgumentParser()
30
+ parser.add_argument('--config', type=str, default='example/config/base.yaml')
31
+ parser.add_argument('--server-name', type=str, default='0.0.0.0')
32
+ parser.add_argument('--port', type=int, default=7860)
33
+ parser.add_argument('--share', action='store_true')
34
+ parser.add_argument('--local-debug', action='store_true')
35
+
36
+ parser.add_argument('--save-path', default='samples')
37
+
38
+ args = parser.parse_args()
39
+ LOCAL_DEBUG = args.local_debug
40
+
41
+
42
+ BASE_CONFIG = 'example/config/base.yaml'
43
+ STYLE_CONFIG_LIST = {
44
+ '3d_cartoon': './example/openxlab/3-3d.yaml',
45
+ 'realistic': './example/openxlab/1-realistic.yaml',
46
+ }
47
+
48
+
49
+ # download models
50
+ PIA_PATH = './models/PIA'
51
+ VAE_PATH = './models/VAE'
52
+ DreamBooth_LoRA_PATH = './models/DreamBooth_LoRA'
53
+
54
+
55
+ if not LOCAL_DEBUG:
56
+ CACHE_PATH = './models'
57
+
58
+ PIA_PATH = osp.join(CACHE_PATH, 'PIA')
59
+ VAE_PATH = osp.join(CACHE_PATH, 'VAE')
60
+ DreamBooth_LoRA_PATH = osp.join(CACHE_PATH, 'DreamBooth_LoRA')
61
+ STABLE_DIFFUSION_PATH = osp.join(CACHE_PATH, 'StableDiffusion')
62
+
63
+ os.makedirs(PIA_PATH, exist_ok=True)
64
+ os.makedirs(VAE_PATH, exist_ok=True)
65
+ os.makedirs(DreamBooth_LoRA_PATH, exist_ok=True)
66
+ os.makedirs(STABLE_DIFFUSION_PATH, exist_ok=True)
67
+
68
+ hf_hub_download(model_repo='leoxing/PIA-pruned',
69
+ model_name='PIA', output=PIA_PATH)
70
+ os.system('bash download_bashscripts/1-RealisticVision.sh')
71
+ os.system('bash download_bashscripts/2-RcnzCartoon.sh')
72
+ print(os.listdir(DreamBooth_LoRA_PATH))
73
+
74
+ # unet
75
+ unet_full_path = hf_hub_download(repo_id='runwayml/stable-diffusion-v1-5',
76
+ subfolder='unet', filename='diffusion_pytorch_model.bin',
77
+ cache_dir='models/StableDiffusion')
78
+ STABLE_DIFFUSION_PATH = '/'.join(unet_full_path.split('/')[:-2])
79
+ hf_hub_download(repo_id='runwayml/stable-diffusion-v1-5',
80
+ subfolder='unet', filename='config.json',
81
+ cache_dir='models/StableDiffusion')
82
+
83
+ # vae
84
+ hf_hub_download(model_repo='runwayml/stable-diffusion-v1-5',
85
+ subfolder='vae', filename='config.json',
86
+ cache_dir='models/StableDiffusion')
87
+ hf_hub_download(model_repo='runwayml/stable-diffusion-v1-5',
88
+ subfolder='vae', filename='diffusion_pytorch_model.bin',
89
+ cache_dir='models/StableDiffusion')
90
+
91
+ # text encoder
92
+ hf_hub_download(model_repo='runwayml/stable-diffusion-v1-5',
93
+ subfolder='text_encoder', filename='config.json',
94
+ cache_dir='models/StableDiffusion')
95
+ hf_hub_download(model_repo='runwayml/stable-diffusion-v1-5',
96
+ subfolder='text_encoder', filename='pytorch_model.bin',
97
+ cache_dir='models/StableDiffusion')
98
+
99
+ # tokenizer
100
+ hf_hub_download(model_repo='runwayml/stable-diffusion-v1-5',
101
+ subfolder='tokenizer', filename='merges.txt',
102
+ cache_dir='models/StableDiffusion')
103
+ hf_hub_download(model_repo='runwayml/stable-diffusion-v1-5',
104
+ subfolder='tokenizer', filename='special_tokens_map.json',
105
+ cache_dir='models/StableDiffusion')
106
+ hf_hub_download(model_repo='runwayml/stable-diffusion-v1-5',
107
+ subfolder='tokenizer', filename='tokenizer_config.json',
108
+ cache_dir='models/StableDiffusion')
109
+ hf_hub_download(model_repo='runwayml/stable-diffusion-v1-5',
110
+ subfolder='tokenizer', filename='vocab.json',
111
+ cache_dir='models/StableDiffusion')
112
+
113
+ # scheduler
114
+ hf_hub_download(model_repo='runwayml/stable-diffusion-v1-5',
115
+ subfolder='scheduler', filename='scheduler_config.json',
116
+ cache_dir='models/StableDiffusion')
117
+
118
+ # model index
119
+ hf_hub_download(model_repo='runwayml/stable-diffusion-v1-5', filename='model_index.json',
120
+ cache_dir='models/StableDiffusion')
121
+
122
+ else:
123
+ PIA_PATH = './models/PIA'
124
+ VAE_PATH = './models/VAE'
125
+ DreamBooth_LoRA_PATH = './models/DreamBooth_LoRA'
126
+ STABLE_DIFFUSION_PATH = './models/StableDiffusion/sd15'
127
+
128
+
129
+ def preprocess_img(img_np, max_size: int = 512):
130
+
131
+ ori_image = Image.fromarray(img_np).convert('RGB')
132
+
133
+ width, height = ori_image.size
134
+
135
+ short_edge = max(width, height)
136
+ if short_edge > max_size:
137
+ scale_factor = max_size / short_edge
138
+ else:
139
+ scale_factor = 1
140
+ width = int(width * scale_factor)
141
+ height = int(height * scale_factor)
142
+ ori_image = ori_image.resize((width, height))
143
+
144
+ if (width % 8 != 0) or (height % 8 != 0):
145
+ in_width = (width // 8) * 8
146
+ in_height = (height // 8) * 8
147
+ else:
148
+ in_width = width
149
+ in_height = height
150
+ in_image = ori_image
151
+
152
+ in_image = ori_image.resize((in_width, in_height))
153
+ in_image_np = np.array(in_image)
154
+ return in_image_np, in_height, in_width
155
+
156
+
157
+ class AnimateController:
158
+ def __init__(self):
159
+
160
+ # config dirs
161
+ self.basedir = os.getcwd()
162
+ self.savedir = os.path.join(
163
+ self.basedir, args.save_path, datetime.now().strftime("Gradio-%Y-%m-%dT%H-%M-%S"))
164
+ self.savedir_sample = os.path.join(self.savedir, "sample")
165
+ os.makedirs(self.savedir, exist_ok=True)
166
+
167
+ self.inference_config = OmegaConf.load(args.config)
168
+ self.style_configs = {k: OmegaConf.load(
169
+ v) for k, v in STYLE_CONFIG_LIST.items()}
170
+
171
+ self.pipeline_dict = self.load_model_list()
172
+
173
+ def load_model_list(self):
174
+ pipeline_dict = dict()
175
+ for style, cfg in self.style_configs.items():
176
+ dreambooth_path = cfg.get('dreambooth', 'none')
177
+ if dreambooth_path and dreambooth_path.upper() != 'NONE':
178
+ dreambooth_path = osp.join(
179
+ DreamBooth_LoRA_PATH, dreambooth_path)
180
+ lora_path = cfg.get('lora', None)
181
+ if lora_path is not None:
182
+ lora_path = osp.join(DreamBooth_LoRA_PATH, lora_path)
183
+ lora_alpha = cfg.get('lora_alpha', 0.0)
184
+ vae_path = cfg.get('vae', None)
185
+ if vae_path is not None:
186
+ vae_path = osp.join(VAE_PATH, vae_path)
187
+
188
+ pipeline_dict[style] = I2VPipeline.build_pipeline(
189
+ self.inference_config,
190
+ STABLE_DIFFUSION_PATH,
191
+ unet_path=osp.join(PIA_PATH, 'pia.ckpt'),
192
+ dreambooth_path=dreambooth_path,
193
+ lora_path=lora_path,
194
+ lora_alpha=lora_alpha,
195
+ vae_path=vae_path,
196
+ ip_adapter_path='h94/IP-Adapter',
197
+ ip_adapter_scale=0.1)
198
+ return pipeline_dict
199
+
200
+ def fetch_default_n_prompt(self, style: str):
201
+ cfg = self.style_configs[style]
202
+ n_prompt = cfg.get('n_prompt', '')
203
+ ip_adapter_scale = cfg.get('real_ip_adapter_scale', 0)
204
+
205
+ gr.Info('Set default negative prompt and ip_adapter_scale.')
206
+ print('Set default negative prompt and ip_adapter_scale.')
207
+
208
+ return n_prompt, ip_adapter_scale
209
+
210
+ def animate(
211
+ self,
212
+ init_img,
213
+ motion_scale,
214
+ prompt_textbox,
215
+ negative_prompt_textbox,
216
+ sample_step_slider,
217
+ cfg_scale_slider,
218
+ seed_textbox,
219
+ ip_adapter_scale,
220
+ style,
221
+ progress=gr.Progress(),
222
+ ):
223
+
224
+ if seed_textbox != -1 and seed_textbox != "":
225
+ torch.manual_seed(int(seed_textbox))
226
+ else:
227
+ torch.seed()
228
+ seed = torch.initial_seed()
229
+
230
+ pipeline = self.pipeline_dict[style]
231
+ init_img, h, w = preprocess_img(init_img)
232
+ sample = pipeline(
233
+ image=init_img,
234
+ prompt=prompt_textbox,
235
+ negative_prompt=negative_prompt_textbox,
236
+ num_inference_steps=sample_step_slider,
237
+ guidance_scale=cfg_scale_slider,
238
+ width=w,
239
+ height=h,
240
+ video_length=16,
241
+ mask_sim_template_idx=motion_scale - 1,
242
+ ip_adapter_scale=ip_adapter_scale,
243
+ progress_fn=progress,
244
+ ).videos
245
+
246
+ save_sample_path = os.path.join(
247
+ self.savedir_sample, f"{sample_idx}.mp4")
248
+ save_videos_grid(sample, save_sample_path)
249
+
250
+ sample_config = {
251
+ "prompt": prompt_textbox,
252
+ "n_prompt": negative_prompt_textbox,
253
+ "num_inference_steps": sample_step_slider,
254
+ "guidance_scale": cfg_scale_slider,
255
+ "width": w,
256
+ "height": h,
257
+ "seed": seed,
258
+ "motion": motion_scale,
259
+ }
260
+ json_str = json.dumps(sample_config, indent=4)
261
+ with open(os.path.join(self.savedir, "logs.json"), "a") as f:
262
+ f.write(json_str)
263
+ f.write("\n\n")
264
+
265
+ return save_sample_path
266
+
267
+
268
+ controller = AnimateController()
269
+
270
+
271
+ def ui():
272
+ with gr.Blocks(css=css) as demo:
273
+
274
+ gr.HTML(
275
+ "<div align='center'><font size='7'> <img src=\"file/pia.png\" style=\"height: 72px;\"/ > Your Personalized Image Animator</font></div>"
276
+ "<div align='center'><font size='7'>via Plug-and-Play Modules in Text-to-Image Models </font></div>"
277
+ )
278
+ with gr.Row():
279
+ gr.Markdown(
280
+ "<div align='center'><font size='5'><a href='https://pi-animator.github.io/'>Project Page</a> &ensp;" # noqa
281
+ "<a href='https://arxiv.org/abs/2312.13964/'>Paper</a> &ensp;"
282
+ "<a href='https://github.com/open-mmlab/PIA'>Code</a> &ensp;" # noqa
283
+ "Try More Style: <a href='https://openxlab.org.cn/apps/detail/zhangyiming/PiaPia-AnimationStyle'>Click here! </a></font></div>" # noqa
284
+ )
285
+
286
+ with gr.Row(equal_height=False):
287
+ with gr.Column():
288
+ with gr.Row():
289
+ init_img = gr.Image(label='Input Image')
290
+
291
+ style_dropdown = gr.Dropdown(label='Style', choices=list(
292
+ STYLE_CONFIG_LIST.keys()), value=list(STYLE_CONFIG_LIST.keys())[0])
293
+
294
+ with gr.Row():
295
+ prompt_textbox = gr.Textbox(label="Prompt", lines=1)
296
+ gift_button = gr.Button(
297
+ value='🎁', elem_classes='toolbutton'
298
+ )
299
+
300
+ def append_gift(prompt):
301
+ rand = random.randint(0, 2)
302
+ if rand == 1:
303
+ prompt = prompt + 'wearing santa hats'
304
+ elif rand == 2:
305
+ prompt = prompt + 'lift a Christmas gift'
306
+ else:
307
+ prompt = prompt + 'in Christmas suit, lift a Christmas gift'
308
+ gr.Info('Merry Christmas! Add magic to your prompt!')
309
+ return prompt
310
+
311
+ gift_button.click(
312
+ fn=append_gift,
313
+ inputs=[prompt_textbox],
314
+ outputs=[prompt_textbox],
315
+ )
316
+
317
+ motion_scale_silder = gr.Slider(
318
+ label='Motion Scale (Larger value means larger motion but less identity consistency)',
319
+ value=1, step=1, minimum=1, maximum=len(RANGE_LIST))
320
+ ip_adapter_scale = gr.Slider(
321
+ label='IP-Apdater Scale', value=controller.fetch_default_n_prompt(
322
+ list(STYLE_CONFIG_LIST.keys())[0])[1], minimum=0, maximum=1)
323
+
324
+ with gr.Accordion('Advance Options', open=False):
325
+ negative_prompt_textbox = gr.Textbox(
326
+ value=controller.fetch_default_n_prompt(
327
+ list(STYLE_CONFIG_LIST.keys())[0])[0],
328
+ label="Negative prompt", lines=2)
329
+
330
+ sample_step_slider = gr.Slider(
331
+ label="Sampling steps", value=20, minimum=10, maximum=100, step=1)
332
+
333
+ cfg_scale_slider = gr.Slider(
334
+ label="CFG Scale", value=7.5, minimum=0, maximum=20)
335
+
336
+ with gr.Row():
337
+ seed_textbox = gr.Textbox(label="Seed", value=-1)
338
+ seed_button = gr.Button(
339
+ value="\U0001F3B2", elem_classes="toolbutton")
340
+ seed_button.click(
341
+ fn=lambda x: random.randint(1, 1e8),
342
+ outputs=[seed_textbox],
343
+ queue=False
344
+ )
345
+
346
+ generate_button = gr.Button(
347
+ value="Generate", variant='primary')
348
+
349
+ result_video = gr.Video(
350
+ label="Generated Animation", interactive=False)
351
+
352
+ style_dropdown.change(fn=controller.fetch_default_n_prompt,
353
+ inputs=[style_dropdown],
354
+ outputs=[negative_prompt_textbox,
355
+ ip_adapter_scale],
356
+ queue=False)
357
+
358
+ generate_button.click(
359
+ fn=controller.animate,
360
+ inputs=[
361
+ init_img,
362
+ motion_scale_silder,
363
+ prompt_textbox,
364
+ negative_prompt_textbox,
365
+ sample_step_slider,
366
+ cfg_scale_slider,
367
+ seed_textbox,
368
+ ip_adapter_scale,
369
+ style_dropdown,
370
+ ],
371
+ outputs=[result_video]
372
+ )
373
+
374
+ def create_example(input_list):
375
+ return gr.Examples(
376
+ examples=input_list,
377
+ inputs=[
378
+ init_img,
379
+ result_video,
380
+ prompt_textbox,
381
+ negative_prompt_textbox,
382
+ style_dropdown,
383
+ motion_scale_silder,
384
+ ],
385
+ )
386
+
387
+ gr.Markdown(
388
+ '### Merry Christmas!'
389
+ )
390
+ create_example(
391
+ [
392
+ [
393
+ '__assets__/image_animation/yiming/yiming.jpeg',
394
+ '__assets__/image_animation/yiming/yiming.mp4',
395
+ '1boy in Christmas suit, lift a Christmas gift',
396
+ 'wrong white balance, dark, sketches,worst quality,low quality, deformed, distorted, disfigured, bad eyes, wrong lips,weird mouth, bad teeth, mutated hands and fingers, bad anatomy,wrong anatomy, amputation, extra limb, missing limb, floating,limbs, disconnected limbs, mutation, ugly, disgusting, bad_pictures, negative_hand-neg',
397
+ '3d_cartoon',
398
+ 2,
399
+ ],
400
+ [
401
+ '__assets__/image_animation/yanhong/yanhong.png',
402
+ '__assets__/image_animation/yanhong/yanhong.mp4',
403
+ '1girl lift a Christmas gift',
404
+ 'wrong white balance, dark, sketches,worst quality,low quality, deformed, distorted, disfigured, bad eyes, wrong lips,weird mouth, bad teeth, mutated hands and fingers, bad anatomy,wrong anatomy, amputation, extra limb, missing limb, floating,limbs, disconnected limbs, mutation, ugly, disgusting, bad_pictures, negative_hand-neg',
405
+ '3d_cartoon',
406
+ 2,
407
+ ],
408
+ ],
409
+
410
+ )
411
+
412
+ with gr.Accordion('More Examples for Style Transfer', open=False):
413
+ create_example([
414
+ [
415
+
416
+ '__assets__/image_animation/style_transfer/anya/anya.jpg',
417
+ '__assets__/image_animation/style_transfer/anya/2.mp4',
418
+ '1girl open mouth ',
419
+ 'wrong white balance, dark, sketches,worst quality,low quality, deformed, distorted, disfigured, bad eyes, wrong lips,weird mouth, bad teeth, mutated hands and fingers, bad anatomy,wrong anatomy, amputation, extra limb, missing limb, floating,limbs, disconnected limbs, mutation, ugly, disgusting, bad_pictures, negative_hand-neg',
420
+ '3d_cartoon',
421
+ 3,
422
+ ],
423
+ [
424
+ '__assets__/image_animation/magnitude/genshin/genshin.jpg',
425
+ '__assets__/image_animation/magnitude/genshin/3.mp4',
426
+ 'cherry blossoms in the wind, raidenshogundef, yaemikodef, best quality, 4k',
427
+ 'wrong white balance, dark, sketches,worst quality,low quality, deformed, distorted, disfigured, bad eyes, wrong lips,weird mouth, bad teeth, mutated hands and fingers, bad anatomy,wrong anatomy, amputation, extra limb, missing limb, floating,limbs, disconnected limbs, mutation, ugly, disgusting, bad_pictures, negative_hand-neg',
428
+ '3d_cartoon',
429
+ 3,
430
+ ],
431
+
432
+ ])
433
+
434
+ with gr.Accordion('More Examples for Prompt Changing', open=False):
435
+ create_example(
436
+ [
437
+ [
438
+ '__assets__/image_animation/real/lighthouse.jpg',
439
+ '__assets__/image_animation/real/1.mp4',
440
+ 'lightning, lighthouse',
441
+ 'wrong white balance, dark, sketches,worst quality,low quality, deformed, distorted, disfigured, bad eyes, wrong lips,weird mouth, bad teeth, mutated hands and fingers, bad anatomy,wrong anatomy, amputation, extra limb, missing limb, floating,limbs, disconnected limbs, mutation, ugly, disgusting, bad_pictures, negative_hand-neg',
442
+ 'realistic',
443
+ 1,
444
+ ],
445
+ [
446
+ '__assets__/image_animation/real/lighthouse.jpg',
447
+ '__assets__/image_animation/real/2.mp4',
448
+ 'sun rising, lighthouse',
449
+ 'wrong white balance, dark, sketches,worst quality,low quality, deformed, distorted, disfigured, bad eyes, wrong lips,weird mouth, bad teeth, mutated hands and fingers, bad anatomy,wrong anatomy, amputation, extra limb, missing limb, floating,limbs, disconnected limbs, mutation, ugly, disgusting, bad_pictures, negative_hand-neg',
450
+ 'realistic',
451
+ 1,
452
+ ],
453
+ [
454
+ '__assets__/image_animation/real/lighthouse.jpg',
455
+ '__assets__/image_animation/real/3.mp4',
456
+ 'fireworks, lighthouse',
457
+ 'wrong white balance, dark, sketches,worst quality,low quality, deformed, distorted, disfigured, bad eyes, wrong lips,weird mouth, bad teeth, mutated hands and fingers, bad anatomy,wrong anatomy, amputation, extra limb, missing limb, floating,limbs, disconnected limbs, mutation, ugly, disgusting, bad_pictures, negative_hand-neg',
458
+ 'realistic',
459
+ 1,
460
+ ],
461
+ [
462
+ '__assets__/image_animation/rcnz/harry.png',
463
+ '__assets__/image_animation/rcnz/1.mp4',
464
+ '1boy smiling',
465
+ 'wrong white balance, dark, sketches,worst quality,low quality, deformed, distorted, disfigured, bad eyes, wrong lips,weird mouth, bad teeth, mutated hands and fingers, bad anatomy,wrong anatomy, amputation, extra limb, missing limb, floating,limbs, disconnected limbs, mutation, ugly, disgusting, bad_pictures, negative_hand-neg',
466
+ '3d_cartoon',
467
+ 2
468
+ ],
469
+ [
470
+ '__assets__/image_animation/rcnz/harry.png',
471
+ '__assets__/image_animation/rcnz/2.mp4',
472
+ '1boy playing magic fire',
473
+ 'wrong white balance, dark, sketches,worst quality,low quality, deformed, distorted, disfigured, bad eyes, wrong lips,weird mouth, bad teeth, mutated hands and fingers, bad anatomy,wrong anatomy, amputation, extra limb, missing limb, floating,limbs, disconnected limbs, mutation, ugly, disgusting, bad_pictures, negative_hand-neg',
474
+ '3d_cartoon',
475
+ 2
476
+ ],
477
+ [
478
+ '__assets__/image_animation/rcnz/harry.png',
479
+ '__assets__/image_animation/rcnz/3.mp4',
480
+ '1boy is waving hands',
481
+ 'wrong white balance, dark, sketches,worst quality,low quality, deformed, distorted, disfigured, bad eyes, wrong lips,weird mouth, bad teeth, mutated hands and fingers, bad anatomy,wrong anatomy, amputation, extra limb, missing limb, floating,limbs, disconnected limbs, mutation, ugly, disgusting, bad_pictures, negative_hand-neg',
482
+ '3d_cartoon',
483
+ 2
484
+ ]
485
+ ])
486
+
487
+ with gr.Accordion('Examples for Motion Magnitude', open=False):
488
+ create_example(
489
+ [
490
+ [
491
+ '__assets__/image_animation/magnitude/labrador.png',
492
+ '__assets__/image_animation/magnitude/1.mp4',
493
+ 'cherry blossoms in the wind, raidenshogundef, yaemikodef, best quality, 4k',
494
+ 'wrong white balance, dark, sketches,worst quality,low quality, deformed, distorted, disfigured, bad eyes, wrong lips,weird mouth, bad teeth, mutated hands and fingers, bad anatomy,wrong anatomy, amputation, extra limb, missing limb, floating,limbs, disconnected limbs, mutation, ugly, disgusting, bad_pictures, negative_hand-neg',
495
+ '3d_cartoon',
496
+ 1,
497
+ ],
498
+ [
499
+ '__assets__/image_animation/magnitude/labrador.png',
500
+ '__assets__/image_animation/magnitude/2.mp4',
501
+ 'cherry blossoms in the wind, raidenshogundef, yaemikodef, best quality, 4k',
502
+ 'wrong white balance, dark, sketches,worst quality,low quality, deformed, distorted, disfigured, bad eyes, wrong lips,weird mouth, bad teeth, mutated hands and fingers, bad anatomy,wrong anatomy, amputation, extra limb, missing limb, floating,limbs, disconnected limbs, mutation, ugly, disgusting, bad_pictures, negative_hand-neg',
503
+ '3d_cartoon',
504
+ 2,
505
+ ],
506
+ [
507
+ '__assets__/image_animation/magnitude/labrador.png',
508
+ '__assets__/image_animation/magnitude/3.mp4',
509
+ 'cherry blossoms in the wind, raidenshogundef, yaemikodef, best quality, 4k',
510
+ 'wrong white balance, dark, sketches,worst quality,low quality, deformed, distorted, disfigured, bad eyes, wrong lips,weird mouth, bad teeth, mutated hands and fingers, bad anatomy,wrong anatomy, amputation, extra limb, missing limb, floating,limbs, disconnected limbs, mutation, ugly, disgusting, bad_pictures, negative_hand-neg',
511
+ '3d_cartoon',
512
+ 3,
513
+ ]
514
+ ])
515
+
516
+ return demo
517
+
518
+
519
+ if __name__ == "__main__":
520
+ demo = ui()
521
+ demo.queue(max_size=10)
522
+ demo.launch(server_name=args.server_name,
523
+ server_port=args.port, share=args.share,
524
+ max_threads=40,
525
+ allowed_paths=['pia.png'])
app.py ADDED
@@ -0,0 +1,567 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import os.path as osp
4
+ import random
5
+ from argparse import ArgumentParser
6
+ from datetime import datetime
7
+
8
+ import gradio as gr
9
+ import numpy as np
10
+ import openxlab
11
+ import torch
12
+ from diffusers import DDIMScheduler, EulerDiscreteScheduler, PNDMScheduler
13
+ from omegaconf import OmegaConf
14
+ from openxlab.model import download
15
+ from PIL import Image
16
+
17
+ from animatediff.pipelines import I2VPipeline
18
+ from animatediff.utils.util import RANGE_LIST, save_videos_grid
19
+
20
+ sample_idx = 0
21
+ # scheduler_dict = {
22
+ # "DDIM": DDIMScheduler,
23
+ # "Euler": EulerDiscreteScheduler,
24
+ # "PNDM": PNDMScheduler,
25
+ # }
26
+
27
+ css = """
28
+ .toolbutton {
29
+ margin-buttom: 0em 0em 0em 0em;
30
+ max-width: 2.5em;
31
+ min-width: 2.5em !important;
32
+ height: 2.5em;
33
+ }
34
+ """
35
+
36
+ parser = ArgumentParser()
37
+ parser.add_argument('--config', type=str, default='example/config/base.yaml')
38
+ parser.add_argument('--server-name', type=str, default='0.0.0.0')
39
+ parser.add_argument('--port', type=int, default=7860)
40
+ parser.add_argument('--share', action='store_true')
41
+ parser.add_argument('--local-debug', action='store_true')
42
+
43
+ parser.add_argument('--save-path', default='samples')
44
+
45
+ args = parser.parse_args()
46
+ LOCAL_DEBUG = args.local_debug
47
+
48
+
49
+ BASE_CONFIG = 'example/config/base.yaml'
50
+ STYLE_CONFIG_LIST = {
51
+ '3d_cartoon': './example/openxlab/3-3d.yaml',
52
+ 'realistic': './example/openxlab/1-realistic.yaml',
53
+ }
54
+
55
+
56
+ # download models
57
+ PIA_PATH = './models/PIA'
58
+ VAE_PATH = './models/VAE'
59
+ DreamBooth_LoRA_PATH = './models/DreamBooth_LoRA'
60
+
61
+
62
+ if not LOCAL_DEBUG:
63
+ CACHE_PATH = '/home/xlab-app-center/.cache/model'
64
+
65
+ PIA_PATH = osp.join(CACHE_PATH, 'PIA')
66
+ VAE_PATH = osp.join(CACHE_PATH, 'VAE')
67
+ DreamBooth_LoRA_PATH = osp.join(CACHE_PATH, 'DreamBooth_LoRA')
68
+ STABLE_DIFFUSION_PATH = osp.join(CACHE_PATH, 'StableDiffusion')
69
+
70
+ os.makedirs(PIA_PATH, exist_ok=True)
71
+ os.makedirs(VAE_PATH, exist_ok=True)
72
+ os.makedirs(DreamBooth_LoRA_PATH, exist_ok=True)
73
+ os.makedirs(STABLE_DIFFUSION_PATH, exist_ok=True)
74
+
75
+ openxlab.login(os.environ['OPENXLAB_AK'], os.environ['OPENXLAB_SK'])
76
+ download(model_repo='zhangyiming/PIA-pruned', model_name='PIA', output=PIA_PATH)
77
+ download(model_repo='zhangyiming/RCNZ_Cartoon_3d',
78
+ model_name='rcnz-cartoon-3d', output=DreamBooth_LoRA_PATH)
79
+ download(model_repo='zhangyiming/realisticVisionV51_v51VAE',
80
+ model_name='realisticVisionV51_v51VAE', output=DreamBooth_LoRA_PATH)
81
+ print(os.listdir(DreamBooth_LoRA_PATH))
82
+ # unet
83
+ download(model_repo='zhangyiming/runwayml_stable-diffusion-v1-5_Unet',
84
+ model_name='unet', output=osp.join(STABLE_DIFFUSION_PATH, 'unet'))
85
+ download(model_repo='zhangyiming/runwayml_stable-diffusion-v1-5_Unet',
86
+ model_name='config', output=osp.join(STABLE_DIFFUSION_PATH, 'unet'))
87
+
88
+ # vae
89
+ download(model_repo='zhangyiming/runwayml_stable-diffusion-v1-5_VAE',
90
+ model_name='vae', output=osp.join(STABLE_DIFFUSION_PATH, 'vae'))
91
+ download(model_repo='zhangyiming/runwayml_stable-diffusion-v1-5_VAE',
92
+ model_name='config', output=osp.join(STABLE_DIFFUSION_PATH, 'vae'))
93
+
94
+ # text encoder
95
+ download(model_repo='zhangyiming/runwayml_stable-diffusion-v1-5_TextEncod',
96
+ model_name='text_encoder', output=osp.join(STABLE_DIFFUSION_PATH, 'text_encoder'))
97
+ download(model_repo='zhangyiming/runwayml_stable-diffusion-v1-5_TextEncod',
98
+ model_name='config', output=osp.join(STABLE_DIFFUSION_PATH, 'text_encoder'))
99
+
100
+ # tokenizer
101
+ download(model_repo='zhangyiming/runwayml_stable-diffusion-v1-5_Tokenizer',
102
+ model_name='merge', output=osp.join(STABLE_DIFFUSION_PATH, 'tokenizer'))
103
+ download(model_repo='zhangyiming/runwayml_stable-diffusion-v1-5_Tokenizer',
104
+ model_name='special_tokens_map', output=osp.join(STABLE_DIFFUSION_PATH, 'tokenizer'))
105
+ download(model_repo='zhangyiming/runwayml_stable-diffusion-v1-5_Tokenizer',
106
+ model_name='tokenizer_config', output=osp.join(STABLE_DIFFUSION_PATH, 'tokenizer'))
107
+ download(model_repo='zhangyiming/runwayml_stable-diffusion-v1-5_Tokenizer',
108
+ model_name='vocab', output=osp.join(STABLE_DIFFUSION_PATH, 'tokenizer'))
109
+
110
+ # scheduler
111
+ scheduler_dict = {
112
+ "_class_name": "PNDMScheduler",
113
+ "_diffusers_version": "0.6.0",
114
+ "beta_end": 0.012,
115
+ "beta_schedule": "scaled_linear",
116
+ "beta_start": 0.00085,
117
+ "num_train_timesteps": 1000,
118
+ "set_alpha_to_one": False,
119
+ "skip_prk_steps": True,
120
+ "steps_offset": 1,
121
+ "trained_betas": None,
122
+ "clip_sample": False
123
+ }
124
+ os.makedirs(osp.join(STABLE_DIFFUSION_PATH, 'scheduler'), exist_ok=True)
125
+ with open(osp.join(STABLE_DIFFUSION_PATH, 'scheduler', 'scheduler_config.json'), 'w') as file:
126
+ json.dump(scheduler_dict, file)
127
+
128
+ # model index
129
+ model_index_dict = {
130
+ "_class_name": "StableDiffusionPipeline",
131
+ "_diffusers_version": "0.6.0",
132
+ "feature_extractor": [
133
+ "transformers",
134
+ "CLIPImageProcessor"
135
+ ],
136
+ "safety_checker": [
137
+ "stable_diffusion",
138
+ "StableDiffusionSafetyChecker"
139
+ ],
140
+ "scheduler": [
141
+ "diffusers",
142
+ "PNDMScheduler"
143
+ ],
144
+ "text_encoder": [
145
+ "transformers",
146
+ "CLIPTextModel"
147
+ ],
148
+ "tokenizer": [
149
+ "transformers",
150
+ "CLIPTokenizer"
151
+ ],
152
+ "unet": [
153
+ "diffusers",
154
+ "UNet2DConditionModel"
155
+ ],
156
+ "vae": [
157
+ "diffusers",
158
+ "AutoencoderKL"
159
+ ]
160
+ }
161
+ with open(osp.join(STABLE_DIFFUSION_PATH, 'model_index.json'), 'w') as file:
162
+ json.dump(model_index_dict, file)
163
+
164
+ else:
165
+ PIA_PATH = './models/PIA'
166
+ VAE_PATH = './models/VAE'
167
+ DreamBooth_LoRA_PATH = './models/DreamBooth_LoRA'
168
+ STABLE_DIFFUSION_PATH = './models/StableDiffusion/sd15'
169
+
170
+
171
+ def preprocess_img(img_np, max_size: int = 512):
172
+
173
+ ori_image = Image.fromarray(img_np).convert('RGB')
174
+
175
+ width, height = ori_image.size
176
+
177
+ short_edge = max(width, height)
178
+ if short_edge > max_size:
179
+ scale_factor = max_size / short_edge
180
+ else:
181
+ scale_factor = 1
182
+ width = int(width * scale_factor)
183
+ height = int(height * scale_factor)
184
+ ori_image = ori_image.resize((width, height))
185
+
186
+ if (width % 8 != 0) or (height % 8 != 0):
187
+ in_width = (width // 8) * 8
188
+ in_height = (height // 8) * 8
189
+ else:
190
+ in_width = width
191
+ in_height = height
192
+ in_image = ori_image
193
+
194
+ in_image = ori_image.resize((in_width, in_height))
195
+ in_image_np = np.array(in_image)
196
+ return in_image_np, in_height, in_width
197
+
198
+
199
+ class AnimateController:
200
+ def __init__(self):
201
+
202
+ # config dirs
203
+ self.basedir = os.getcwd()
204
+ self.savedir = os.path.join(
205
+ self.basedir, args.save_path, datetime.now().strftime("Gradio-%Y-%m-%dT%H-%M-%S"))
206
+ self.savedir_sample = os.path.join(self.savedir, "sample")
207
+ os.makedirs(self.savedir, exist_ok=True)
208
+
209
+ self.inference_config = OmegaConf.load(args.config)
210
+ self.style_configs = {k: OmegaConf.load(
211
+ v) for k, v in STYLE_CONFIG_LIST.items()}
212
+
213
+ self.pipeline_dict = self.load_model_list()
214
+
215
+ def load_model_list(self):
216
+ pipeline_dict = dict()
217
+ for style, cfg in self.style_configs.items():
218
+ dreambooth_path = cfg.get('dreambooth', 'none')
219
+ if dreambooth_path and dreambooth_path.upper() != 'NONE':
220
+ dreambooth_path = osp.join(
221
+ DreamBooth_LoRA_PATH, dreambooth_path)
222
+ lora_path = cfg.get('lora', None)
223
+ if lora_path is not None:
224
+ lora_path = osp.join(DreamBooth_LoRA_PATH, lora_path)
225
+ lora_alpha = cfg.get('lora_alpha', 0.0)
226
+ vae_path = cfg.get('vae', None)
227
+ if vae_path is not None:
228
+ vae_path = osp.join(VAE_PATH, vae_path)
229
+
230
+ pipeline_dict[style] = I2VPipeline.build_pipeline(
231
+ self.inference_config,
232
+ STABLE_DIFFUSION_PATH,
233
+ unet_path=osp.join(PIA_PATH, 'pia.ckpt'),
234
+ dreambooth_path=dreambooth_path,
235
+ lora_path=lora_path,
236
+ lora_alpha=lora_alpha,
237
+ vae_path=vae_path,
238
+ ip_adapter_path='h94/IP-Adapter',
239
+ ip_adapter_scale=0.1)
240
+ return pipeline_dict
241
+
242
+ def fetch_default_n_prompt(self, style: str):
243
+ cfg = self.style_configs[style]
244
+ n_prompt = cfg.get('n_prompt', '')
245
+ ip_adapter_scale = cfg.get('real_ip_adapter_scale', 0)
246
+
247
+ gr.Info('Set default negative prompt and ip_adapter_scale.')
248
+ print('Set default negative prompt and ip_adapter_scale.')
249
+
250
+ return n_prompt, ip_adapter_scale
251
+
252
+ def animate(
253
+ self,
254
+ init_img,
255
+ motion_scale,
256
+ prompt_textbox,
257
+ negative_prompt_textbox,
258
+ sample_step_slider,
259
+ cfg_scale_slider,
260
+ seed_textbox,
261
+ ip_adapter_scale,
262
+ style,
263
+ progress=gr.Progress(),
264
+ ):
265
+
266
+ if seed_textbox != -1 and seed_textbox != "":
267
+ torch.manual_seed(int(seed_textbox))
268
+ else:
269
+ torch.seed()
270
+ seed = torch.initial_seed()
271
+
272
+ pipeline = self.pipeline_dict[style]
273
+ init_img, h, w = preprocess_img(init_img)
274
+ sample = pipeline(
275
+ image=init_img,
276
+ prompt=prompt_textbox,
277
+ negative_prompt=negative_prompt_textbox,
278
+ num_inference_steps=sample_step_slider,
279
+ guidance_scale=cfg_scale_slider,
280
+ width=w,
281
+ height=h,
282
+ video_length=16,
283
+ mask_sim_template_idx=motion_scale - 1,
284
+ ip_adapter_scale=ip_adapter_scale,
285
+ progress_fn=progress,
286
+ ).videos
287
+
288
+ save_sample_path = os.path.join(
289
+ self.savedir_sample, f"{sample_idx}.mp4")
290
+ save_videos_grid(sample, save_sample_path)
291
+
292
+ sample_config = {
293
+ "prompt": prompt_textbox,
294
+ "n_prompt": negative_prompt_textbox,
295
+ "num_inference_steps": sample_step_slider,
296
+ "guidance_scale": cfg_scale_slider,
297
+ "width": w,
298
+ "height": h,
299
+ "seed": seed,
300
+ "motion": motion_scale,
301
+ }
302
+ json_str = json.dumps(sample_config, indent=4)
303
+ with open(os.path.join(self.savedir, "logs.json"), "a") as f:
304
+ f.write(json_str)
305
+ f.write("\n\n")
306
+
307
+ return save_sample_path
308
+
309
+
310
+ controller = AnimateController()
311
+
312
+
313
+ def ui():
314
+ with gr.Blocks(css=css) as demo:
315
+
316
+ gr.HTML(
317
+ "<div align='center'><font size='7'> <img src=\"file/pia.png\" style=\"height: 72px;\"/ > Your Personalized Image Animator</font></div>"
318
+ "<div align='center'><font size='7'>via Plug-and-Play Modules in Text-to-Image Models </font></div>"
319
+ )
320
+ with gr.Row():
321
+ gr.Markdown(
322
+ "<div align='center'><font size='5'><a href='https://pi-animator.github.io/'>Project Page</a> &ensp;" # noqa
323
+ "<a href='https://arxiv.org/abs/2312.13964/'>Paper</a> &ensp;"
324
+ "<a href='https://github.com/open-mmlab/PIA'>Code</a> &ensp;" # noqa
325
+ "Try More Style: <a href='https://openxlab.org.cn/apps/detail/zhangyiming/PiaPia-AnimationStyle'>Click here! </a></font></div>" # noqa
326
+ )
327
+
328
+ with gr.Row(equal_height=False):
329
+ with gr.Column():
330
+ with gr.Row():
331
+ init_img = gr.Image(label='Input Image')
332
+
333
+ style_dropdown = gr.Dropdown(label='Style', choices=list(
334
+ STYLE_CONFIG_LIST.keys()), value=list(STYLE_CONFIG_LIST.keys())[0])
335
+
336
+ with gr.Row():
337
+ prompt_textbox = gr.Textbox(label="Prompt", lines=1)
338
+ gift_button = gr.Button(
339
+ value='🎁', elem_classes='toolbutton'
340
+ )
341
+
342
+ def append_gift(prompt):
343
+ rand = random.randint(0, 2)
344
+ if rand == 1:
345
+ prompt = prompt + 'wearing santa hats'
346
+ elif rand == 2:
347
+ prompt = prompt + 'lift a Christmas gift'
348
+ else:
349
+ prompt = prompt + 'in Christmas suit, lift a Christmas gift'
350
+ gr.Info('Merry Christmas! Add magic to your prompt!')
351
+ return prompt
352
+
353
+ gift_button.click(
354
+ fn=append_gift,
355
+ inputs=[prompt_textbox],
356
+ outputs=[prompt_textbox],
357
+ )
358
+
359
+ motion_scale_silder = gr.Slider(
360
+ label='Motion Scale (Larger value means larger motion but less identity consistency)',
361
+ value=1, step=1, minimum=1, maximum=len(RANGE_LIST))
362
+ ip_adapter_scale = gr.Slider(
363
+ label='IP-Apdater Scale', value=controller.fetch_default_n_prompt(
364
+ list(STYLE_CONFIG_LIST.keys())[0])[1], minimum=0, maximum=1)
365
+
366
+ with gr.Accordion('Advance Options', open=False):
367
+ negative_prompt_textbox = gr.Textbox(
368
+ value=controller.fetch_default_n_prompt(
369
+ list(STYLE_CONFIG_LIST.keys())[0])[0],
370
+ label="Negative prompt", lines=2)
371
+
372
+ sample_step_slider = gr.Slider(
373
+ label="Sampling steps", value=20, minimum=10, maximum=100, step=1)
374
+
375
+ cfg_scale_slider = gr.Slider(
376
+ label="CFG Scale", value=7.5, minimum=0, maximum=20)
377
+
378
+ with gr.Row():
379
+ seed_textbox = gr.Textbox(label="Seed", value=-1)
380
+ seed_button = gr.Button(
381
+ value="\U0001F3B2", elem_classes="toolbutton")
382
+ seed_button.click(
383
+ fn=lambda x: random.randint(1, 1e8),
384
+ outputs=[seed_textbox],
385
+ queue=False
386
+ )
387
+
388
+ generate_button = gr.Button(
389
+ value="Generate", variant='primary')
390
+
391
+ result_video = gr.Video(
392
+ label="Generated Animation", interactive=False)
393
+
394
+ style_dropdown.change(fn=controller.fetch_default_n_prompt,
395
+ inputs=[style_dropdown],
396
+ outputs=[negative_prompt_textbox,
397
+ ip_adapter_scale],
398
+ queue=False)
399
+
400
+ generate_button.click(
401
+ fn=controller.animate,
402
+ inputs=[
403
+ init_img,
404
+ motion_scale_silder,
405
+ prompt_textbox,
406
+ negative_prompt_textbox,
407
+ sample_step_slider,
408
+ cfg_scale_slider,
409
+ seed_textbox,
410
+ ip_adapter_scale,
411
+ style_dropdown,
412
+ ],
413
+ outputs=[result_video]
414
+ )
415
+
416
+ def create_example(input_list):
417
+ return gr.Examples(
418
+ examples=input_list,
419
+ inputs=[
420
+ init_img,
421
+ result_video,
422
+ prompt_textbox,
423
+ negative_prompt_textbox,
424
+ style_dropdown,
425
+ motion_scale_silder,
426
+ ],
427
+ )
428
+
429
+ gr.Markdown(
430
+ '### Merry Christmas!'
431
+ )
432
+ create_example(
433
+ [
434
+ [
435
+ '__assets__/image_animation/yiming/yiming.jpeg',
436
+ '__assets__/image_animation/yiming/yiming.mp4',
437
+ '1boy in Christmas suit, lift a Christmas gift',
438
+ 'wrong white balance, dark, sketches,worst quality,low quality, deformed, distorted, disfigured, bad eyes, wrong lips,weird mouth, bad teeth, mutated hands and fingers, bad anatomy,wrong anatomy, amputation, extra limb, missing limb, floating,limbs, disconnected limbs, mutation, ugly, disgusting, bad_pictures, negative_hand-neg',
439
+ '3d_cartoon',
440
+ 2,
441
+ ],
442
+ [
443
+ '__assets__/image_animation/yanhong/yanhong.png',
444
+ '__assets__/image_animation/yanhong/yanhong.mp4',
445
+ '1girl lift a Christmas gift',
446
+ 'wrong white balance, dark, sketches,worst quality,low quality, deformed, distorted, disfigured, bad eyes, wrong lips,weird mouth, bad teeth, mutated hands and fingers, bad anatomy,wrong anatomy, amputation, extra limb, missing limb, floating,limbs, disconnected limbs, mutation, ugly, disgusting, bad_pictures, negative_hand-neg',
447
+ '3d_cartoon',
448
+ 2,
449
+ ],
450
+ ],
451
+
452
+ )
453
+
454
+ with gr.Accordion('More Examples for Style Transfer', open=False):
455
+ create_example([
456
+ [
457
+
458
+ '__assets__/image_animation/style_transfer/anya/anya.jpg',
459
+ '__assets__/image_animation/style_transfer/anya/2.mp4',
460
+ '1girl open mouth ',
461
+ 'wrong white balance, dark, sketches,worst quality,low quality, deformed, distorted, disfigured, bad eyes, wrong lips,weird mouth, bad teeth, mutated hands and fingers, bad anatomy,wrong anatomy, amputation, extra limb, missing limb, floating,limbs, disconnected limbs, mutation, ugly, disgusting, bad_pictures, negative_hand-neg',
462
+ '3d_cartoon',
463
+ 3,
464
+ ],
465
+ [
466
+ '__assets__/image_animation/magnitude/genshin/genshin.jpg',
467
+ '__assets__/image_animation/magnitude/genshin/3.mp4',
468
+ 'cherry blossoms in the wind, raidenshogundef, yaemikodef, best quality, 4k',
469
+ 'wrong white balance, dark, sketches,worst quality,low quality, deformed, distorted, disfigured, bad eyes, wrong lips,weird mouth, bad teeth, mutated hands and fingers, bad anatomy,wrong anatomy, amputation, extra limb, missing limb, floating,limbs, disconnected limbs, mutation, ugly, disgusting, bad_pictures, negative_hand-neg',
470
+ '3d_cartoon',
471
+ 3,
472
+ ],
473
+
474
+ ])
475
+
476
+ with gr.Accordion('More Examples for Prompt Changing', open=False):
477
+ create_example(
478
+ [
479
+ [
480
+ '__assets__/image_animation/real/lighthouse.jpg',
481
+ '__assets__/image_animation/real/1.mp4',
482
+ 'lightning, lighthouse',
483
+ 'wrong white balance, dark, sketches,worst quality,low quality, deformed, distorted, disfigured, bad eyes, wrong lips,weird mouth, bad teeth, mutated hands and fingers, bad anatomy,wrong anatomy, amputation, extra limb, missing limb, floating,limbs, disconnected limbs, mutation, ugly, disgusting, bad_pictures, negative_hand-neg',
484
+ 'realistic',
485
+ 1,
486
+ ],
487
+ [
488
+ '__assets__/image_animation/real/lighthouse.jpg',
489
+ '__assets__/image_animation/real/2.mp4',
490
+ 'sun rising, lighthouse',
491
+ 'wrong white balance, dark, sketches,worst quality,low quality, deformed, distorted, disfigured, bad eyes, wrong lips,weird mouth, bad teeth, mutated hands and fingers, bad anatomy,wrong anatomy, amputation, extra limb, missing limb, floating,limbs, disconnected limbs, mutation, ugly, disgusting, bad_pictures, negative_hand-neg',
492
+ 'realistic',
493
+ 1,
494
+ ],
495
+ [
496
+ '__assets__/image_animation/real/lighthouse.jpg',
497
+ '__assets__/image_animation/real/3.mp4',
498
+ 'fireworks, lighthouse',
499
+ 'wrong white balance, dark, sketches,worst quality,low quality, deformed, distorted, disfigured, bad eyes, wrong lips,weird mouth, bad teeth, mutated hands and fingers, bad anatomy,wrong anatomy, amputation, extra limb, missing limb, floating,limbs, disconnected limbs, mutation, ugly, disgusting, bad_pictures, negative_hand-neg',
500
+ 'realistic',
501
+ 1,
502
+ ],
503
+ [
504
+ '__assets__/image_animation/rcnz/harry.png',
505
+ '__assets__/image_animation/rcnz/1.mp4',
506
+ '1boy smiling',
507
+ 'wrong white balance, dark, sketches,worst quality,low quality, deformed, distorted, disfigured, bad eyes, wrong lips,weird mouth, bad teeth, mutated hands and fingers, bad anatomy,wrong anatomy, amputation, extra limb, missing limb, floating,limbs, disconnected limbs, mutation, ugly, disgusting, bad_pictures, negative_hand-neg',
508
+ '3d_cartoon',
509
+ 2
510
+ ],
511
+ [
512
+ '__assets__/image_animation/rcnz/harry.png',
513
+ '__assets__/image_animation/rcnz/2.mp4',
514
+ '1boy playing magic fire',
515
+ 'wrong white balance, dark, sketches,worst quality,low quality, deformed, distorted, disfigured, bad eyes, wrong lips,weird mouth, bad teeth, mutated hands and fingers, bad anatomy,wrong anatomy, amputation, extra limb, missing limb, floating,limbs, disconnected limbs, mutation, ugly, disgusting, bad_pictures, negative_hand-neg',
516
+ '3d_cartoon',
517
+ 2
518
+ ],
519
+ [
520
+ '__assets__/image_animation/rcnz/harry.png',
521
+ '__assets__/image_animation/rcnz/3.mp4',
522
+ '1boy is waving hands',
523
+ 'wrong white balance, dark, sketches,worst quality,low quality, deformed, distorted, disfigured, bad eyes, wrong lips,weird mouth, bad teeth, mutated hands and fingers, bad anatomy,wrong anatomy, amputation, extra limb, missing limb, floating,limbs, disconnected limbs, mutation, ugly, disgusting, bad_pictures, negative_hand-neg',
524
+ '3d_cartoon',
525
+ 2
526
+ ]
527
+ ])
528
+
529
+ with gr.Accordion('Examples for Motion Magnitude', open=False):
530
+ create_example(
531
+ [
532
+ [
533
+ '__assets__/image_animation/magnitude/labrador.png',
534
+ '__assets__/image_animation/magnitude/1.mp4',
535
+ 'cherry blossoms in the wind, raidenshogundef, yaemikodef, best quality, 4k',
536
+ 'wrong white balance, dark, sketches,worst quality,low quality, deformed, distorted, disfigured, bad eyes, wrong lips,weird mouth, bad teeth, mutated hands and fingers, bad anatomy,wrong anatomy, amputation, extra limb, missing limb, floating,limbs, disconnected limbs, mutation, ugly, disgusting, bad_pictures, negative_hand-neg',
537
+ '3d_cartoon',
538
+ 1,
539
+ ],
540
+ [
541
+ '__assets__/image_animation/magnitude/labrador.png',
542
+ '__assets__/image_animation/magnitude/2.mp4',
543
+ 'cherry blossoms in the wind, raidenshogundef, yaemikodef, best quality, 4k',
544
+ 'wrong white balance, dark, sketches,worst quality,low quality, deformed, distorted, disfigured, bad eyes, wrong lips,weird mouth, bad teeth, mutated hands and fingers, bad anatomy,wrong anatomy, amputation, extra limb, missing limb, floating,limbs, disconnected limbs, mutation, ugly, disgusting, bad_pictures, negative_hand-neg',
545
+ '3d_cartoon',
546
+ 2,
547
+ ],
548
+ [
549
+ '__assets__/image_animation/magnitude/labrador.png',
550
+ '__assets__/image_animation/magnitude/3.mp4',
551
+ 'cherry blossoms in the wind, raidenshogundef, yaemikodef, best quality, 4k',
552
+ 'wrong white balance, dark, sketches,worst quality,low quality, deformed, distorted, disfigured, bad eyes, wrong lips,weird mouth, bad teeth, mutated hands and fingers, bad anatomy,wrong anatomy, amputation, extra limb, missing limb, floating,limbs, disconnected limbs, mutation, ugly, disgusting, bad_pictures, negative_hand-neg',
553
+ '3d_cartoon',
554
+ 3,
555
+ ]
556
+ ])
557
+
558
+ return demo
559
+
560
+
561
+ if __name__ == "__main__":
562
+ demo = ui()
563
+ demo.queue(max_size=10)
564
+ demo.launch(server_name=args.server_name,
565
+ server_port=args.port, share=args.share,
566
+ max_threads=40,
567
+ allowed_paths=['pia.png'])
benchmark.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ from time import sleep
4
+ import subprocess
5
+ from concurrent.futures import ThreadPoolExecutor
6
+
7
+ """
8
+ Examples:
9
+ - Test AD-benchmark:
10
+ python benchmark.py --script=inference_ad.py --yaml_dir=configs/ad/
11
+
12
+ - Test Indomain:
13
+ python benchmark.py --script=inference.py --yaml_dir=configs/indomain/
14
+
15
+ - Test:
16
+ python benchmark.py --script=inference.py --yaml_dir=configs/indomain/myprompt --spot=True
17
+
18
+ - Test AnimateBench:
19
+ python benchmark.py --script=inference_new.py --yaml_dir=AnimateBench/config/
20
+
21
+ """
22
+
23
+ parser = argparse.ArgumentParser()
24
+ parser.add_argument('--yaml_dir', type=str, default='configs/indomain/myprompt_simple/')
25
+ parser.add_argument('--node', type=str, default=None)
26
+ parser.add_argument('--script', type=str, default='inference.py')
27
+ parser.add_argument('--dreambooth', type=list, default=['toon', 'maj', 'real', 'rc', 'ly'])
28
+ # parser.add_argument('--dreambooth', type=list, default=['toon'])
29
+ parser.add_argument('--spot', type=bool, default=False)
30
+ args = parser.parse_args()
31
+
32
+ def run_srun_command(command):
33
+ subprocess.run(command, shell=True)
34
+
35
+ executor = ThreadPoolExecutor()
36
+
37
+ for db in args.dreambooth:
38
+ if not args.spot:
39
+ command = f"srun -p mm_lol --gres=gpu:1 "
40
+ else:
41
+ command = f"srun -p mm_lol --gres=gpu:1 --quota=spot "
42
+ if args.node is not None:
43
+ command = command + f'-w {args.node} '
44
+ command = command + f"python {args.script} --config={os.path.join(args.yaml_dir, db + '.yaml')}"
45
+
46
+ executor.submit(run_srun_command, command)
47
+ sleep(1)
configs/indomain/base.yaml ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ generate:
2
+ model_path: "outputs/training32-2023-11-01T09-50-05/checkpoints/checkpoint81000.ckpt"
3
+
4
+ validation_data:
5
+ mask_sim_range: 0
6
+ cond_frame: 0
7
+
8
+ noise_scheduler_kwargs:
9
+ num_train_timesteps: 1000
10
+ beta_start: 0.00085
11
+ beta_end: 0.012
12
+ beta_schedule: "linear"
13
+ steps_offset: 1
14
+ clip_sample: false
configs/indomain/real.yaml ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ base: 'configs/indomain/base.yaml'
2
+ prompts:
3
+ - ['A city street filled with neon shining lights and fog.', 'A young woman in dark clothes is smiling.']
4
+ - ['A dramatic black sky with overcast moving clouds.', 'The wind and waves are lapping at the lighthouse on the cliff']
5
+
6
+
7
+ n_prompt:
8
+ - 'wrong white balance, dark, sketches,worst quality,low quality, deformed, distorted, disfigured, bad eyes, wrong lips,weird mouth, bad teeth, mutated hands and fingers, bad anatomy,wrong anatomy, amputation, extra limb, missing limb, floating,limbs, disconnected limbs, mutation, ugly, disgusting, bad_pictures, negative_hand-neg'
9
+ validation_data:
10
+ input_name: 'init_image'
11
+ validation_input_path: 'benchmark_prompt/real/'
12
+ save_path: 'benchmark_prompt_result/real/'
13
+ num_inference_steps: 25
14
+ guidance_scale: 7.5
15
+ img_mask: ''
16
+
17
+ pretrained_model_path: "/mnt/petrelfs/zhangyiming/project/Image2Video-AnimateDiff/models/StableDiffusion/"
18
+ unet_additional_kwargs:
19
+ use_motion_module : true
20
+ motion_module_resolutions : [ 1,2,4,8 ]
21
+ unet_use_cross_frame_attention : false
22
+ unet_use_temporal_attention : false
23
+
24
+ motion_module_type: Vanilla
25
+ motion_module_kwargs:
26
+ num_attention_heads : 8
27
+ num_transformer_block : 1
28
+ attention_block_types : [ "Temporal_Self", "Temporal_Self" ]
29
+ temporal_position_encoding : true
30
+ temporal_position_encoding_max_len : 32
31
+ temporal_attention_dim_div : 1
32
+ zero_initialize : true
33
+
34
+
35
+ generate:
36
+ use_image: true
37
+ use_video: false
38
+ sample_size: 512
39
+ video_length: 16
40
+ global_seed: [2022, 2023]
41
+ use_lora: false
42
+ use_db: true
43
+ lora_path: "models/DreamBooth_LoRA/ink_lora.safetensors"
44
+ db_path: "models/DreamBooth_LoRA/real.safetensors"
45
+ lora_alpha: 0.8
configs/inference/inference.yaml ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ unet_additional_kwargs:
2
+ unet_use_cross_frame_attention: false
3
+ unet_use_temporal_attention: false
4
+ use_motion_module: true
5
+ motion_module_resolutions:
6
+ - 1
7
+ - 2
8
+ - 4
9
+ - 8
10
+ motion_module_mid_block: false
11
+ motion_module_decoder_only: false
12
+ motion_module_type: Vanilla
13
+ motion_module_kwargs:
14
+ num_attention_heads: 8
15
+ num_transformer_block: 1
16
+ attention_block_types:
17
+ - Temporal_Self
18
+ - Temporal_Self
19
+ temporal_position_encoding: true
20
+ temporal_position_encoding_max_len: 32
21
+ temporal_attention_dim_div: 1
22
+
23
+ noise_scheduler_kwargs:
24
+ beta_start: 0.00085
25
+ beta_end: 0.012
26
+ beta_schedule: "linear"
configs/prompts/1-ToonYou.yaml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ToonYou:
2
+ base: ""
3
+ path: "models/DreamBooth_LoRA/toonyou_beta3.safetensors"
4
+ motion_module:
5
+ - "models/Motion_Module/mm_sd_v14.ckpt"
6
+ - "models/Motion_Module/mm_sd_v15.ckpt"
7
+
8
+ seed: [10788741199826055526, 6520604954829636163, 6519455744612555650, 16372571278361863751]
9
+ steps: 25
10
+ guidance_scale: 7.5
11
+
12
+ prompt:
13
+ - "best quality, masterpiece, 1girl, looking at viewer, blurry background, upper body, contemporary, dress"
14
+ - "masterpiece, best quality, 1girl, solo, cherry blossoms, hanami, pink flower, white flower, spring season, wisteria, petals, flower, plum blossoms, outdoors, falling petals, white hair, black eyes,"
15
+ - "best quality, masterpiece, 1boy, formal, abstract, looking at viewer, masculine, marble pattern"
16
+ - "best quality, masterpiece, 1girl, cloudy sky, dandelion, contrapposto, alternate hairstyle,"
17
+
18
+ n_prompt:
19
+ - ""
20
+ - "badhandv4,easynegative,ng_deepnegative_v1_75t,verybadimagenegative_v1.3, bad-artist, bad_prompt_version2-neg, teeth"
21
+ - ""
22
+ - ""
configs/prompts/1.yaml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FilmVelvia:
2
+ base: "models/DreamBooth_LoRA/majicmixRealistic_v5.safetensors"
3
+ path: "models/DreamBooth_LoRA/FilmVelvia2.safetensors"
4
+ motion_module:
5
+ - "models/Motion_Module/mm_sd_v14.ckpt"
6
+ - "models/Motion_Module/mm_sd_v15.ckpt"
7
+
8
+ seed: [358675358833372813, 3519455280971923743, 11684545350557985081, 8696855302100399877]
9
+ steps: 25
10
+ guidance_scale: 7.5
11
+ lora_alpha: 0.6
12
+
13
+ prompt:
14
+ -
15
+
16
+ n_prompt:
17
+ - "cartoon, anime, sketches,worst quality, low quality, deformed, distorted, disfigured, bad eyes, wrong lips, weird mouth, bad teeth, mutated hands and fingers, bad anatomy, wrong anatomy, amputation, extra limb, missing limb, floating limbs, disconnected limbs, mutation, ugly, disgusting, bad_pictures, negative_hand-neg"
18
+ - "cartoon, anime, sketches,worst quality, low quality, deformed, distorted, disfigured, bad eyes, wrong lips, weird mouth, bad teeth, mutated hands and fingers, bad anatomy, wrong anatomy, amputation, extra limb, missing limb, floating limbs, disconnected limbs, mutation, ugly, disgusting, bad_pictures, negative_hand-neg"
19
+ - "wrong white balance, dark, cartoon, anime, sketches,worst quality, low quality, deformed, distorted, disfigured, bad eyes, wrong lips, weird mouth, bad teeth, mutated hands and fingers, bad anatomy, wrong anatomy, amputation, extra limb, missing limb, floating limbs, disconnected limbs, mutation, ugly, disgusting, bad_pictures, negative_hand-neg"
20
+ - "wrong white balance, dark, cartoon, anime, sketches,worst quality, low quality, deformed, distorted, disfigured, bad eyes, wrong lips, weird mouth, bad teeth, mutated hands and fingers, bad anatomy, wrong anatomy, amputation, extra limb, missing limb, floating limbs, disconnected limbs, mutation, ugly, disgusting, bad_pictures, negative_hand-neg"
configs/prompts/2-Lyriel.yaml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Lyriel:
2
+ base: ""
3
+ path: "models/DreamBooth_LoRA/lyriel_v16.safetensors"
4
+ motion_module:
5
+ - "models/Motion_Module/mm_sd_v14.ckpt"
6
+ - "models/Motion_Module/mm_sd_v15.ckpt"
7
+
8
+ seed: [10917152860782582783, 6399018107401806238, 15875751942533906793, 6653196880059936551]
9
+ steps: 25
10
+ guidance_scale: 7.5
11
+
12
+ prompt:
13
+ - "dark shot, epic realistic, portrait of halo, sunglasses, blue eyes, tartan scarf, white hair by atey ghailan, by greg rutkowski, by greg tocchini, by james gilleard, by joe fenton, by kaethe butcher, gradient yellow, black, brown and magenta color scheme, grunge aesthetic!!! graffiti tag wall background, art by greg rutkowski and artgerm, soft cinematic light, adobe lightroom, photolab, hdr, intricate, highly detailed, depth of field, faded, neutral colors, hdr, muted colors, hyperdetailed, artstation, cinematic, warm lights, dramatic light, intricate details, complex background, rutkowski, teal and orange"
14
+ - "A forbidden castle high up in the mountains, pixel art, intricate details2, hdr, intricate details, hyperdetailed5, natural skin texture, hyperrealism, soft light, sharp, game art, key visual, surreal"
15
+ - "dark theme, medieval portrait of a man sharp features, grim, cold stare, dark colors, Volumetric lighting, baroque oil painting by Greg Rutkowski, Artgerm, WLOP, Alphonse Mucha dynamic lighting hyperdetailed intricately detailed, hdr, muted colors, complex background, hyperrealism, hyperdetailed, amandine van ray"
16
+ - "As I have gone alone in there and with my treasures bold, I can keep my secret where and hint of riches new and old. Begin it where warm waters halt and take it in a canyon down, not far but too far to walk, put in below the home of brown."
17
+
18
+ n_prompt:
19
+ - "3d, cartoon, lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry, artist name, young, loli, elf, 3d, illustration"
20
+ - "3d, cartoon, anime, sketches, worst quality, low quality, normal quality, lowres, normal quality, monochrome, grayscale, skin spots, acnes, skin blemishes, bad anatomy, girl, loli, young, large breasts, red eyes, muscular"
21
+ - "dof, grayscale, black and white, bw, 3d, cartoon, anime, sketches, worst quality, low quality, normal quality, lowres, normal quality, monochrome, grayscale, skin spots, acnes, skin blemishes, bad anatomy, girl, loli, young, large breasts, red eyes, muscular,badhandsv5-neg, By bad artist -neg 1, monochrome"
22
+ - "holding an item, cowboy, hat, cartoon, 3d, disfigured, bad art, deformed,extra limbs,close up,b&w, wierd colors, blurry, duplicate, morbid, mutilated, [out of frame], extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, ugly, blurry, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, out of frame, ugly, extra limbs, bad anatomy, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, mutated hands, fused fingers, too many fingers, long neck, Photoshop, video game, ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, mutation, mutated, extra limbs, extra legs, extra arms, disfigured, deformed, cross-eye, body out of frame, blurry, bad art, bad anatomy, 3d render"
configs/prompts/3-RcnzCartoon.yaml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ RcnzCartoon:
2
+ base: ""
3
+ path: "models/DreamBooth_LoRA/rcnzCartoon3d_v10.safetensors"
4
+ motion_module:
5
+ - "models/Motion_Module/mm_sd_v14.ckpt"
6
+ - "models/Motion_Module/mm_sd_v15.ckpt"
7
+
8
+ seed: [16931037867122267877, 2094308009433392066, 4292543217695451092, 15572665120852309890]
9
+ steps: 25
10
+ guidance_scale: 7.5
11
+
12
+ prompt:
13
+ - "Jane Eyre with headphones, natural skin texture,4mm,k textures, soft cinematic light, adobe lightroom, photolab, hdr, intricate, elegant, highly detailed, sharp focus, cinematic look, soothing tones, insane details, intricate details, hyperdetailed, low contrast, soft cinematic light, dim colors, exposure blend, hdr, faded"
14
+ - "close up Portrait photo of muscular bearded guy in a worn mech suit, light bokeh, intricate, steel metal [rust], elegant, sharp focus, photo by greg rutkowski, soft lighting, vibrant colors, masterpiece, streets, detailed face"
15
+ - "absurdres, photorealistic, masterpiece, a 30 year old man with gold framed, aviator reading glasses and a black hooded jacket and a beard, professional photo, a character portrait, altermodern, detailed eyes, detailed lips, detailed face, grey eyes"
16
+ - "a golden labrador, warm vibrant colours, natural lighting, dappled lighting, diffused lighting, absurdres, highres,k, uhd, hdr, rtx, unreal, octane render, RAW photo, photorealistic, global illumination, subsurface scattering"
17
+
18
+ n_prompt:
19
+ - "deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, mutated hands and fingers, disconnected limbs, mutation, mutated, ugly, disgusting, blurry, amputation"
20
+ - "nude, cross eyed, tongue, open mouth, inside, 3d, cartoon, anime, sketches, worst quality, low quality, normal quality, lowres, normal quality, monochrome, grayscale, skin spots, acnes, skin blemishes, bad anatomy, red eyes, muscular"
21
+ - "easynegative, cartoon, anime, sketches, necklace, earrings worst quality, low quality, normal quality, bad anatomy, bad hands, shiny skin, error, missing fingers, extra digit, fewer digits, jpeg artifacts, signature, watermark, username, blurry, chubby, anorectic, bad eyes, old, wrinkled skin, red skin, photograph By bad artist -neg, big eyes, muscular face,"
22
+ - "beard, EasyNegative, lowres, chromatic aberration, depth of field, motion blur, blurry, bokeh, bad quality, worst quality, multiple arms, badhand"
configs/prompts/4-MajicMix.yaml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MajicMix:
2
+ base: ""
3
+ path: "models/DreamBooth_LoRA/majicmixRealistic_v5Preview.safetensors"
4
+ motion_module:
5
+ - "models/Motion_Module/mm_sd_v14.ckpt"
6
+ - "models/Motion_Module/mm_sd_v15.ckpt"
7
+
8
+ seed: [1572448948722921032, 1099474677988590681, 6488833139725635347, 18339859844376517918]
9
+ steps: 25
10
+ guidance_scale: 7.5
11
+
12
+ prompt:
13
+ - "1girl, offshoulder, light smile, shiny skin best quality, masterpiece, photorealistic"
14
+ - "best quality, masterpiece, photorealistic, 1boy, 50 years old beard, dramatic lighting"
15
+ - "best quality, masterpiece, photorealistic, 1girl, light smile, shirt with collars, waist up, dramatic lighting, from below"
16
+ - "male, man, beard, bodybuilder, skinhead,cold face, tough guy, cowboyshot, tattoo, french windows, luxury hotel masterpiece, best quality, photorealistic"
17
+
18
+ n_prompt:
19
+ - "ng_deepnegative_v1_75t, badhandv4, worst quality, low quality, normal quality, lowres, bad anatomy, bad hands, watermark, moles"
20
+ - "nsfw, ng_deepnegative_v1_75t,badhandv4, worst quality, low quality, normal quality, lowres,watermark, monochrome"
21
+ - "nsfw, ng_deepnegative_v1_75t,badhandv4, worst quality, low quality, normal quality, lowres,watermark, monochrome"
22
+ - "nude, nsfw, ng_deepnegative_v1_75t, badhandv4, worst quality, low quality, normal quality, lowres, bad anatomy, bad hands, monochrome, grayscale watermark, moles, people"
configs/prompts/5-RealisticVision.yaml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ RealisticVision:
2
+ base: ""
3
+ path: "models/DreamBooth_LoRA/realisticVisionV20_v20.safetensors"
4
+ motion_module:
5
+ - "models/Motion_Module/mm_sd_v14.ckpt"
6
+ - "models/Motion_Module/mm_sd_v15.ckpt"
7
+
8
+ seed: [5658137986800322009, 12099779162349365895, 10499524853910852697, 16768009035333711932]
9
+ steps: 25
10
+ guidance_scale: 7.5
11
+
12
+ prompt:
13
+ - "b&w photo of 42 y.o man in black clothes, bald, face, half body, body, high detailed skin, skin pores, coastline, overcast weather, wind, waves, 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3"
14
+ - "close up photo of a rabbit, forest, haze, halation, bloom, dramatic atmosphere, centred, rule of thirds, 200mm 1.4f macro shot"
15
+ - "photo of coastline, rocks, storm weather, wind, waves, lightning, 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3"
16
+ - "night, b&w photo of old house, post apocalypse, forest, storm weather, wind, rocks, 8k uhd, dslr, soft lighting, high quality, film grain"
17
+
18
+ n_prompt:
19
+ - "semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, text, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck"
20
+ - "semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, text, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck"
21
+ - "blur, haze, deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, mutated hands and fingers, deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, amputation"
22
+ - "blur, haze, deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, art, mutated hands and fingers, deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, amputation"
configs/prompts/6-Tusun.yaml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Tusun:
2
+ base: "models/DreamBooth_LoRA/moonfilm_reality20.safetensors"
3
+ path: "models/DreamBooth_LoRA/TUSUN.safetensors"
4
+ motion_module:
5
+ - "models/Motion_Module/mm_sd_v14.ckpt"
6
+ - "models/Motion_Module/mm_sd_v15.ckpt"
7
+
8
+ seed: [10154078483724687116, 2664393535095473805, 4231566096207622938, 1713349740448094493]
9
+ steps: 25
10
+ guidance_scale: 7.5
11
+ lora_alpha: 0.6
12
+
13
+ prompt:
14
+ - "tusuncub with its mouth open, blurry, open mouth, fangs, photo background, looking at viewer, tongue, full body, solo, cute and lovely, Beautiful and realistic eye details, perfect anatomy, Nonsense, pure background, Centered-Shot, realistic photo, photograph, 4k, hyper detailed, DSLR, 24 Megapixels, 8mm Lens, Full Frame, film grain, Global Illumination, studio Lighting, Award Winning Photography, diffuse reflection, ray tracing"
15
+ - "cute tusun with a blurry background, black background, simple background, signature, face, solo, cute and lovely, Beautiful and realistic eye details, perfect anatomy, Nonsense, pure background, Centered-Shot, realistic photo, photograph, 4k, hyper detailed, DSLR, 24 Megapixels, 8mm Lens, Full Frame, film grain, Global Illumination, studio Lighting, Award Winning Photography, diffuse reflection, ray tracing"
16
+ - "cut tusuncub walking in the snow, blurry, looking at viewer, depth of field, blurry background, full body, solo, cute and lovely, Beautiful and realistic eye details, perfect anatomy, Nonsense, pure background, Centered-Shot, realistic photo, photograph, 4k, hyper detailed, DSLR, 24 Megapixels, 8mm Lens, Full Frame, film grain, Global Illumination, studio Lighting, Award Winning Photography, diffuse reflection, ray tracing"
17
+ - "character design, cyberpunk tusun kitten wearing astronaut suit, sci-fic, realistic eye color and details, fluffy, big head, science fiction, communist ideology, Cyborg, fantasy, intense angle, soft lighting, photograph, 4k, hyper detailed, portrait wallpaper, realistic, photo-realistic, DSLR, 24 Megapixels, Full Frame, vibrant details, octane render, finely detail, best quality, incredibly absurdres, robotic parts, rim light, vibrant details, luxurious cyberpunk, hyperrealistic, cable electric wires, microchip, full body"
18
+
19
+ n_prompt:
20
+ - "worst quality, low quality, deformed, distorted, disfigured, bad eyes, bad anatomy, disconnected limbs, wrong body proportions, low quality, worst quality, text, watermark, signatre, logo, illustration, painting, cartoons, ugly, easy_negative"
configs/prompts/7-FilmVelvia.yaml ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FilmVelvia:
2
+ base: "models/DreamBooth_LoRA/majicmixRealistic_v5.safetensors"
3
+ path: "models/DreamBooth_LoRA/FilmVelvia2.safetensors"
4
+ motion_module:
5
+ - "models/Motion_Module/mm_sd_v14.ckpt"
6
+ - "models/Motion_Module/mm_sd_v15.ckpt"
7
+
8
+ seed: [358675358833372813, 3519455280971923743, 11684545350557985081, 8696855302100399877]
9
+ steps: 25
10
+ guidance_scale: 7.5
11
+ lora_alpha: 0.6
12
+
13
+ prompt:
14
+ - "a woman standing on the side of a road at night,girl, long hair, motor vehicle, car, looking at viewer, ground vehicle, night, hands in pockets, blurry background, coat, black hair, parted lips, bokeh, jacket, brown hair, outdoors, red lips, upper body, artist name"
15
+ - ", dark shot,0mm, portrait quality of a arab man worker,boy, wasteland that stands out vividly against the background of the desert, barren landscape, closeup, moles skin, soft light, sharp, exposure blend, medium shot, bokeh, hdr, high contrast, cinematic, teal and orange5, muted colors, dim colors, soothing tones, low saturation, hyperdetailed, noir"
16
+ - "fashion photography portrait of 1girl, offshoulder, fluffy short hair, soft light, rim light, beautiful shadow, low key, photorealistic, raw photo, natural skin texture, realistic eye and face details, hyperrealism, ultra high res, 4K, Best quality, masterpiece, necklace, cleavage, in the dark"
17
+ - "In this lighthearted portrait, a woman is dressed as a fierce warrior, armed with an arsenal of paintbrushes and palette knives. Her war paint is composed of thick, vibrant strokes of color, and her armor is made of paint tubes and paint-splattered canvases. She stands victoriously atop a mountain of conquered blank canvases, with a beautiful, colorful landscape behind her, symbolizing the power of art and creativity. bust Portrait, close-up, Bright and transparent scene lighting, "
18
+
19
+ n_prompt:
20
+ - "cartoon, anime, sketches,worst quality, low quality, deformed, distorted, disfigured, bad eyes, wrong lips, weird mouth, bad teeth, mutated hands and fingers, bad anatomy, wrong anatomy, amputation, extra limb, missing limb, floating limbs, disconnected limbs, mutation, ugly, disgusting, bad_pictures, negative_hand-neg"
21
+ - "cartoon, anime, sketches,worst quality, low quality, deformed, distorted, disfigured, bad eyes, wrong lips, weird mouth, bad teeth, mutated hands and fingers, bad anatomy, wrong anatomy, amputation, extra limb, missing limb, floating limbs, disconnected limbs, mutation, ugly, disgusting, bad_pictures, negative_hand-neg"
22
+ - "wrong white balance, dark, cartoon, anime, sketches,worst quality, low quality, deformed, distorted, disfigured, bad eyes, wrong lips, weird mouth, bad teeth, mutated hands and fingers, bad anatomy, wrong anatomy, amputation, extra limb, missing limb, floating limbs, disconnected limbs, mutation, ugly, disgusting, bad_pictures, negative_hand-neg"
23
+ - "wrong white balance, dark, cartoon, anime, sketches,worst quality, low quality, deformed, distorted, disfigured, bad eyes, wrong lips, weird mouth, bad teeth, mutated hands and fingers, bad anatomy, wrong anatomy, amputation, extra limb, missing limb, floating limbs, disconnected limbs, mutation, ugly, disgusting, bad_pictures, negative_hand-neg"
configs/prompts/8-GhibliBackground.yaml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ GhibliBackground:
2
+ base: "models/DreamBooth_LoRA/CounterfeitV30_25.safetensors"
3
+ path: "models/DreamBooth_LoRA/lora_Ghibli_n3.safetensors"
4
+ motion_module:
5
+ - "models/Motion_Module/mm_sd_v14.ckpt"
6
+ - "models/Motion_Module/mm_sd_v15.ckpt"
7
+
8
+ seed: [8775748474469046618, 5893874876080607656, 11911465742147695752, 12437784838692000640]
9
+ steps: 25
10
+ guidance_scale: 7.5
11
+ lora_alpha: 1.0
12
+
13
+ prompt:
14
+ - "best quality,single build,architecture, blue_sky, building,cloudy_sky, day, fantasy, fence, field, house, build,architecture,landscape, moss, outdoors, overgrown, path, river, road, rock, scenery, sky, sword, tower, tree, waterfall"
15
+ - "black_border, building, city, day, fantasy, ice, landscape, letterboxed, mountain, ocean, outdoors, planet, scenery, ship, snow, snowing, water, watercraft, waterfall, winter"
16
+ - ",mysterious sea area, fantasy,build,concept"
17
+ - "Tomb Raider,Scenography,Old building"
18
+
19
+ n_prompt:
20
+ - "easynegative,bad_construction,bad_structure,bad_wail,bad_windows,blurry,cloned_window,cropped,deformed,disfigured,error,extra_windows,extra_chimney,extra_door,extra_structure,extra_frame,fewer_digits,fused_structure,gross_proportions,jpeg_artifacts,long_roof,low_quality,structure_limbs,missing_windows,missing_doors,missing_roofs,mutated_structure,mutation,normal_quality,out_of_frame,owres,poorly_drawn_structure,poorly_drawn_house,signature,text,too_many_windows,ugly,username,uta,watermark,worst_quality"
configs/training/image_finetune.yaml ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ image_finetune: true
2
+
3
+ output_dir: "outputs"
4
+ pretrained_model_path: "models/StableDiffusion/stable-diffusion-v1-5"
5
+
6
+ noise_scheduler_kwargs:
7
+ num_train_timesteps: 1000
8
+ beta_start: 0.00085
9
+ beta_end: 0.012
10
+ beta_schedule: "scaled_linear"
11
+ steps_offset: 1
12
+ clip_sample: false
13
+
14
+ train_data:
15
+ csv_path: "/mnt/petrelfs/guoyuwei/projects/datasets/webvid/results_2M_val.csv"
16
+ video_folder: "/mnt/petrelfs/guoyuwei/projects/datasets/webvid/2M_val"
17
+ sample_size: 256
18
+
19
+ validation_data:
20
+ prompts:
21
+ - "Snow rocky mountains peaks canyon. Snow blanketed rocky mountains surround and shadow deep canyons."
22
+ - "A drone view of celebration with Christma tree and fireworks, starry sky - background."
23
+ - "Robot dancing in times square."
24
+ - "Pacific coast, carmel by the sea ocean and waves."
25
+ num_inference_steps: 25
26
+ guidance_scale: 8.
27
+
28
+ trainable_modules:
29
+ - "."
30
+
31
+ unet_checkpoint_path: ""
32
+
33
+ learning_rate: 1.e-5
34
+ train_batch_size: 50
35
+
36
+ max_train_epoch: -1
37
+ max_train_steps: 100
38
+ checkpointing_epochs: -1
39
+ checkpointing_steps: 60
40
+
41
+ validation_steps: 5000
42
+ validation_steps_tuple: [2, 50]
43
+
44
+ global_seed: 42
45
+ mixed_precision_training: true
46
+ enable_xformers_memory_efficient_attention: True
47
+
48
+ is_debug: False