deep_privacy2 / dp2 /layers /sg2_layers.py
haakohu's picture
fix
44539fc
from typing import List
import numpy as np
import torch
import tops
import torch.nn.functional as F
from sg3_torch_utils.ops import conv2d_resample
from sg3_torch_utils.ops import upfirdn2d
from sg3_torch_utils.ops import bias_act
from sg3_torch_utils.ops.fma import fma
class FullyConnectedLayer(torch.nn.Module):
def __init__(self,
in_features, # Number of input features.
out_features, # Number of output features.
bias=True, # Apply additive bias before the activation function?
activation='linear', # Activation function: 'relu', 'lrelu', etc.
lr_multiplier=1, # Learning rate multiplier.
bias_init=0, # Initial value for the additive bias.
):
super().__init__()
self.repr = dict(
in_features=in_features, out_features=out_features, bias=bias,
activation=activation, lr_multiplier=lr_multiplier, bias_init=bias_init)
self.activation = activation
self.weight = torch.nn.Parameter(torch.randn([out_features, in_features]) / lr_multiplier)
self.bias = torch.nn.Parameter(torch.full([out_features], np.float32(bias_init))) if bias else None
self.weight_gain = lr_multiplier / np.sqrt(in_features)
self.bias_gain = lr_multiplier
self.in_features = in_features
self.out_features = out_features
def forward(self, x):
w = self.weight * self.weight_gain
b = self.bias
if b is not None and self.bias_gain != 1:
b = b * self.bias_gain
x = F.linear(x, w)
x = bias_act.bias_act(x, b, act=self.activation)
return x
def extra_repr(self) -> str:
return ", ".join([f"{key}={item}" for key, item in self.repr.items()])
class Conv2d(torch.nn.Module):
def __init__(self,
in_channels, # Number of input channels.
out_channels, # Number of output channels.
kernel_size=3, # Convolution kernel size.
up=1, # Integer upsampling factor.
down=1, # Integer downsampling factor
activation='lrelu', # Activation function: 'relu', 'lrelu', etc.
resample_filter=[1, 3, 3, 1], # Low-pass filter to apply when resampling activations.
conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping.
bias=True,
norm=False,
lr_multiplier=1,
bias_init=0,
w_dim=None,
gain=1,
):
super().__init__()
if norm:
self.norm = torch.nn.InstanceNorm2d(None)
assert norm in [True, False]
self.up = up
self.down = down
self.activation = activation
self.conv_clamp = conv_clamp if conv_clamp is None else conv_clamp * gain
self.out_channels = out_channels
self.in_channels = in_channels
self.padding = kernel_size // 2
self.repr = dict(
in_channels=in_channels, out_channels=out_channels,
kernel_size=kernel_size, up=up, down=down,
activation=activation, resample_filter=resample_filter, conv_clamp=conv_clamp, bias=bias,
)
if self.up == 1 and self.down == 1:
self.resample_filter = None
else:
self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter))
self.act_gain = bias_act.activation_funcs[activation].def_gain * gain
self.weight_gain = lr_multiplier / np.sqrt(in_channels * (kernel_size ** 2))
self.weight = torch.nn.Parameter(torch.randn([out_channels, in_channels, kernel_size, kernel_size]))
self.bias = torch.nn.Parameter(torch.zeros([out_channels]) + bias_init) if bias else None
self.bias_gain = lr_multiplier
if w_dim is not None:
self.affine = FullyConnectedLayer(w_dim, in_channels, bias_init=1)
self.affine_beta = FullyConnectedLayer(w_dim, in_channels, bias_init=0)
def forward(self, x, w=None, s=None):
tops.assert_shape(x, [None, self.weight.shape[1], None, None])
if s is not None:
s = s[..., :self.in_channels * 2]
gamma, beta = s.view(-1, self.in_channels * 2, 1, 1).chunk(2, dim=1)
x = fma(x, gamma, beta)
elif hasattr(self, "affine"):
gamma = self.affine(w).view(-1, self.in_channels, 1, 1)
beta = self.affine_beta(w).view(-1, self.in_channels, 1, 1)
x = fma(x, gamma, beta)
w = self.weight * self.weight_gain
# Removing flip weight is not safe.
x = conv2d_resample.conv2d_resample(x, w, self.resample_filter, self.up,
self.down, self.padding, flip_weight=self.up == 1)
if hasattr(self, "norm"):
x = self.norm(x)
b = self.bias * self.bias_gain if self.bias is not None else None
x = bias_act.bias_act(x, b, act=self.activation, gain=self.act_gain, clamp=self.conv_clamp)
return x
def extra_repr(self) -> str:
return ", ".join([f"{key}={item}" for key, item in self.repr.items()])
class Block(torch.nn.Module):
def __init__(self,
in_channels, # Number of input channels, 0 = first block.
out_channels, # Number of output channels.
conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping.
up=1,
down=1,
**layer_kwargs, # Arguments for SynthesisLayer.
):
super().__init__()
self.in_channels = in_channels
self.down = down
self.conv0 = Conv2d(in_channels, out_channels, down=down, conv_clamp=conv_clamp, **layer_kwargs)
self.conv1 = Conv2d(out_channels, out_channels, up=up, conv_clamp=conv_clamp, **layer_kwargs)
def forward(self, x, **layer_kwargs):
x = self.conv0(x, **layer_kwargs)
x = self.conv1(x, **layer_kwargs)
return x
class ResidualBlock(torch.nn.Module):
def __init__(self,
in_channels, # Number of input channels, 0 = first block.
out_channels, # Number of output channels.
conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping.
up=1,
down=1,
gain_out=np.sqrt(0.5),
fix_residual: bool = False,
**layer_kwargs, # Arguments for conv layer.
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.down = down
self.conv0 = Conv2d(in_channels, out_channels, down=down, conv_clamp=conv_clamp, **layer_kwargs)
self.conv1 = Conv2d(out_channels, out_channels, up=up, conv_clamp=conv_clamp, gain=gain_out, **layer_kwargs)
self.skip = Conv2d(
in_channels, out_channels, kernel_size=1, bias=False, up=up, down=down,
activation="linear" if fix_residual else "lrelu",
gain=gain_out
)
self.gain_out = gain_out
def forward(self, x, w=None, s=None, **layer_kwargs):
y = self.skip(x)
s_ = next(s) if s is not None else None
x = self.conv0(x, w, s=s_, **layer_kwargs)
s_ = next(s) if s is not None else None
x = self.conv1(x, w, s=s_, **layer_kwargs)
x = y + x
return x
class MinibatchStdLayer(torch.nn.Module):
def __init__(self, group_size, num_channels=1):
super().__init__()
self.group_size = group_size
self.num_channels = num_channels
def forward(self, x):
N, C, H, W = x.shape
with tops.suppress_tracer_warnings(): # as_tensor results are registered as constants
G = torch.min(torch.as_tensor(self.group_size), torch.as_tensor(N)) if self.group_size is not None else N
F = self.num_channels
c = C // F
# [GnFcHW] Split minibatch N into n groups of size G, and channels C into F groups of size c.
y = x.reshape(G, -1, F, c, H, W)
y = y - y.mean(dim=0) # [GnFcHW] Subtract mean over group.
y = y.square().mean(dim=0) # [nFcHW] Calc variance over group.
y = (y + 1e-8).sqrt() # [nFcHW] Calc stddev over group.
y = y.mean(dim=[2, 3, 4]) # [nF] Take average over channels and pixels.
y = y.reshape(-1, F, 1, 1) # [nF11] Add missing dimensions.
y = y.repeat(G, 1, H, W) # [NFHW] Replicate over group and pixels.
x = torch.cat([x, y], dim=1) # [NCHW] Append to input as new channels.
return x
class DiscriminatorEpilogue(torch.nn.Module):
def __init__(self,
in_channels, # Number of input channels.
resolution: List[int], # Resolution of this block.
mbstd_group_size=4, # Group size for the minibatch standard deviation layer, None = entire minibatch.
mbstd_num_channels=1, # Number of features for the minibatch standard deviation layer, 0 = disable.
activation='lrelu', # Activation function: 'relu', 'lrelu', etc.
conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping.
):
super().__init__()
self.in_channels = in_channels
self.resolution = resolution
self.mbstd = MinibatchStdLayer(group_size=mbstd_group_size,
num_channels=mbstd_num_channels) if mbstd_num_channels > 0 else None
self.conv = Conv2d(
in_channels + mbstd_num_channels, in_channels,
kernel_size=3, activation=activation, conv_clamp=conv_clamp)
self.fc = FullyConnectedLayer(in_channels * resolution[0] * resolution[1], in_channels, activation=activation)
self.out = FullyConnectedLayer(in_channels, 1)
def forward(self, x):
tops.assert_shape(x, [None, self.in_channels, *self.resolution]) # [NCHW]
# Main layers.
if self.mbstd is not None:
x = self.mbstd(x)
x = self.conv(x)
x = self.fc(x.flatten(1))
x = self.out(x)
return x