Upload 7 files
Browse files- taehv/LICENSE +21 -0
- taehv/README.md +63 -0
- taehv/taecvx.pth +3 -0
- taehv/taehv.pth +3 -0
- taehv/taehv.py +286 -0
- taehv/taeos1_3.pth +3 -0
- taehv/taew2_1.pth +3 -0
taehv/LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2025 Ollin Boer Bohan
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
taehv/README.md
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# 🥮 Tiny AutoEncoder for Hunyuan Video & Wan 2.1
|
2 |
+
|
3 |
+
## What is TAEHV?
|
4 |
+
|
5 |
+
TAEHV is a Tiny AutoEncoder for Hunyuan Video (& Wan 2.1). TAEHV can decode latents into videos more cheaply (in time & memory) than the full-size VAEs, at the cost of slightly lower quality.
|
6 |
+
|
7 |
+
Here's a comparison of the output & memory usage of the Full Hunyuan VAE vs. TAEHV:
|
8 |
+
|
9 |
+
<table>
|
10 |
+
<tr><th><tt>pipe.vae</tt></th><th>Full Hunyuan VAE</th><th>TAEHV</th></tr>
|
11 |
+
<tr>
|
12 |
+
<td>Decoded Video<br/><sup>(converted to GIF)</sup></td>
|
13 |
+
<td><img src="https://github.com/user-attachments/assets/b9ee3405-c210-4410-95ac-639a4ed09c50"/></td>
|
14 |
+
<td><img src="https://github.com/user-attachments/assets/3fe3cb6a-30e5-46fe-9458-f0a39e454b86"/></td>
|
15 |
+
</tr>
|
16 |
+
<tr>
|
17 |
+
<td>Runtime<br/><sup>(in fp16, on GH200)</sup></td>
|
18 |
+
<td><strong>~2-3s</strong> for decoding 61 frames of (512, 320) video</td>
|
19 |
+
<td><strong>~0.5s</strong> for decoding 61 frames of (512, 320) video</td>
|
20 |
+
</tr>
|
21 |
+
<tr>
|
22 |
+
<td>Memory<br/><sup>(in fp16, on GH200)</sup></td>
|
23 |
+
<td><strong>~6-9GB Peak Memory Usage</strong><br/><img src="https://github.com/user-attachments/assets/d7837271-c748-4eef-ab37-eda6cc1e6a69"/></td>
|
24 |
+
<td><strong><0.5GB Peak Memory Usage</strong><br/><img src="https://github.com/user-attachments/assets/c71e2ef5-12f1-431f-b193-29d9a5ee6343"/></td>
|
25 |
+
</tr>
|
26 |
+
</table>
|
27 |
+
|
28 |
+
|
29 |
+
See the [profiling notebook](./examples/TAEHV_Profiling.ipynb) for details on this comparison or the [example notebook](./examples/TAEHV_T2I_Demo.ipynb) for a simpler demo.
|
30 |
+
|
31 |
+
## How do I use TAEHV with Wan 2.1?
|
32 |
+
|
33 |
+
Since Wan 2.1 uses the same input / output shapes as Hunyuan VAE, you can also use TAEHV for Wan 2.1 decoding using the `taew2_1.pth` weights (see the [Wan 2.1 example notebook](./examples/TAEW2.1_T2I_Demo.ipynb)).
|
34 |
+
|
35 |
+
## How do I use TAEHV with CogVideoX?
|
36 |
+
|
37 |
+
Try the `taecvx.pth` weights (see the [example notebook](./examples/TAECVX_T2I_Demo.ipynb)).
|
38 |
+
|
39 |
+
## How do I use TAEHV with Open-Sora 1.3?
|
40 |
+
|
41 |
+
Try the `taeos1_3.pth` weights.
|
42 |
+
|
43 |
+
## How can I reduce the TAEHV decoding cost further?
|
44 |
+
|
45 |
+
You can disable temporal or spatial upscaling to get even-cheaper decoding.
|
46 |
+
|
47 |
+
```python
|
48 |
+
TAEHV(decoder_time_upscale=(False, False), decoder_space_upscale=(True, True, True))
|
49 |
+
```
|
50 |
+
|
51 |
+

|
52 |
+
|
53 |
+
```python
|
54 |
+
TAEHV(decoder_time_upscale=(False, False), decoder_space_upscale=(False, False, False))
|
55 |
+
```
|
56 |
+
|
57 |
+

|
58 |
+
|
59 |
+
If you have a powerful GPU or are decoding at a reduced resolution, you can also set `parallel=True` in `TAEHV.decode_video` to decode all frames at once (which is faster but requires more memory).
|
60 |
+
|
61 |
+
## Limitations
|
62 |
+
|
63 |
+
TAEHV is still pretty experimental (specifically, it's a hacky finetune of [TAEM1](https://github.com/madebyollin/taem1) :) using a fairly limited dataset) and I haven't tested it much yet. Please report quality / performance issues as you discover them.
|
taehv/taecvx.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ec0ab2077a044d294f05a5731c5ca5174fd54b082552598ad7c5b1800159f423
|
3 |
+
size 22680146
|
taehv/taehv.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3866076f74b50087d5ba4d145191480ad51a08a316c8e95a7233284945cfc26e
|
3 |
+
size 22679486
|
taehv/taehv.py
ADDED
@@ -0,0 +1,286 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
"""
|
3 |
+
Tiny AutoEncoder for Hunyuan Video
|
4 |
+
(DNN for encoding / decoding videos to Hunyuan Video's latent space)
|
5 |
+
"""
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from tqdm.auto import tqdm
|
10 |
+
from collections import namedtuple
|
11 |
+
|
12 |
+
DecoderResult = namedtuple("DecoderResult", ("frame", "memory"))
|
13 |
+
TWorkItem = namedtuple("TWorkItem", ("input_tensor", "block_index"))
|
14 |
+
|
15 |
+
def conv(n_in, n_out, **kwargs):
|
16 |
+
return nn.Conv2d(n_in, n_out, 3, padding=1, **kwargs)
|
17 |
+
|
18 |
+
class Clamp(nn.Module):
|
19 |
+
def forward(self, x):
|
20 |
+
return torch.tanh(x / 3) * 3
|
21 |
+
|
22 |
+
class MemBlock(nn.Module):
|
23 |
+
def __init__(self, n_in, n_out):
|
24 |
+
super().__init__()
|
25 |
+
self.conv = nn.Sequential(conv(n_in * 2, n_out), nn.ReLU(inplace=True), conv(n_out, n_out), nn.ReLU(inplace=True), conv(n_out, n_out))
|
26 |
+
self.skip = nn.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity()
|
27 |
+
self.act = nn.ReLU(inplace=True)
|
28 |
+
def forward(self, x, past):
|
29 |
+
return self.act(self.conv(torch.cat([x, past], 1)) + self.skip(x))
|
30 |
+
|
31 |
+
class TPool(nn.Module):
|
32 |
+
def __init__(self, n_f, stride):
|
33 |
+
super().__init__()
|
34 |
+
self.stride = stride
|
35 |
+
self.conv = nn.Conv2d(n_f*stride,n_f, 1, bias=False)
|
36 |
+
def forward(self, x):
|
37 |
+
_NT, C, H, W = x.shape
|
38 |
+
return self.conv(x.reshape(-1, self.stride * C, H, W))
|
39 |
+
|
40 |
+
class TGrow(nn.Module):
|
41 |
+
def __init__(self, n_f, stride):
|
42 |
+
super().__init__()
|
43 |
+
self.stride = stride
|
44 |
+
self.conv = nn.Conv2d(n_f, n_f*stride, 1, bias=False)
|
45 |
+
def forward(self, x):
|
46 |
+
_NT, C, H, W = x.shape
|
47 |
+
x = self.conv(x)
|
48 |
+
return x.reshape(-1, C, H, W)
|
49 |
+
|
50 |
+
def apply_model_with_memblocks(model, x, parallel, show_progress_bar):
|
51 |
+
"""
|
52 |
+
Apply a sequential model with memblocks to the given input.
|
53 |
+
Args:
|
54 |
+
- model: nn.Sequential of blocks to apply
|
55 |
+
- x: input data, of dimensions NTCHW
|
56 |
+
- parallel: if True, parallelize over timesteps (fast but uses O(T) memory)
|
57 |
+
if False, each timestep will be processed sequentially (slow but uses O(1) memory)
|
58 |
+
- show_progress_bar: if True, enables tqdm progressbar display
|
59 |
+
|
60 |
+
Returns NTCHW tensor of output data.
|
61 |
+
"""
|
62 |
+
assert x.ndim == 5, f"TAEHV operates on NTCHW tensors, but got {x.ndim}-dim tensor"
|
63 |
+
N, T, C, H, W = x.shape
|
64 |
+
if parallel:
|
65 |
+
x = x.reshape(N*T, C, H, W)
|
66 |
+
# parallel over input timesteps, iterate over blocks
|
67 |
+
for b in tqdm(model, disable=not show_progress_bar):
|
68 |
+
if isinstance(b, MemBlock):
|
69 |
+
NT, C, H, W = x.shape
|
70 |
+
T = NT // N
|
71 |
+
_x = x.reshape(N, T, C, H, W)
|
72 |
+
mem = F.pad(_x, (0,0,0,0,0,0,1,0), value=0)[:,:T].reshape(x.shape)
|
73 |
+
x = b(x, mem)
|
74 |
+
else:
|
75 |
+
x = b(x)
|
76 |
+
NT, C, H, W = x.shape
|
77 |
+
T = NT // N
|
78 |
+
x = x.view(N, T, C, H, W)
|
79 |
+
else:
|
80 |
+
# TODO(oboerbohan): at least on macos this still gradually uses more memory during decode...
|
81 |
+
# need to fix :(
|
82 |
+
out = []
|
83 |
+
# iterate over input timesteps and also iterate over blocks.
|
84 |
+
# because of the cursed TPool/TGrow blocks, this is not a nested loop,
|
85 |
+
# it's actually a ***graph traversal*** problem! so let's make a queue
|
86 |
+
work_queue = [TWorkItem(xt, 0) for t, xt in enumerate(x.reshape(N, T * C, H, W).chunk(T, dim=1))]
|
87 |
+
# in addition to manually managing our queue, we also need to manually manage our progressbar.
|
88 |
+
# we'll update it for every source node that we consume.
|
89 |
+
progress_bar = tqdm(range(T), disable=not show_progress_bar)
|
90 |
+
# we'll also need a separate addressable memory per node as well
|
91 |
+
mem = [None] * len(model)
|
92 |
+
while work_queue:
|
93 |
+
xt, i = work_queue.pop(0)
|
94 |
+
if i == 0:
|
95 |
+
# new source node consumed
|
96 |
+
progress_bar.update(1)
|
97 |
+
if i == len(model):
|
98 |
+
# reached end of the graph, append result to output list
|
99 |
+
out.append(xt)
|
100 |
+
else:
|
101 |
+
# fetch the block to process
|
102 |
+
b = model[i]
|
103 |
+
if isinstance(b, MemBlock):
|
104 |
+
# mem blocks are simple since we're visiting the graph in causal order
|
105 |
+
if mem[i] is None:
|
106 |
+
xt_new = b(xt, xt * 0)
|
107 |
+
mem[i] = xt
|
108 |
+
else:
|
109 |
+
xt_new = b(xt, mem[i])
|
110 |
+
mem[i].copy_(xt) # inplace might reduce mysterious pytorch memory allocations? doesn't help though
|
111 |
+
# add successor to work queue
|
112 |
+
work_queue.insert(0, TWorkItem(xt_new, i+1))
|
113 |
+
elif isinstance(b, TPool):
|
114 |
+
# pool blocks are miserable
|
115 |
+
if mem[i] is None:
|
116 |
+
mem[i] = [] # pool memory is itself a queue of inputs to pool
|
117 |
+
mem[i].append(xt)
|
118 |
+
if len(mem[i]) > b.stride:
|
119 |
+
# pool mem is in invalid state, we should have pooled before this
|
120 |
+
raise ValueError("???")
|
121 |
+
elif len(mem[i]) < b.stride:
|
122 |
+
# pool mem is not yet full, go back to processing the work queue
|
123 |
+
pass
|
124 |
+
else:
|
125 |
+
# pool mem is ready, run the pool block
|
126 |
+
N, C, H, W = xt.shape
|
127 |
+
xt = b(torch.cat(mem[i], 1).view(N*b.stride, C, H, W))
|
128 |
+
# reset the pool mem
|
129 |
+
mem[i] = []
|
130 |
+
# add successor to work queue
|
131 |
+
work_queue.insert(0, TWorkItem(xt, i+1))
|
132 |
+
elif isinstance(b, TGrow):
|
133 |
+
xt = b(xt)
|
134 |
+
NT, C, H, W = xt.shape
|
135 |
+
# each tgrow has multiple successor nodes
|
136 |
+
for xt_next in reversed(xt.view(N, b.stride*C, H, W).chunk(b.stride, 1)):
|
137 |
+
# add successor to work queue
|
138 |
+
work_queue.insert(0, TWorkItem(xt_next, i+1))
|
139 |
+
else:
|
140 |
+
# normal block with no funny business
|
141 |
+
xt = b(xt)
|
142 |
+
# add successor to work queue
|
143 |
+
work_queue.insert(0, TWorkItem(xt, i+1))
|
144 |
+
progress_bar.close()
|
145 |
+
x = torch.stack(out, 1)
|
146 |
+
return x
|
147 |
+
|
148 |
+
class TAEHV(nn.Module):
|
149 |
+
latent_channels = 16
|
150 |
+
image_channels = 3
|
151 |
+
def __init__(self, checkpoint_path="taehv.pth", decoder_time_upscale=(True, True), decoder_space_upscale=(True, True, True)):
|
152 |
+
"""Initialize pretrained TAEHV from the given checkpoint.
|
153 |
+
|
154 |
+
Arg:
|
155 |
+
checkpoint_path: path to weight file to load. taehv.pth for Hunyuan, taew2_1.pth for Wan 2.1.
|
156 |
+
decoder_time_upscale: whether temporal upsampling is enabled for each block. upsampling can be disabled for a cheaper preview.
|
157 |
+
decoder_space_upscale: whether spatial upsampling is enabled for each block. upsampling can be disabled for a cheaper preview.
|
158 |
+
"""
|
159 |
+
super().__init__()
|
160 |
+
self.encoder = nn.Sequential(
|
161 |
+
conv(TAEHV.image_channels, 64), nn.ReLU(inplace=True),
|
162 |
+
TPool(64, 2), conv(64, 64, stride=2, bias=False), MemBlock(64, 64), MemBlock(64, 64), MemBlock(64, 64),
|
163 |
+
TPool(64, 2), conv(64, 64, stride=2, bias=False), MemBlock(64, 64), MemBlock(64, 64), MemBlock(64, 64),
|
164 |
+
TPool(64, 1), conv(64, 64, stride=2, bias=False), MemBlock(64, 64), MemBlock(64, 64), MemBlock(64, 64),
|
165 |
+
conv(64, TAEHV.latent_channels),
|
166 |
+
)
|
167 |
+
n_f = [256, 128, 64, 64]
|
168 |
+
self.frames_to_trim = 2**sum(decoder_time_upscale) - 1
|
169 |
+
self.decoder = nn.Sequential(
|
170 |
+
Clamp(), conv(TAEHV.latent_channels, n_f[0]), nn.ReLU(inplace=True),
|
171 |
+
MemBlock(n_f[0], n_f[0]), MemBlock(n_f[0], n_f[0]), MemBlock(n_f[0], n_f[0]), nn.Upsample(scale_factor=2 if decoder_space_upscale[0] else 1), TGrow(n_f[0], 1), conv(n_f[0], n_f[1], bias=False),
|
172 |
+
MemBlock(n_f[1], n_f[1]), MemBlock(n_f[1], n_f[1]), MemBlock(n_f[1], n_f[1]), nn.Upsample(scale_factor=2 if decoder_space_upscale[1] else 1), TGrow(n_f[1], 2 if decoder_time_upscale[0] else 1), conv(n_f[1], n_f[2], bias=False),
|
173 |
+
MemBlock(n_f[2], n_f[2]), MemBlock(n_f[2], n_f[2]), MemBlock(n_f[2], n_f[2]), nn.Upsample(scale_factor=2 if decoder_space_upscale[2] else 1), TGrow(n_f[2], 2 if decoder_time_upscale[1] else 1), conv(n_f[2], n_f[3], bias=False),
|
174 |
+
nn.ReLU(inplace=True), conv(n_f[3], TAEHV.image_channels),
|
175 |
+
)
|
176 |
+
if checkpoint_path is not None:
|
177 |
+
self.load_state_dict(self.patch_tgrow_layers(torch.load(checkpoint_path, map_location="cpu", weights_only=True)))
|
178 |
+
|
179 |
+
def patch_tgrow_layers(self, sd):
|
180 |
+
"""Patch TGrow layers to use a smaller kernel if needed.
|
181 |
+
|
182 |
+
Args:
|
183 |
+
sd: state dict to patch
|
184 |
+
"""
|
185 |
+
new_sd = self.state_dict()
|
186 |
+
for i, layer in enumerate(self.decoder):
|
187 |
+
if isinstance(layer, TGrow):
|
188 |
+
key = f"decoder.{i}.conv.weight"
|
189 |
+
if sd[key].shape[0] > new_sd[key].shape[0]:
|
190 |
+
# take the last-timestep output channels
|
191 |
+
sd[key] = sd[key][-new_sd[key].shape[0]:]
|
192 |
+
return sd
|
193 |
+
|
194 |
+
def encode_video(self, x, parallel=True, show_progress_bar=True):
|
195 |
+
"""Encode a sequence of frames.
|
196 |
+
|
197 |
+
Args:
|
198 |
+
x: input NTCHW RGB (C=3) tensor with values in [0, 1].
|
199 |
+
parallel: if True, all frames will be processed at once.
|
200 |
+
(this is faster but may require more memory).
|
201 |
+
if False, frames will be processed sequentially.
|
202 |
+
Returns NTCHW latent tensor with ~Gaussian values.
|
203 |
+
"""
|
204 |
+
return apply_model_with_memblocks(self.encoder, x, parallel, show_progress_bar)
|
205 |
+
|
206 |
+
def decode_video(self, x, parallel=True, show_progress_bar=True):
|
207 |
+
"""Decode a sequence of frames.
|
208 |
+
|
209 |
+
Args:
|
210 |
+
x: input NTCHW latent (C=12) tensor with ~Gaussian values.
|
211 |
+
parallel: if True, all frames will be processed at once.
|
212 |
+
(this is faster but may require more memory).
|
213 |
+
if False, frames will be processed sequentially.
|
214 |
+
Returns NTCHW RGB tensor with ~[0, 1] values.
|
215 |
+
"""
|
216 |
+
x = apply_model_with_memblocks(self.decoder, x, parallel, show_progress_bar)
|
217 |
+
return x[:, self.frames_to_trim:]
|
218 |
+
|
219 |
+
def forward(self, x):
|
220 |
+
return self.c(x)
|
221 |
+
|
222 |
+
@torch.no_grad()
|
223 |
+
def main():
|
224 |
+
"""Run TAEHV roundtrip reconstruction on the given video paths."""
|
225 |
+
import os
|
226 |
+
import sys
|
227 |
+
import cv2 # no highly esteemed deed is commemorated here
|
228 |
+
|
229 |
+
class VideoTensorReader:
|
230 |
+
def __init__(self, video_file_path):
|
231 |
+
self.cap = cv2.VideoCapture(video_file_path)
|
232 |
+
assert self.cap.isOpened(), f"Could not load {video_file_path}"
|
233 |
+
self.fps = self.cap.get(cv2.CAP_PROP_FPS)
|
234 |
+
def __iter__(self):
|
235 |
+
return self
|
236 |
+
def __next__(self):
|
237 |
+
ret, frame = self.cap.read()
|
238 |
+
if not ret:
|
239 |
+
self.cap.release()
|
240 |
+
raise StopIteration # End of video or error
|
241 |
+
return torch.from_numpy(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)).permute(2, 0, 1) # BGR HWC -> RGB CHW
|
242 |
+
|
243 |
+
class VideoTensorWriter:
|
244 |
+
def __init__(self, video_file_path, width_height, fps=30):
|
245 |
+
self.writer = cv2.VideoWriter(video_file_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, width_height)
|
246 |
+
assert self.writer.isOpened(), f"Could not create writer for {video_file_path}"
|
247 |
+
def write(self, frame_tensor):
|
248 |
+
assert frame_tensor.ndim == 3 and frame_tensor.shape[0] == 3, f"{frame_tensor.shape}??"
|
249 |
+
self.writer.write(cv2.cvtColor(frame_tensor.permute(1, 2, 0).numpy(), cv2.COLOR_RGB2BGR)) # RGB CHW -> BGR HWC
|
250 |
+
def __del__(self):
|
251 |
+
if hasattr(self, 'writer'): self.writer.release()
|
252 |
+
|
253 |
+
dev = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
|
254 |
+
dtype = torch.float16
|
255 |
+
checkpoint_path = os.getenv("TAEHV_CHECKPOINT_PATH", "taehv.pth")
|
256 |
+
checkpoint_name = os.path.splitext(os.path.basename(checkpoint_path))[0]
|
257 |
+
print(f"Using device \033[31m{dev}\033[0m, dtype \033[32m{dtype}\033[0m, checkpoint \033[34m{checkpoint_name}\033[0m ({checkpoint_path})")
|
258 |
+
taehv = TAEHV(checkpoint_path=checkpoint_path).to(dev, dtype)
|
259 |
+
for video_path in sys.argv[1:]:
|
260 |
+
print(f"Processing {video_path}...")
|
261 |
+
video_in = VideoTensorReader(video_path)
|
262 |
+
video = torch.stack(list(video_in), 0)[None]
|
263 |
+
vid_dev = video.to(dev, dtype).div_(255.0)
|
264 |
+
# convert to device tensor
|
265 |
+
if video.numel() < 100_000_000:
|
266 |
+
print(f" {video_path} seems small enough, will process all frames in parallel")
|
267 |
+
# convert to device tensor
|
268 |
+
vid_enc = taehv.encode_video(vid_dev)
|
269 |
+
print(f" Encoded {video_path} -> {vid_enc.shape}. Decoding...")
|
270 |
+
vid_dec = taehv.decode_video(vid_enc)
|
271 |
+
print(f" Decoded {video_path} -> {vid_dec.shape}")
|
272 |
+
else:
|
273 |
+
print(f" {video_path} seems large, will process each frame sequentially")
|
274 |
+
# convert to device tensor
|
275 |
+
vid_enc = taehv.encode_video(vid_dev, parallel=False)
|
276 |
+
print(f" Encoded {video_path} -> {vid_enc.shape}. Decoding...")
|
277 |
+
vid_dec = taehv.decode_video(vid_enc, parallel=False)
|
278 |
+
print(f" Decoded {video_path} -> {vid_dec.shape}")
|
279 |
+
video_out_path = video_path + f".reconstructed_by_{checkpoint_name}.mp4"
|
280 |
+
video_out = VideoTensorWriter(video_out_path, (vid_dec.shape[-1], vid_dec.shape[-2]), fps=int(round(video_in.fps)))
|
281 |
+
for frame in vid_dec.clamp_(0, 1).mul_(255).round_().byte().cpu()[0]:
|
282 |
+
video_out.write(frame)
|
283 |
+
print(f" Saved to {video_out_path}")
|
284 |
+
|
285 |
+
if __name__ == "__main__":
|
286 |
+
main()
|
taehv/taeos1_3.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d68af87d0e216c3545bdf78dbdba86066594902a7d659e227b11fee5bc2d46a7
|
3 |
+
size 22678486
|
taehv/taew2_1.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f986092baa4a124035dcf44ef59cd30e0bef91b2aa4d6a6e8e152e55e82ca30c
|
3 |
+
size 22679486
|