jbilcke-hf HF Staff commited on
Commit
b9c531e
·
verified ·
1 Parent(s): 03fd240

Upload 7 files

Browse files
taehv/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2025 Ollin Boer Bohan
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
taehv/README.md ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🥮 Tiny AutoEncoder for Hunyuan Video & Wan 2.1
2
+
3
+ ## What is TAEHV?
4
+
5
+ TAEHV is a Tiny AutoEncoder for Hunyuan Video (& Wan 2.1). TAEHV can decode latents into videos more cheaply (in time & memory) than the full-size VAEs, at the cost of slightly lower quality.
6
+
7
+ Here's a comparison of the output & memory usage of the Full Hunyuan VAE vs. TAEHV:
8
+
9
+ <table>
10
+ <tr><th><tt>pipe.vae</tt></th><th>Full Hunyuan VAE</th><th>TAEHV</th></tr>
11
+ <tr>
12
+ <td>Decoded Video<br/><sup>(converted to GIF)</sup></td>
13
+ <td><img src="https://github.com/user-attachments/assets/b9ee3405-c210-4410-95ac-639a4ed09c50"/></td>
14
+ <td><img src="https://github.com/user-attachments/assets/3fe3cb6a-30e5-46fe-9458-f0a39e454b86"/></td>
15
+ </tr>
16
+ <tr>
17
+ <td>Runtime<br/><sup>(in fp16, on GH200)</sup></td>
18
+ <td><strong>~2-3s</strong> for decoding 61 frames of (512, 320) video</td>
19
+ <td><strong>~0.5s</strong> for decoding 61 frames of (512, 320) video</td>
20
+ </tr>
21
+ <tr>
22
+ <td>Memory<br/><sup>(in fp16, on GH200)</sup></td>
23
+ <td><strong>~6-9GB Peak Memory Usage</strong><br/><img src="https://github.com/user-attachments/assets/d7837271-c748-4eef-ab37-eda6cc1e6a69"/></td>
24
+ <td><strong><0.5GB Peak Memory Usage</strong><br/><img src="https://github.com/user-attachments/assets/c71e2ef5-12f1-431f-b193-29d9a5ee6343"/></td>
25
+ </tr>
26
+ </table>
27
+
28
+
29
+ See the [profiling notebook](./examples/TAEHV_Profiling.ipynb) for details on this comparison or the [example notebook](./examples/TAEHV_T2I_Demo.ipynb) for a simpler demo.
30
+
31
+ ## How do I use TAEHV with Wan 2.1?
32
+
33
+ Since Wan 2.1 uses the same input / output shapes as Hunyuan VAE, you can also use TAEHV for Wan 2.1 decoding using the `taew2_1.pth` weights (see the [Wan 2.1 example notebook](./examples/TAEW2.1_T2I_Demo.ipynb)).
34
+
35
+ ## How do I use TAEHV with CogVideoX?
36
+
37
+ Try the `taecvx.pth` weights (see the [example notebook](./examples/TAECVX_T2I_Demo.ipynb)).
38
+
39
+ ## How do I use TAEHV with Open-Sora 1.3?
40
+
41
+ Try the `taeos1_3.pth` weights.
42
+
43
+ ## How can I reduce the TAEHV decoding cost further?
44
+
45
+ You can disable temporal or spatial upscaling to get even-cheaper decoding.
46
+
47
+ ```python
48
+ TAEHV(decoder_time_upscale=(False, False), decoder_space_upscale=(True, True, True))
49
+ ```
50
+
51
+ ![Image](https://github.com/user-attachments/assets/c517e37b-e53b-4d7d-b282-fbbbce10ade7)
52
+
53
+ ```python
54
+ TAEHV(decoder_time_upscale=(False, False), decoder_space_upscale=(False, False, False))
55
+ ```
56
+
57
+ ![Image](https://github.com/user-attachments/assets/62223493-8cad-427b-b13c-fa9919d3fd7b)
58
+
59
+ If you have a powerful GPU or are decoding at a reduced resolution, you can also set `parallel=True` in `TAEHV.decode_video` to decode all frames at once (which is faster but requires more memory).
60
+
61
+ ## Limitations
62
+
63
+ TAEHV is still pretty experimental (specifically, it's a hacky finetune of [TAEM1](https://github.com/madebyollin/taem1) :) using a fairly limited dataset) and I haven't tested it much yet. Please report quality / performance issues as you discover them.
taehv/taecvx.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ec0ab2077a044d294f05a5731c5ca5174fd54b082552598ad7c5b1800159f423
3
+ size 22680146
taehv/taehv.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3866076f74b50087d5ba4d145191480ad51a08a316c8e95a7233284945cfc26e
3
+ size 22679486
taehv/taehv.py ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Tiny AutoEncoder for Hunyuan Video
4
+ (DNN for encoding / decoding videos to Hunyuan Video's latent space)
5
+ """
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from tqdm.auto import tqdm
10
+ from collections import namedtuple
11
+
12
+ DecoderResult = namedtuple("DecoderResult", ("frame", "memory"))
13
+ TWorkItem = namedtuple("TWorkItem", ("input_tensor", "block_index"))
14
+
15
+ def conv(n_in, n_out, **kwargs):
16
+ return nn.Conv2d(n_in, n_out, 3, padding=1, **kwargs)
17
+
18
+ class Clamp(nn.Module):
19
+ def forward(self, x):
20
+ return torch.tanh(x / 3) * 3
21
+
22
+ class MemBlock(nn.Module):
23
+ def __init__(self, n_in, n_out):
24
+ super().__init__()
25
+ self.conv = nn.Sequential(conv(n_in * 2, n_out), nn.ReLU(inplace=True), conv(n_out, n_out), nn.ReLU(inplace=True), conv(n_out, n_out))
26
+ self.skip = nn.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity()
27
+ self.act = nn.ReLU(inplace=True)
28
+ def forward(self, x, past):
29
+ return self.act(self.conv(torch.cat([x, past], 1)) + self.skip(x))
30
+
31
+ class TPool(nn.Module):
32
+ def __init__(self, n_f, stride):
33
+ super().__init__()
34
+ self.stride = stride
35
+ self.conv = nn.Conv2d(n_f*stride,n_f, 1, bias=False)
36
+ def forward(self, x):
37
+ _NT, C, H, W = x.shape
38
+ return self.conv(x.reshape(-1, self.stride * C, H, W))
39
+
40
+ class TGrow(nn.Module):
41
+ def __init__(self, n_f, stride):
42
+ super().__init__()
43
+ self.stride = stride
44
+ self.conv = nn.Conv2d(n_f, n_f*stride, 1, bias=False)
45
+ def forward(self, x):
46
+ _NT, C, H, W = x.shape
47
+ x = self.conv(x)
48
+ return x.reshape(-1, C, H, W)
49
+
50
+ def apply_model_with_memblocks(model, x, parallel, show_progress_bar):
51
+ """
52
+ Apply a sequential model with memblocks to the given input.
53
+ Args:
54
+ - model: nn.Sequential of blocks to apply
55
+ - x: input data, of dimensions NTCHW
56
+ - parallel: if True, parallelize over timesteps (fast but uses O(T) memory)
57
+ if False, each timestep will be processed sequentially (slow but uses O(1) memory)
58
+ - show_progress_bar: if True, enables tqdm progressbar display
59
+
60
+ Returns NTCHW tensor of output data.
61
+ """
62
+ assert x.ndim == 5, f"TAEHV operates on NTCHW tensors, but got {x.ndim}-dim tensor"
63
+ N, T, C, H, W = x.shape
64
+ if parallel:
65
+ x = x.reshape(N*T, C, H, W)
66
+ # parallel over input timesteps, iterate over blocks
67
+ for b in tqdm(model, disable=not show_progress_bar):
68
+ if isinstance(b, MemBlock):
69
+ NT, C, H, W = x.shape
70
+ T = NT // N
71
+ _x = x.reshape(N, T, C, H, W)
72
+ mem = F.pad(_x, (0,0,0,0,0,0,1,0), value=0)[:,:T].reshape(x.shape)
73
+ x = b(x, mem)
74
+ else:
75
+ x = b(x)
76
+ NT, C, H, W = x.shape
77
+ T = NT // N
78
+ x = x.view(N, T, C, H, W)
79
+ else:
80
+ # TODO(oboerbohan): at least on macos this still gradually uses more memory during decode...
81
+ # need to fix :(
82
+ out = []
83
+ # iterate over input timesteps and also iterate over blocks.
84
+ # because of the cursed TPool/TGrow blocks, this is not a nested loop,
85
+ # it's actually a ***graph traversal*** problem! so let's make a queue
86
+ work_queue = [TWorkItem(xt, 0) for t, xt in enumerate(x.reshape(N, T * C, H, W).chunk(T, dim=1))]
87
+ # in addition to manually managing our queue, we also need to manually manage our progressbar.
88
+ # we'll update it for every source node that we consume.
89
+ progress_bar = tqdm(range(T), disable=not show_progress_bar)
90
+ # we'll also need a separate addressable memory per node as well
91
+ mem = [None] * len(model)
92
+ while work_queue:
93
+ xt, i = work_queue.pop(0)
94
+ if i == 0:
95
+ # new source node consumed
96
+ progress_bar.update(1)
97
+ if i == len(model):
98
+ # reached end of the graph, append result to output list
99
+ out.append(xt)
100
+ else:
101
+ # fetch the block to process
102
+ b = model[i]
103
+ if isinstance(b, MemBlock):
104
+ # mem blocks are simple since we're visiting the graph in causal order
105
+ if mem[i] is None:
106
+ xt_new = b(xt, xt * 0)
107
+ mem[i] = xt
108
+ else:
109
+ xt_new = b(xt, mem[i])
110
+ mem[i].copy_(xt) # inplace might reduce mysterious pytorch memory allocations? doesn't help though
111
+ # add successor to work queue
112
+ work_queue.insert(0, TWorkItem(xt_new, i+1))
113
+ elif isinstance(b, TPool):
114
+ # pool blocks are miserable
115
+ if mem[i] is None:
116
+ mem[i] = [] # pool memory is itself a queue of inputs to pool
117
+ mem[i].append(xt)
118
+ if len(mem[i]) > b.stride:
119
+ # pool mem is in invalid state, we should have pooled before this
120
+ raise ValueError("???")
121
+ elif len(mem[i]) < b.stride:
122
+ # pool mem is not yet full, go back to processing the work queue
123
+ pass
124
+ else:
125
+ # pool mem is ready, run the pool block
126
+ N, C, H, W = xt.shape
127
+ xt = b(torch.cat(mem[i], 1).view(N*b.stride, C, H, W))
128
+ # reset the pool mem
129
+ mem[i] = []
130
+ # add successor to work queue
131
+ work_queue.insert(0, TWorkItem(xt, i+1))
132
+ elif isinstance(b, TGrow):
133
+ xt = b(xt)
134
+ NT, C, H, W = xt.shape
135
+ # each tgrow has multiple successor nodes
136
+ for xt_next in reversed(xt.view(N, b.stride*C, H, W).chunk(b.stride, 1)):
137
+ # add successor to work queue
138
+ work_queue.insert(0, TWorkItem(xt_next, i+1))
139
+ else:
140
+ # normal block with no funny business
141
+ xt = b(xt)
142
+ # add successor to work queue
143
+ work_queue.insert(0, TWorkItem(xt, i+1))
144
+ progress_bar.close()
145
+ x = torch.stack(out, 1)
146
+ return x
147
+
148
+ class TAEHV(nn.Module):
149
+ latent_channels = 16
150
+ image_channels = 3
151
+ def __init__(self, checkpoint_path="taehv.pth", decoder_time_upscale=(True, True), decoder_space_upscale=(True, True, True)):
152
+ """Initialize pretrained TAEHV from the given checkpoint.
153
+
154
+ Arg:
155
+ checkpoint_path: path to weight file to load. taehv.pth for Hunyuan, taew2_1.pth for Wan 2.1.
156
+ decoder_time_upscale: whether temporal upsampling is enabled for each block. upsampling can be disabled for a cheaper preview.
157
+ decoder_space_upscale: whether spatial upsampling is enabled for each block. upsampling can be disabled for a cheaper preview.
158
+ """
159
+ super().__init__()
160
+ self.encoder = nn.Sequential(
161
+ conv(TAEHV.image_channels, 64), nn.ReLU(inplace=True),
162
+ TPool(64, 2), conv(64, 64, stride=2, bias=False), MemBlock(64, 64), MemBlock(64, 64), MemBlock(64, 64),
163
+ TPool(64, 2), conv(64, 64, stride=2, bias=False), MemBlock(64, 64), MemBlock(64, 64), MemBlock(64, 64),
164
+ TPool(64, 1), conv(64, 64, stride=2, bias=False), MemBlock(64, 64), MemBlock(64, 64), MemBlock(64, 64),
165
+ conv(64, TAEHV.latent_channels),
166
+ )
167
+ n_f = [256, 128, 64, 64]
168
+ self.frames_to_trim = 2**sum(decoder_time_upscale) - 1
169
+ self.decoder = nn.Sequential(
170
+ Clamp(), conv(TAEHV.latent_channels, n_f[0]), nn.ReLU(inplace=True),
171
+ MemBlock(n_f[0], n_f[0]), MemBlock(n_f[0], n_f[0]), MemBlock(n_f[0], n_f[0]), nn.Upsample(scale_factor=2 if decoder_space_upscale[0] else 1), TGrow(n_f[0], 1), conv(n_f[0], n_f[1], bias=False),
172
+ MemBlock(n_f[1], n_f[1]), MemBlock(n_f[1], n_f[1]), MemBlock(n_f[1], n_f[1]), nn.Upsample(scale_factor=2 if decoder_space_upscale[1] else 1), TGrow(n_f[1], 2 if decoder_time_upscale[0] else 1), conv(n_f[1], n_f[2], bias=False),
173
+ MemBlock(n_f[2], n_f[2]), MemBlock(n_f[2], n_f[2]), MemBlock(n_f[2], n_f[2]), nn.Upsample(scale_factor=2 if decoder_space_upscale[2] else 1), TGrow(n_f[2], 2 if decoder_time_upscale[1] else 1), conv(n_f[2], n_f[3], bias=False),
174
+ nn.ReLU(inplace=True), conv(n_f[3], TAEHV.image_channels),
175
+ )
176
+ if checkpoint_path is not None:
177
+ self.load_state_dict(self.patch_tgrow_layers(torch.load(checkpoint_path, map_location="cpu", weights_only=True)))
178
+
179
+ def patch_tgrow_layers(self, sd):
180
+ """Patch TGrow layers to use a smaller kernel if needed.
181
+
182
+ Args:
183
+ sd: state dict to patch
184
+ """
185
+ new_sd = self.state_dict()
186
+ for i, layer in enumerate(self.decoder):
187
+ if isinstance(layer, TGrow):
188
+ key = f"decoder.{i}.conv.weight"
189
+ if sd[key].shape[0] > new_sd[key].shape[0]:
190
+ # take the last-timestep output channels
191
+ sd[key] = sd[key][-new_sd[key].shape[0]:]
192
+ return sd
193
+
194
+ def encode_video(self, x, parallel=True, show_progress_bar=True):
195
+ """Encode a sequence of frames.
196
+
197
+ Args:
198
+ x: input NTCHW RGB (C=3) tensor with values in [0, 1].
199
+ parallel: if True, all frames will be processed at once.
200
+ (this is faster but may require more memory).
201
+ if False, frames will be processed sequentially.
202
+ Returns NTCHW latent tensor with ~Gaussian values.
203
+ """
204
+ return apply_model_with_memblocks(self.encoder, x, parallel, show_progress_bar)
205
+
206
+ def decode_video(self, x, parallel=True, show_progress_bar=True):
207
+ """Decode a sequence of frames.
208
+
209
+ Args:
210
+ x: input NTCHW latent (C=12) tensor with ~Gaussian values.
211
+ parallel: if True, all frames will be processed at once.
212
+ (this is faster but may require more memory).
213
+ if False, frames will be processed sequentially.
214
+ Returns NTCHW RGB tensor with ~[0, 1] values.
215
+ """
216
+ x = apply_model_with_memblocks(self.decoder, x, parallel, show_progress_bar)
217
+ return x[:, self.frames_to_trim:]
218
+
219
+ def forward(self, x):
220
+ return self.c(x)
221
+
222
+ @torch.no_grad()
223
+ def main():
224
+ """Run TAEHV roundtrip reconstruction on the given video paths."""
225
+ import os
226
+ import sys
227
+ import cv2 # no highly esteemed deed is commemorated here
228
+
229
+ class VideoTensorReader:
230
+ def __init__(self, video_file_path):
231
+ self.cap = cv2.VideoCapture(video_file_path)
232
+ assert self.cap.isOpened(), f"Could not load {video_file_path}"
233
+ self.fps = self.cap.get(cv2.CAP_PROP_FPS)
234
+ def __iter__(self):
235
+ return self
236
+ def __next__(self):
237
+ ret, frame = self.cap.read()
238
+ if not ret:
239
+ self.cap.release()
240
+ raise StopIteration # End of video or error
241
+ return torch.from_numpy(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)).permute(2, 0, 1) # BGR HWC -> RGB CHW
242
+
243
+ class VideoTensorWriter:
244
+ def __init__(self, video_file_path, width_height, fps=30):
245
+ self.writer = cv2.VideoWriter(video_file_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, width_height)
246
+ assert self.writer.isOpened(), f"Could not create writer for {video_file_path}"
247
+ def write(self, frame_tensor):
248
+ assert frame_tensor.ndim == 3 and frame_tensor.shape[0] == 3, f"{frame_tensor.shape}??"
249
+ self.writer.write(cv2.cvtColor(frame_tensor.permute(1, 2, 0).numpy(), cv2.COLOR_RGB2BGR)) # RGB CHW -> BGR HWC
250
+ def __del__(self):
251
+ if hasattr(self, 'writer'): self.writer.release()
252
+
253
+ dev = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
254
+ dtype = torch.float16
255
+ checkpoint_path = os.getenv("TAEHV_CHECKPOINT_PATH", "taehv.pth")
256
+ checkpoint_name = os.path.splitext(os.path.basename(checkpoint_path))[0]
257
+ print(f"Using device \033[31m{dev}\033[0m, dtype \033[32m{dtype}\033[0m, checkpoint \033[34m{checkpoint_name}\033[0m ({checkpoint_path})")
258
+ taehv = TAEHV(checkpoint_path=checkpoint_path).to(dev, dtype)
259
+ for video_path in sys.argv[1:]:
260
+ print(f"Processing {video_path}...")
261
+ video_in = VideoTensorReader(video_path)
262
+ video = torch.stack(list(video_in), 0)[None]
263
+ vid_dev = video.to(dev, dtype).div_(255.0)
264
+ # convert to device tensor
265
+ if video.numel() < 100_000_000:
266
+ print(f" {video_path} seems small enough, will process all frames in parallel")
267
+ # convert to device tensor
268
+ vid_enc = taehv.encode_video(vid_dev)
269
+ print(f" Encoded {video_path} -> {vid_enc.shape}. Decoding...")
270
+ vid_dec = taehv.decode_video(vid_enc)
271
+ print(f" Decoded {video_path} -> {vid_dec.shape}")
272
+ else:
273
+ print(f" {video_path} seems large, will process each frame sequentially")
274
+ # convert to device tensor
275
+ vid_enc = taehv.encode_video(vid_dev, parallel=False)
276
+ print(f" Encoded {video_path} -> {vid_enc.shape}. Decoding...")
277
+ vid_dec = taehv.decode_video(vid_enc, parallel=False)
278
+ print(f" Decoded {video_path} -> {vid_dec.shape}")
279
+ video_out_path = video_path + f".reconstructed_by_{checkpoint_name}.mp4"
280
+ video_out = VideoTensorWriter(video_out_path, (vid_dec.shape[-1], vid_dec.shape[-2]), fps=int(round(video_in.fps)))
281
+ for frame in vid_dec.clamp_(0, 1).mul_(255).round_().byte().cpu()[0]:
282
+ video_out.write(frame)
283
+ print(f" Saved to {video_out_path}")
284
+
285
+ if __name__ == "__main__":
286
+ main()
taehv/taeos1_3.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d68af87d0e216c3545bdf78dbdba86066594902a7d659e227b11fee5bc2d46a7
3
+ size 22678486
taehv/taew2_1.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f986092baa4a124035dcf44ef59cd30e0bef91b2aa4d6a6e8e152e55e82ca30c
3
+ size 22679486