File size: 4,657 Bytes
131da64 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 |
import abc
import torch
import torch.nn as nn
# Flags required to enable jit fusion kernels
torch._C._jit_set_profiling_mode(False)
torch._C._jit_set_profiling_executor(False)
torch._C._jit_override_can_fuse_on_cpu(True)
torch._C._jit_override_can_fuse_on_gpu(True)
def get_noise(config, dtype=torch.float32):
if config.noise.type == 'geometric':
return GeometricNoise(config.noise.sigma_min,
config.noise.sigma_max)
elif config.noise.type == 'loglinear':
return LogLinearNoise()
elif config.noise.type == 'cosine':
return CosineNoise()
elif config.noise.type == 'cosinesqr':
return CosineSqrNoise()
elif config.noise.type == 'linear':
return Linear(config.noise.sigma_min,
config.noise.sigma_max,
dtype)
else:
raise ValueError(f'{config.noise.type} is not a valid noise')
def binary_discretization(z):
z_hard = torch.sign(z)
z_soft = z / torch.norm(z, dim=-1, keepdim=True)
return z_soft + (z_hard - z_soft).detach()
class Noise(abc.ABC, nn.Module):
"""
Baseline forward method to get the total + rate of noise at a timestep
"""
def forward(self, t):
# Assume time goes from 0 to 1
return self.total_noise(t), self.rate_noise(t)
@abc.abstractmethod
def rate_noise(self, t):
"""
Rate of change of noise ie g(t)
"""
pass
@abc.abstractmethod
def total_noise(self, t):
"""
Total noise ie \int_0^t g(t) dt + g(0)
"""
pass
class CosineNoise(Noise):
def __init__(self, eps=1e-3):
super().__init__()
self.eps = eps
def rate_noise(self, t):
cos = (1 - self.eps) * torch.cos(t * torch.pi / 2)
sin = (1 - self.eps) * torch.sin(t * torch.pi / 2)
scale = torch.pi / 2
return scale * sin / (cos + self.eps)
def total_noise(self, t):
cos = torch.cos(t * torch.pi / 2)
return - torch.log(self.eps + (1 - self.eps) * cos)
class CosineSqrNoise(Noise):
def __init__(self, eps=1e-3):
super().__init__()
self.eps = eps
def rate_noise(self, t):
cos = (1 - self.eps) * (
torch.cos(t * torch.pi / 2) ** 2)
sin = (1 - self.eps) * torch.sin(t * torch.pi)
scale = torch.pi / 2
return scale * sin / (cos + self.eps)
def total_noise(self, t):
cos = torch.cos(t * torch.pi / 2) ** 2
return - torch.log(self.eps + (1 - self.eps) * cos)
class Linear(Noise):
def __init__(self, sigma_min=0, sigma_max=10, dtype=torch.float32):
super().__init__()
self.sigma_min = torch.tensor(sigma_min, dtype=dtype)
self.sigma_max = torch.tensor(sigma_max, dtype=dtype)
def rate_noise(self, t):
return self.sigma_max - self.sigma_min
def total_noise(self, t):
return self.sigma_min + t * (self.sigma_max - self.sigma_min)
def importance_sampling_transformation(self, t):
f_T = torch.log1p(- torch.exp(- self.sigma_max))
f_0 = torch.log1p(- torch.exp(- self.sigma_min))
sigma_t = - torch.log1p(- torch.exp(t * f_T + (1 - t) * f_0))
return (sigma_t - self.sigma_min) / (
self.sigma_max - self.sigma_min)
class GeometricNoise(Noise):
def __init__(self, sigma_min=1e-3, sigma_max=1):
super().__init__()
self.sigmas = 1.0 * torch.tensor([sigma_min, sigma_max])
def rate_noise(self, t):
return self.sigmas[0] ** (1 - t) * self.sigmas[1] ** t * (
self.sigmas[1].log() - self.sigmas[0].log())
def total_noise(self, t):
return self.sigmas[0] ** (1 - t) * self.sigmas[1] ** t
from decoupled_utils import is_torch_xla_available
is_xla_available = is_torch_xla_available()
class LogLinearNoise(Noise):
"""Log Linear noise schedule.
Built such that 1 - 1/e^(n(t)) interpolates between 0 and
~1 when t varies from 0 to 1. Total noise is
-log(1 - (1 - eps) * t), so the sigma will be
(1 - eps) * t.
"""
def __init__(self, eps=1e-3):
super().__init__()
self.eps = eps
self.sigma_max = self.total_noise(torch.tensor(1.0, dtype=torch.float32))
self.sigma_min = self.eps + self.total_noise(torch.tensor(0.0, dtype=torch.float32))
def rate_noise(self, t):
return (1 - self.eps) / (1 - (1 - self.eps) * t)
def total_noise(self, t):
if is_xla_available:
# XLA breaks here with large batch sizes...
return -torch.log(1 + (-(1 - self.eps) * t.to(torch.float64))).to(t.dtype)
else:
return -torch.log1p(-(1 - self.eps) * t)
def importance_sampling_transformation(self, t):
f_T = torch.log1p(- torch.exp(- self.sigma_max))
f_0 = torch.log1p(- torch.exp(- self.sigma_min))
sigma_t = - torch.log1p(- torch.exp(t * f_T + (1 - t) * f_0))
t = - torch.expm1(- sigma_t) / (1 - self.eps)
return t
|