Spaces:
Paused
Paused
import math | |
from dataclasses import dataclass | |
from typing import Union, Tuple, Literal | |
import torch as T | |
import torch.nn as nn | |
from torch.nn.utils.parametrizations import weight_norm | |
from utils import load_ckpt | |
from utils.interp import print_colored | |
from utils import si_module, get_activation | |
# Adapted from https://github.com/facebookresearch/AudioDec | |
def Conv1d1x1(in_channels, out_channels, bias=True): | |
return nn.Conv1d(in_channels, out_channels, kernel_size=1, bias=bias) | |
class NonCausalConv1d(nn.Module): | |
"""1D noncausal convolution w/ 2-sides padding.""" | |
def __init__( | |
self, | |
in_channels, | |
out_channels, | |
kernel_size, | |
stride=1, | |
padding=-1, | |
dilation=1, | |
groups=1, | |
bias=True): | |
super().__init__() | |
self.in_channels = in_channels | |
self.out_channels = out_channels | |
self.kernel_size = kernel_size | |
if padding < 0: | |
padding = (kernel_size - 1) // 2 * dilation | |
self.dilation = dilation | |
self.conv = nn.Conv1d( | |
in_channels=in_channels, | |
out_channels=out_channels, | |
kernel_size=kernel_size, | |
stride=stride, | |
padding=padding, | |
dilation=dilation, | |
groups=groups, | |
bias=bias, | |
) | |
def forward(self, x): | |
""" | |
Args: | |
x (Tensor): Float tensor variable with the shape (B, C, T). | |
Returns: | |
Tensor: Float tensor variable with the shape (B, C, T). | |
""" | |
x = self.conv(x) | |
return x | |
class NonCausalConvTranspose1d(nn.Module): | |
"""1D noncausal transpose convolution.""" | |
def __init__( | |
self, | |
in_channels, | |
out_channels, | |
kernel_size, | |
stride, | |
padding=-1, | |
output_padding=-1, | |
groups=1, | |
bias=True, | |
): | |
super().__init__() | |
if padding < 0: | |
padding = (stride+1) // 2 | |
if output_padding < 0: | |
output_padding = 1 if stride % 2 else 0 | |
self.deconv = nn.ConvTranspose1d( | |
in_channels=in_channels, | |
out_channels=out_channels, | |
kernel_size=kernel_size, | |
stride=stride, | |
padding=padding, | |
output_padding=output_padding, | |
groups=groups, | |
bias=bias, | |
) | |
def forward(self, x): | |
""" | |
Args: | |
x (Tensor): Float tensor variable with the shape (B, C, T). | |
Returns: | |
Tensor: Float tensor variable with the shape (B, C', T'). | |
""" | |
x = self.deconv(x) | |
return x | |
class CausalConv1d(NonCausalConv1d): | |
def __init__( | |
self, | |
in_channels, | |
out_channels, | |
kernel_size, | |
stride=1, | |
dilation=1, | |
groups=1, | |
bias=True | |
): | |
super(CausalConv1d, self).__init__( | |
in_channels=in_channels, | |
out_channels=out_channels, | |
kernel_size=kernel_size, | |
stride=stride, | |
padding=0, | |
dilation=dilation, | |
groups=groups, | |
bias=bias, | |
) | |
self.stride = stride | |
self.pad_length = (kernel_size - 1) * dilation | |
def forward(self, x): | |
pad = nn.ConstantPad1d((self.pad_length, 0), 0.0) | |
x = pad(x) | |
return self.conv(x) | |
class CausalConvTranspose1d(NonCausalConvTranspose1d): | |
def __init__( | |
self, | |
in_channels, | |
out_channels, | |
kernel_size, | |
stride, | |
bias=True, | |
pad_buffer=None, | |
): | |
super(CausalConvTranspose1d, self).__init__( | |
in_channels=in_channels, | |
out_channels=out_channels, | |
kernel_size=kernel_size, | |
stride=stride, | |
padding=0, | |
output_padding=0, | |
bias=bias, | |
) | |
self.stride = stride | |
self.pad_length = (math.ceil(kernel_size/stride) - 1) | |
if pad_buffer is None: | |
pad_buffer = T.zeros(1, in_channels, self.pad_length) | |
self.register_buffer("pad_buffer", pad_buffer) | |
def forward(self, x): | |
pad = nn.ReplicationPad1d((self.pad_length, 0)) | |
x = pad(x) | |
return self.deconv(x)[:, :, self.stride : -self.stride] | |
def inference(self, x): | |
x = T.cat((self.pad_buffer, x), -1) | |
self.pad_buffer = x[:, :, -self.pad_length:] | |
return self.deconv(x)[:, :, self.stride : -self.stride] | |
def reset_buffer(self): | |
self.pad_buffer.zero_() | |
class NonCausalResUnit(nn.Module): | |
def __init__( | |
self, | |
in_channels, | |
out_channels, | |
kernel_size=7, | |
dilation=1, | |
bias=False, | |
): | |
super().__init__() | |
self.activation = nn.ELU() | |
self.conv1 = NonCausalConv1d( | |
in_channels=in_channels, | |
out_channels=out_channels, | |
kernel_size=kernel_size, | |
stride=1, | |
dilation=dilation, | |
bias=bias, | |
) | |
self.conv2 = Conv1d1x1(out_channels, out_channels, bias) | |
def forward(self, x): | |
y = self.conv1(self.activation(x)) | |
y = self.conv2(self.activation(y)) | |
return x + y | |
class CausalResUnit(NonCausalResUnit): | |
def __init__( | |
self, | |
in_channels, | |
out_channels, | |
kernel_size=7, | |
dilation=1, | |
bias=False, | |
): | |
super(CausalResUnit, self).__init__( | |
in_channels=in_channels, | |
out_channels=out_channels, | |
kernel_size=kernel_size, | |
dilation=dilation, | |
bias=bias, | |
) | |
self.conv1 = CausalConv1d( | |
in_channels=in_channels, | |
out_channels=out_channels, | |
kernel_size=kernel_size, | |
stride=1, | |
dilation=dilation, | |
bias=bias, | |
) | |
def inference(self, x): | |
y = self.conv1.inference(self.activation(x)) | |
y = self.conv2(self.activation(y)) | |
return x + y | |
class ResNetBlock(nn.Module): | |
def __init__(self, | |
in_channels, | |
out_channels, | |
stride, | |
kernel_size=7, | |
dilations=(1, 3, 9), | |
bias=True, | |
mode='encoder', | |
): | |
super().__init__() | |
assert mode in ('encoder', 'decoder'), f"Mode ({mode}) is not supported!" | |
self.mode = mode | |
self.stride = stride | |
ConvUnit = CausalConv1d if mode == 'encoder' else CausalConvTranspose1d | |
res_channels = in_channels if mode == 'encoder' else out_channels | |
res_units = [CausalResUnit( | |
res_channels, | |
res_channels, | |
kernel_size=kernel_size, | |
dilation=dilation, | |
) for dilation in dilations] | |
if in_channels == out_channels: | |
if mode == 'encoder': | |
self.pool = nn.AvgPool1d(kernel_size=stride, stride=stride) | |
if mode == 'decoder': | |
self.upsample = nn.Upsample(scale_factor=stride, mode='nearest') | |
conv_unit = nn.Conv1d( | |
in_channels=in_channels, | |
out_channels=out_channels, | |
kernel_size=1, | |
bias=bias, | |
) if in_channels != out_channels else nn.Identity() | |
else: | |
conv_unit = ConvUnit( | |
in_channels=in_channels, | |
out_channels=out_channels, | |
kernel_size=(2 * stride), | |
stride=stride, | |
bias=bias, | |
) | |
if mode == 'encoder': | |
if in_channels == out_channels: | |
self.res_block = nn.Sequential(*res_units, self.pool, conv_unit) | |
else: | |
self.res_block = nn.Sequential(*res_units, conv_unit) | |
elif mode == 'decoder': | |
if in_channels == out_channels: | |
self.res_block = nn.Sequential(self.upsample, conv_unit, *res_units) | |
else: | |
self.res_block = nn.Sequential(conv_unit, *res_units) | |
def forward(self, x): | |
out = x | |
for unit in self.res_block: | |
out = unit(out) | |
return out | |
def inference(self, x): | |
for unit in self.res_block: | |
x = unit.inference(x) | |
return x | |
class ResNetStack(nn.Module): | |
""" | |
ResNet encoder or decoder stack. Channel ratios | |
and strides take the default order of from | |
data/io-layer, to the middle of the model. | |
""" | |
class Config: | |
input_channels: int = 1 | |
output_channels: int = 1 | |
encode_channels: int = 32 | |
decode_channel_multiplier: int = 1 | |
latent_dim: int = None | |
kernel_size: int = 7 | |
bias: bool = True | |
channel_ratios: Tuple[int, ...] = (2, 4, 8, 16) | |
strides: Tuple[int, ...] = (3, 4, 5, 5) | |
mode: Literal['encoder', 'decoder'] = 'encoder' | |
def __init__(self, c: Config): | |
super().__init__() | |
assert c.mode in ('encoder', 'decoder'), f"Mode ({c.mode}) is not supported!" | |
self.mode = c.mode | |
assert len(c.channel_ratios) == len(c.strides) | |
channel_ratios = (1,) + c.channel_ratios | |
strides = c.strides | |
self.middle_channels = c.encode_channels * channel_ratios[-1] | |
if c.mode == 'decoder': | |
channel_ratios = tuple(reversed(channel_ratios)) | |
strides = tuple(reversed(strides)) | |
self.multiplier = c.decode_channel_multiplier if c.mode == 'decoder' else 1 | |
res_blocks = [ResNetBlock( | |
c.encode_channels * channel_ratios[s_idx] * self.multiplier, | |
c.encode_channels * channel_ratios[s_idx+1] * self.multiplier, | |
stride, | |
kernel_size=c.kernel_size, | |
bias=c.bias, | |
mode=c.mode, | |
) for s_idx, stride in enumerate(strides)] | |
data_conv = CausalConv1d( | |
in_channels=c.input_channels if c.mode == 'encoder' else c.encode_channels * self.multiplier, | |
out_channels=c.encode_channels if c.mode == 'encoder' else c.output_channels, | |
kernel_size=c.kernel_size, | |
stride=1, | |
bias=False, | |
) | |
if c.mode == 'encoder': | |
self.res_stack = nn.Sequential(data_conv, *res_blocks) | |
elif c.mode == 'decoder': | |
self.res_stack = nn.Sequential(*res_blocks, data_conv) | |
if c.latent_dim is not None: | |
self.latent_proj = Conv1d1x1(self.middle_channels, c.latent_dim, bias=c.bias) if c.mode == 'encoder' else Conv1d1x1(c.latent_dim, self.middle_channels, bias=c.bias) | |
if self.multiplier != 1: | |
self.multiplier_proj = Conv1d1x1(self.middle_channels, self.middle_channels * self.multiplier, bias=c.bias) | |
def forward(self, x, return_feats=False): | |
if self.c.latent_dim is not None and self.mode == 'decoder': | |
x = self.latent_proj(x) | |
if self.multiplier != 1: | |
x = self.multiplier_proj(x) | |
feats = [] | |
for block in self.res_stack: | |
x = block(x) | |
if return_feats: | |
feats.append(x) | |
if self.c.latent_dim is not None and self.mode == 'encoder': | |
x = self.latent_proj(x) | |
if return_feats: | |
feats.append(x) | |
if return_feats: | |
return feats | |
return x | |
def inference(self, x): | |
for block in self.res_stack: | |
x = block.inference(x) | |
return x | |
def reset_buffer(self): | |
def _reset_buffer(m): | |
if isinstance(m, CausalConv1d) or isinstance(m, CausalConvTranspose1d): | |
m.reset_buffer() | |
self.apply(_reset_buffer) | |
def reset_parameters(self): | |
def _reset_parameters(m): | |
if isinstance(m, (nn.Conv1d, nn.ConvTranspose1d)): | |
m.weight.data.normal_(0.0, 0.01) | |
self.apply(_reset_parameters) | |
def apply_weight_norm(self): | |
def _apply_weight_norm(m): | |
if isinstance(m, nn.Conv1d) or isinstance( | |
m, nn.ConvTranspose1d | |
): | |
nn.utils.parametrizations.weight_norm(m) | |
self.apply(_apply_weight_norm) | |
def remove_weight_norm(self): | |
def _remove_weight_norm(m): | |
try: | |
print(m) | |
nn.utils.remove_weight_norm(m) | |
except ValueError: # this module didn't have weight norm | |
return | |
self.apply(_remove_weight_norm) | |
class GaussianZ(nn.Module): | |
class Config: | |
dim: int | |
latent_dim: int | |
bias: bool = False | |
use_weight_norm: bool = False | |
def __init__(self, c: Config): | |
super().__init__() | |
self.proj_in = nn.Linear(c.dim, c.latent_dim * 2, bias=c.bias) | |
self.proj_out = nn.Linear(c.latent_dim, c.dim, bias=c.bias) | |
if c.use_weight_norm: | |
self.proj_in = weight_norm(self.proj_in) | |
self.proj_out = weight_norm(self.proj_out) | |
def reparam(self, mu, logvar): | |
std = T.exp(logvar / 2) | |
eps = T.randn_like(std) | |
return mu + eps * std | |
def kl_divergence(self, mu, logvar): | |
return T.mean(-0.5 * T.sum( | |
1 + logvar - mu.pow(2) - logvar.exp(), | |
dim=(1, 2)) | |
) | |
def repr_from_latent(self, latent: Union[dict, T.Tensor]): | |
if isinstance(latent, T.Tensor): | |
z = latent | |
else: | |
z = self.reparam(latent['mu'], latent['logvar']) | |
l = self.proj_out(z) | |
return l | |
def forward(self, x: T.Tensor) -> Tuple[T.Tensor, dict]: | |
mu, logvar = self.proj_in(x).chunk(2, dim=-1) | |
kl_div = self.kl_divergence(mu, logvar) | |
z = self.reparam(mu, logvar) | |
xhat = self.proj_out(z) | |
latent = {'mu': mu, 'logvar': logvar, 'z': z, 'kl_divergence': kl_div} | |
return xhat, latent | |
class WaveCodec(nn.Module): | |
class Config: | |
resnet_config: ResNetStack.Config = None | |
sample_rate: int = 16_000 | |
use_weight_norm: bool = False | |
compressor_config: dataclass = None | |
norm_stddev: float = 1.0 | |
def __init__(self, c: Config): | |
super().__init__() | |
self.norm_stddev = c.norm_stddev | |
self.encoder = c.resnet_config(mode='encoder') | |
self.sample_rate = c.sample_rate | |
self.total_stride = 1 | |
for stride in c.resnet_config.strides: | |
self.total_stride *= stride | |
self.tokens_per_second = self.sample_rate / self.total_stride | |
self.compressor = c.compressor_config(dim=self.encoder.middle_channels) | |
self.decoder = c.resnet_config(mode='decoder') | |
if c.use_weight_norm: | |
self.encoder.apply_weight_norm() | |
self.decoder.apply_weight_norm() | |
self.encoder.reset_parameters() | |
self.decoder.reset_parameters() | |
def encode(self, data): | |
return self.encoder(data/self.norm_stddev) | |
def decode(self, latent): | |
return self.decoder(latent.transpose(1, 2))*self.norm_stddev | |
def latent_from_data(self, data, get_parameters=False): | |
x = self.encode(data) | |
l_in = x.transpose(1, 2) | |
l, latent = self.compressor(l_in) | |
return latent['z'] if not get_parameters else { | |
'mu': latent['mu'], | |
'logvar': latent['logvar'], | |
'z': latent['z'], | |
} | |
def data_from_latent(self, latent): | |
l = self.compressor.repr_from_latent(latent) | |
x = self.decode(l) | |
return x | |
def process(self, x): | |
return self.latent_from_data(x) | |
def unprocess(self, latent): | |
return self.data_from_latent(latent) | |
def forward(self, audio_input): | |
x = self.encode(audio_input) | |
l_in = x.transpose(1, 2) | |
l, latent = self.compressor(l_in) | |
xhat = self.decode(l) | |
return xhat, latent | |
def make_tokenizer(device='cuda'): | |
generator_config = WaveCodec.Config( | |
resnet_config=ResNetStack.Config( | |
input_channels=1, | |
output_channels=1, | |
encode_channels=16, | |
decode_channel_multiplier=4, | |
kernel_size=7, | |
bias=True, | |
channel_ratios=(4, 8, 16, 16, 16, 16), | |
strides=(2, 2, 4, 5, 5, 5), | |
mode=None, | |
), | |
use_weight_norm=True, | |
compressor_config=GaussianZ.Config( | |
dim=None, | |
latent_dim=32, | |
bias=True, | |
use_weight_norm=True | |
), | |
norm_stddev=0.05, | |
) | |
checkpoint = load_ckpt("inference_apatosaurus_95000", expected_hash="ba876edb97b988e9196e449dd176ca97") | |
tokenizer = generator_config() | |
load_result = tokenizer.load_state_dict(checkpoint, strict=False) | |
print_colored(f"Loaded tokenizer state dict: {load_result}", "grey") | |
tokenizer = tokenizer.eval() | |
# Only convert to bfloat16 if using CUDA | |
if device == 'cuda': | |
tokenizer = tokenizer.bfloat16() | |
tokenizer = tokenizer.to(device) | |
tokenizer.requires_grad_ = False | |
return tokenizer | |