Spaces:
Runtime error
Runtime error
Commit
Β·
b6a5660
1
Parent(s):
d8df719
Add tensor_to_mp4
Browse files- lvdm/utils/saving_utils.py +18 -0
lvdm/utils/saving_utils.py
CHANGED
@@ -15,6 +15,24 @@ from torch import Tensor
|
|
15 |
from torchvision.transforms.functional import to_tensor
|
16 |
|
17 |
# ----------------------------------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
def savenp2sheet(imgs, savepath, nrow=None):
|
19 |
""" save multiple imgs (in numpy array type) to a img sheet.
|
20 |
img sheet is one row.
|
|
|
15 |
from torchvision.transforms.functional import to_tensor
|
16 |
|
17 |
# ----------------------------------------------------------------------------------------------
|
18 |
+
def tensor_to_mp4(video, savepath, fps, rescale=True, nrow=None):
|
19 |
+
"""
|
20 |
+
video: torch.Tensor, b,c,t,h,w, 0-1
|
21 |
+
if -1~1, enable rescale=True
|
22 |
+
"""
|
23 |
+
n = video.shape[0]
|
24 |
+
video = video.permute(2, 0, 1, 3, 4) # t,n,c,h,w
|
25 |
+
nrow = int(np.sqrt(n)) if nrow is None else nrow
|
26 |
+
frame_grids = [torchvision.utils.make_grid(framesheet, nrow=nrow) for framesheet in video] # [3, grid_h, grid_w]
|
27 |
+
grid = torch.stack(frame_grids, dim=0) # stack in temporal dim [T, 3, grid_h, grid_w]
|
28 |
+
grid = torch.clamp(grid.float(), -1., 1.)
|
29 |
+
if rescale:
|
30 |
+
grid = (grid + 1.0) / 2.0
|
31 |
+
grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1) # [T, 3, grid_h, grid_w] -> [T, grid_h, grid_w, 3]
|
32 |
+
#print(f'Save video to {savepath}')
|
33 |
+
torchvision.io.write_video(savepath, grid, fps=fps, video_codec='h264', options={'crf': '10'})
|
34 |
+
|
35 |
+
# ----------------------------------------------------------------------------------------------
|
36 |
def savenp2sheet(imgs, savepath, nrow=None):
|
37 |
""" save multiple imgs (in numpy array type) to a img sheet.
|
38 |
img sheet is one row.
|