RamAnanth1 commited on
Commit
b6a5660
Β·
1 Parent(s): d8df719

Add tensor_to_mp4

Browse files
Files changed (1) hide show
  1. lvdm/utils/saving_utils.py +18 -0
lvdm/utils/saving_utils.py CHANGED
@@ -15,6 +15,24 @@ from torch import Tensor
15
  from torchvision.transforms.functional import to_tensor
16
 
17
  # ----------------------------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  def savenp2sheet(imgs, savepath, nrow=None):
19
  """ save multiple imgs (in numpy array type) to a img sheet.
20
  img sheet is one row.
 
15
  from torchvision.transforms.functional import to_tensor
16
 
17
  # ----------------------------------------------------------------------------------------------
18
+ def tensor_to_mp4(video, savepath, fps, rescale=True, nrow=None):
19
+ """
20
+ video: torch.Tensor, b,c,t,h,w, 0-1
21
+ if -1~1, enable rescale=True
22
+ """
23
+ n = video.shape[0]
24
+ video = video.permute(2, 0, 1, 3, 4) # t,n,c,h,w
25
+ nrow = int(np.sqrt(n)) if nrow is None else nrow
26
+ frame_grids = [torchvision.utils.make_grid(framesheet, nrow=nrow) for framesheet in video] # [3, grid_h, grid_w]
27
+ grid = torch.stack(frame_grids, dim=0) # stack in temporal dim [T, 3, grid_h, grid_w]
28
+ grid = torch.clamp(grid.float(), -1., 1.)
29
+ if rescale:
30
+ grid = (grid + 1.0) / 2.0
31
+ grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1) # [T, 3, grid_h, grid_w] -> [T, grid_h, grid_w, 3]
32
+ #print(f'Save video to {savepath}')
33
+ torchvision.io.write_video(savepath, grid, fps=fps, video_codec='h264', options={'crf': '10'})
34
+
35
+ # ----------------------------------------------------------------------------------------------
36
  def savenp2sheet(imgs, savepath, nrow=None):
37
  """ save multiple imgs (in numpy array type) to a img sheet.
38
  img sheet is one row.