Update handler.py
Browse files- handler.py +3 -3
handler.py
CHANGED
@@ -49,12 +49,12 @@ class EndpointHandler:
|
|
49 |
latents_mean = (
|
50 |
torch.tensor(self.vae.config.latents_mean)
|
51 |
.view(1, self.vae.config.z_dim, 1, 1, 1)
|
52 |
-
.to(
|
53 |
)
|
54 |
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(
|
55 |
1, self.vae.config.z_dim, 1, 1, 1
|
56 |
-
).to(
|
57 |
-
|
58 |
|
59 |
with torch.no_grad():
|
60 |
frames = cast(torch.Tensor, self.vae.decode(tensor, return_dict=False)[0])
|
|
|
49 |
latents_mean = (
|
50 |
torch.tensor(self.vae.config.latents_mean)
|
51 |
.view(1, self.vae.config.z_dim, 1, 1, 1)
|
52 |
+
.to(tensor.device, tensor.dtype)
|
53 |
)
|
54 |
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(
|
55 |
1, self.vae.config.z_dim, 1, 1, 1
|
56 |
+
).to(tensor.device, tensor.dtype)
|
57 |
+
tensor = tensor / latents_std + latents_mean
|
58 |
|
59 |
with torch.no_grad():
|
60 |
frames = cast(torch.Tensor, self.vae.decode(tensor, return_dict=False)[0])
|