Spaces:
Runtime error
Runtime error
from typing import List | |
import numpy as np | |
import torch | |
import tops | |
import torch.nn.functional as F | |
from sg3_torch_utils.ops import conv2d_resample | |
from sg3_torch_utils.ops import upfirdn2d | |
from sg3_torch_utils.ops import bias_act | |
from sg3_torch_utils.ops.fma import fma | |
class FullyConnectedLayer(torch.nn.Module): | |
def __init__(self, | |
in_features, # Number of input features. | |
out_features, # Number of output features. | |
bias=True, # Apply additive bias before the activation function? | |
activation='linear', # Activation function: 'relu', 'lrelu', etc. | |
lr_multiplier=1, # Learning rate multiplier. | |
bias_init=0, # Initial value for the additive bias. | |
): | |
super().__init__() | |
self.repr = dict( | |
in_features=in_features, out_features=out_features, bias=bias, | |
activation=activation, lr_multiplier=lr_multiplier, bias_init=bias_init) | |
self.activation = activation | |
self.weight = torch.nn.Parameter(torch.randn([out_features, in_features]) / lr_multiplier) | |
self.bias = torch.nn.Parameter(torch.full([out_features], np.float32(bias_init))) if bias else None | |
self.weight_gain = lr_multiplier / np.sqrt(in_features) | |
self.bias_gain = lr_multiplier | |
self.in_features = in_features | |
self.out_features = out_features | |
def forward(self, x): | |
w = self.weight * self.weight_gain | |
b = self.bias | |
if b is not None and self.bias_gain != 1: | |
b = b * self.bias_gain | |
x = F.linear(x, w) | |
x = bias_act.bias_act(x, b, act=self.activation) | |
return x | |
def extra_repr(self) -> str: | |
return ", ".join([f"{key}={item}" for key, item in self.repr.items()]) | |
class Conv2d(torch.nn.Module): | |
def __init__(self, | |
in_channels, # Number of input channels. | |
out_channels, # Number of output channels. | |
kernel_size=3, # Convolution kernel size. | |
up=1, # Integer upsampling factor. | |
down=1, # Integer downsampling factor | |
activation='lrelu', # Activation function: 'relu', 'lrelu', etc. | |
resample_filter=[1, 3, 3, 1], # Low-pass filter to apply when resampling activations. | |
conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping. | |
bias=True, | |
norm=False, | |
lr_multiplier=1, | |
bias_init=0, | |
w_dim=None, | |
gain=1, | |
): | |
super().__init__() | |
if norm: | |
self.norm = torch.nn.InstanceNorm2d(None) | |
assert norm in [True, False] | |
self.up = up | |
self.down = down | |
self.activation = activation | |
self.conv_clamp = conv_clamp if conv_clamp is None else conv_clamp * gain | |
self.out_channels = out_channels | |
self.in_channels = in_channels | |
self.padding = kernel_size // 2 | |
self.repr = dict( | |
in_channels=in_channels, out_channels=out_channels, | |
kernel_size=kernel_size, up=up, down=down, | |
activation=activation, resample_filter=resample_filter, conv_clamp=conv_clamp, bias=bias, | |
) | |
if self.up == 1 and self.down == 1: | |
self.resample_filter = None | |
else: | |
self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter)) | |
self.act_gain = bias_act.activation_funcs[activation].def_gain * gain | |
self.weight_gain = lr_multiplier / np.sqrt(in_channels * (kernel_size ** 2)) | |
self.weight = torch.nn.Parameter(torch.randn([out_channels, in_channels, kernel_size, kernel_size])) | |
self.bias = torch.nn.Parameter(torch.zeros([out_channels]) + bias_init) if bias else None | |
self.bias_gain = lr_multiplier | |
if w_dim is not None: | |
self.affine = FullyConnectedLayer(w_dim, in_channels, bias_init=1) | |
self.affine_beta = FullyConnectedLayer(w_dim, in_channels, bias_init=0) | |
def forward(self, x, w=None, s=None): | |
tops.assert_shape(x, [None, self.weight.shape[1], None, None]) | |
if s is not None: | |
s = s[..., :self.in_channels * 2] | |
gamma, beta = s.view(-1, self.in_channels * 2, 1, 1).chunk(2, dim=1) | |
x = fma(x, gamma, beta) | |
elif hasattr(self, "affine"): | |
gamma = self.affine(w).view(-1, self.in_channels, 1, 1) | |
beta = self.affine_beta(w).view(-1, self.in_channels, 1, 1) | |
x = fma(x, gamma, beta) | |
w = self.weight * self.weight_gain | |
# Removing flip weight is not safe. | |
x = conv2d_resample.conv2d_resample(x, w, self.resample_filter, self.up, | |
self.down, self.padding, flip_weight=self.up == 1) | |
if hasattr(self, "norm"): | |
x = self.norm(x) | |
b = self.bias * self.bias_gain if self.bias is not None else None | |
x = bias_act.bias_act(x, b, act=self.activation, gain=self.act_gain, clamp=self.conv_clamp) | |
return x | |
def extra_repr(self) -> str: | |
return ", ".join([f"{key}={item}" for key, item in self.repr.items()]) | |
class Block(torch.nn.Module): | |
def __init__(self, | |
in_channels, # Number of input channels, 0 = first block. | |
out_channels, # Number of output channels. | |
conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping. | |
up=1, | |
down=1, | |
**layer_kwargs, # Arguments for SynthesisLayer. | |
): | |
super().__init__() | |
self.in_channels = in_channels | |
self.down = down | |
self.conv0 = Conv2d(in_channels, out_channels, down=down, conv_clamp=conv_clamp, **layer_kwargs) | |
self.conv1 = Conv2d(out_channels, out_channels, up=up, conv_clamp=conv_clamp, **layer_kwargs) | |
def forward(self, x, **layer_kwargs): | |
x = self.conv0(x, **layer_kwargs) | |
x = self.conv1(x, **layer_kwargs) | |
return x | |
class ResidualBlock(torch.nn.Module): | |
def __init__(self, | |
in_channels, # Number of input channels, 0 = first block. | |
out_channels, # Number of output channels. | |
conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping. | |
up=1, | |
down=1, | |
gain_out=np.sqrt(0.5), | |
fix_residual: bool = False, | |
**layer_kwargs, # Arguments for conv layer. | |
): | |
super().__init__() | |
self.in_channels = in_channels | |
self.out_channels = out_channels | |
self.down = down | |
self.conv0 = Conv2d(in_channels, out_channels, down=down, conv_clamp=conv_clamp, **layer_kwargs) | |
self.conv1 = Conv2d(out_channels, out_channels, up=up, conv_clamp=conv_clamp, gain=gain_out, **layer_kwargs) | |
self.skip = Conv2d( | |
in_channels, out_channels, kernel_size=1, bias=False, up=up, down=down, | |
activation="linear" if fix_residual else "lrelu", | |
gain=gain_out | |
) | |
self.gain_out = gain_out | |
def forward(self, x, w=None, s=None, **layer_kwargs): | |
y = self.skip(x) | |
s_ = next(s) if s is not None else None | |
x = self.conv0(x, w, s=s_, **layer_kwargs) | |
s_ = next(s) if s is not None else None | |
x = self.conv1(x, w, s=s_, **layer_kwargs) | |
x = y + x | |
return x | |
class MinibatchStdLayer(torch.nn.Module): | |
def __init__(self, group_size, num_channels=1): | |
super().__init__() | |
self.group_size = group_size | |
self.num_channels = num_channels | |
def forward(self, x): | |
N, C, H, W = x.shape | |
with tops.suppress_tracer_warnings(): # as_tensor results are registered as constants | |
G = torch.min(torch.as_tensor(self.group_size), torch.as_tensor(N)) if self.group_size is not None else N | |
F = self.num_channels | |
c = C // F | |
# [GnFcHW] Split minibatch N into n groups of size G, and channels C into F groups of size c. | |
y = x.reshape(G, -1, F, c, H, W) | |
y = y - y.mean(dim=0) # [GnFcHW] Subtract mean over group. | |
y = y.square().mean(dim=0) # [nFcHW] Calc variance over group. | |
y = (y + 1e-8).sqrt() # [nFcHW] Calc stddev over group. | |
y = y.mean(dim=[2, 3, 4]) # [nF] Take average over channels and pixels. | |
y = y.reshape(-1, F, 1, 1) # [nF11] Add missing dimensions. | |
y = y.repeat(G, 1, H, W) # [NFHW] Replicate over group and pixels. | |
x = torch.cat([x, y], dim=1) # [NCHW] Append to input as new channels. | |
return x | |
class DiscriminatorEpilogue(torch.nn.Module): | |
def __init__(self, | |
in_channels, # Number of input channels. | |
resolution: List[int], # Resolution of this block. | |
mbstd_group_size=4, # Group size for the minibatch standard deviation layer, None = entire minibatch. | |
mbstd_num_channels=1, # Number of features for the minibatch standard deviation layer, 0 = disable. | |
activation='lrelu', # Activation function: 'relu', 'lrelu', etc. | |
conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping. | |
): | |
super().__init__() | |
self.in_channels = in_channels | |
self.resolution = resolution | |
self.mbstd = MinibatchStdLayer(group_size=mbstd_group_size, | |
num_channels=mbstd_num_channels) if mbstd_num_channels > 0 else None | |
self.conv = Conv2d( | |
in_channels + mbstd_num_channels, in_channels, | |
kernel_size=3, activation=activation, conv_clamp=conv_clamp) | |
self.fc = FullyConnectedLayer(in_channels * resolution[0] * resolution[1], in_channels, activation=activation) | |
self.out = FullyConnectedLayer(in_channels, 1) | |
def forward(self, x): | |
tops.assert_shape(x, [None, self.in_channels, *self.resolution]) # [NCHW] | |
# Main layers. | |
if self.mbstd is not None: | |
x = self.mbstd(x) | |
x = self.conv(x) | |
x = self.fc(x.flatten(1)) | |
x = self.out(x) | |
return x | |