hlky HF staff commited on
Commit
dbaddac
·
verified ·
1 Parent(s): 04619f3

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +3 -3
handler.py CHANGED
@@ -49,12 +49,12 @@ class EndpointHandler:
49
  latents_mean = (
50
  torch.tensor(self.vae.config.latents_mean)
51
  .view(1, self.vae.config.z_dim, 1, 1, 1)
52
- .to(latents.device, latents.dtype)
53
  )
54
  latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(
55
  1, self.vae.config.z_dim, 1, 1, 1
56
- ).to(latents.device, latents.dtype)
57
- latents = latents / latents_std + latents_mean
58
 
59
  with torch.no_grad():
60
  frames = cast(torch.Tensor, self.vae.decode(tensor, return_dict=False)[0])
 
49
  latents_mean = (
50
  torch.tensor(self.vae.config.latents_mean)
51
  .view(1, self.vae.config.z_dim, 1, 1, 1)
52
+ .to(tensor.device, tensor.dtype)
53
  )
54
  latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(
55
  1, self.vae.config.z_dim, 1, 1, 1
56
+ ).to(tensor.device, tensor.dtype)
57
+ tensor = tensor / latents_std + latents_mean
58
 
59
  with torch.no_grad():
60
  frames = cast(torch.Tensor, self.vae.decode(tensor, return_dict=False)[0])