import math from dataclasses import dataclass from typing import Union, Tuple, Literal import torch as T import torch.nn as nn from torch.nn.utils.parametrizations import weight_norm from utils import load_ckpt from utils.interp import print_colored from utils import si_module, get_activation # Adapted from https://github.com/facebookresearch/AudioDec def Conv1d1x1(in_channels, out_channels, bias=True): return nn.Conv1d(in_channels, out_channels, kernel_size=1, bias=bias) class NonCausalConv1d(nn.Module): """1D noncausal convolution w/ 2-sides padding.""" def __init__( self, in_channels, out_channels, kernel_size, stride=1, padding=-1, dilation=1, groups=1, bias=True): super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = kernel_size if padding < 0: padding = (kernel_size - 1) // 2 * dilation self.dilation = dilation self.conv = nn.Conv1d( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias, ) def forward(self, x): """ Args: x (Tensor): Float tensor variable with the shape (B, C, T). Returns: Tensor: Float tensor variable with the shape (B, C, T). """ x = self.conv(x) return x class NonCausalConvTranspose1d(nn.Module): """1D noncausal transpose convolution.""" def __init__( self, in_channels, out_channels, kernel_size, stride, padding=-1, output_padding=-1, groups=1, bias=True, ): super().__init__() if padding < 0: padding = (stride+1) // 2 if output_padding < 0: output_padding = 1 if stride % 2 else 0 self.deconv = nn.ConvTranspose1d( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, output_padding=output_padding, groups=groups, bias=bias, ) def forward(self, x): """ Args: x (Tensor): Float tensor variable with the shape (B, C, T). Returns: Tensor: Float tensor variable with the shape (B, C', T'). """ x = self.deconv(x) return x class CausalConv1d(NonCausalConv1d): def __init__( self, in_channels, out_channels, kernel_size, stride=1, dilation=1, groups=1, bias=True ): super(CausalConv1d, self).__init__( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=0, dilation=dilation, groups=groups, bias=bias, ) self.stride = stride self.pad_length = (kernel_size - 1) * dilation def forward(self, x): pad = nn.ConstantPad1d((self.pad_length, 0), 0.0) x = pad(x) return self.conv(x) class CausalConvTranspose1d(NonCausalConvTranspose1d): def __init__( self, in_channels, out_channels, kernel_size, stride, bias=True, pad_buffer=None, ): super(CausalConvTranspose1d, self).__init__( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=0, output_padding=0, bias=bias, ) self.stride = stride self.pad_length = (math.ceil(kernel_size/stride) - 1) if pad_buffer is None: pad_buffer = T.zeros(1, in_channels, self.pad_length) self.register_buffer("pad_buffer", pad_buffer) def forward(self, x): pad = nn.ReplicationPad1d((self.pad_length, 0)) x = pad(x) return self.deconv(x)[:, :, self.stride : -self.stride] def inference(self, x): x = T.cat((self.pad_buffer, x), -1) self.pad_buffer = x[:, :, -self.pad_length:] return self.deconv(x)[:, :, self.stride : -self.stride] def reset_buffer(self): self.pad_buffer.zero_() class NonCausalResUnit(nn.Module): def __init__( self, in_channels, out_channels, kernel_size=7, dilation=1, bias=False, ): super().__init__() self.activation = nn.ELU() self.conv1 = NonCausalConv1d( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=1, dilation=dilation, bias=bias, ) self.conv2 = Conv1d1x1(out_channels, out_channels, bias) def forward(self, x): y = self.conv1(self.activation(x)) y = self.conv2(self.activation(y)) return x + y class CausalResUnit(NonCausalResUnit): def __init__( self, in_channels, out_channels, kernel_size=7, dilation=1, bias=False, ): super(CausalResUnit, self).__init__( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, dilation=dilation, bias=bias, ) self.conv1 = CausalConv1d( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=1, dilation=dilation, bias=bias, ) def inference(self, x): y = self.conv1.inference(self.activation(x)) y = self.conv2(self.activation(y)) return x + y class ResNetBlock(nn.Module): def __init__(self, in_channels, out_channels, stride, kernel_size=7, dilations=(1, 3, 9), bias=True, mode='encoder', ): super().__init__() assert mode in ('encoder', 'decoder'), f"Mode ({mode}) is not supported!" self.mode = mode self.stride = stride ConvUnit = CausalConv1d if mode == 'encoder' else CausalConvTranspose1d res_channels = in_channels if mode == 'encoder' else out_channels res_units = [CausalResUnit( res_channels, res_channels, kernel_size=kernel_size, dilation=dilation, ) for dilation in dilations] if in_channels == out_channels: if mode == 'encoder': self.pool = nn.AvgPool1d(kernel_size=stride, stride=stride) if mode == 'decoder': self.upsample = nn.Upsample(scale_factor=stride, mode='nearest') conv_unit = nn.Conv1d( in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias, ) if in_channels != out_channels else nn.Identity() else: conv_unit = ConvUnit( in_channels=in_channels, out_channels=out_channels, kernel_size=(2 * stride), stride=stride, bias=bias, ) if mode == 'encoder': if in_channels == out_channels: self.res_block = nn.Sequential(*res_units, self.pool, conv_unit) else: self.res_block = nn.Sequential(*res_units, conv_unit) elif mode == 'decoder': if in_channels == out_channels: self.res_block = nn.Sequential(self.upsample, conv_unit, *res_units) else: self.res_block = nn.Sequential(conv_unit, *res_units) def forward(self, x): out = x for unit in self.res_block: out = unit(out) return out def inference(self, x): for unit in self.res_block: x = unit.inference(x) return x @si_module class ResNetStack(nn.Module): """ ResNet encoder or decoder stack. Channel ratios and strides take the default order of from data/io-layer, to the middle of the model. """ class Config: input_channels: int = 1 output_channels: int = 1 encode_channels: int = 32 decode_channel_multiplier: int = 1 latent_dim: int = None kernel_size: int = 7 bias: bool = True channel_ratios: Tuple[int, ...] = (2, 4, 8, 16) strides: Tuple[int, ...] = (3, 4, 5, 5) mode: Literal['encoder', 'decoder'] = 'encoder' def __init__(self, c: Config): super().__init__() assert c.mode in ('encoder', 'decoder'), f"Mode ({c.mode}) is not supported!" self.mode = c.mode assert len(c.channel_ratios) == len(c.strides) channel_ratios = (1,) + c.channel_ratios strides = c.strides self.middle_channels = c.encode_channels * channel_ratios[-1] if c.mode == 'decoder': channel_ratios = tuple(reversed(channel_ratios)) strides = tuple(reversed(strides)) self.multiplier = c.decode_channel_multiplier if c.mode == 'decoder' else 1 res_blocks = [ResNetBlock( c.encode_channels * channel_ratios[s_idx] * self.multiplier, c.encode_channels * channel_ratios[s_idx+1] * self.multiplier, stride, kernel_size=c.kernel_size, bias=c.bias, mode=c.mode, ) for s_idx, stride in enumerate(strides)] data_conv = CausalConv1d( in_channels=c.input_channels if c.mode == 'encoder' else c.encode_channels * self.multiplier, out_channels=c.encode_channels if c.mode == 'encoder' else c.output_channels, kernel_size=c.kernel_size, stride=1, bias=False, ) if c.mode == 'encoder': self.res_stack = nn.Sequential(data_conv, *res_blocks) elif c.mode == 'decoder': self.res_stack = nn.Sequential(*res_blocks, data_conv) if c.latent_dim is not None: self.latent_proj = Conv1d1x1(self.middle_channels, c.latent_dim, bias=c.bias) if c.mode == 'encoder' else Conv1d1x1(c.latent_dim, self.middle_channels, bias=c.bias) if self.multiplier != 1: self.multiplier_proj = Conv1d1x1(self.middle_channels, self.middle_channels * self.multiplier, bias=c.bias) def forward(self, x, return_feats=False): if self.c.latent_dim is not None and self.mode == 'decoder': x = self.latent_proj(x) if self.multiplier != 1: x = self.multiplier_proj(x) feats = [] for block in self.res_stack: x = block(x) if return_feats: feats.append(x) if self.c.latent_dim is not None and self.mode == 'encoder': x = self.latent_proj(x) if return_feats: feats.append(x) if return_feats: return feats return x def inference(self, x): for block in self.res_stack: x = block.inference(x) return x def reset_buffer(self): def _reset_buffer(m): if isinstance(m, CausalConv1d) or isinstance(m, CausalConvTranspose1d): m.reset_buffer() self.apply(_reset_buffer) def reset_parameters(self): def _reset_parameters(m): if isinstance(m, (nn.Conv1d, nn.ConvTranspose1d)): m.weight.data.normal_(0.0, 0.01) self.apply(_reset_parameters) def apply_weight_norm(self): def _apply_weight_norm(m): if isinstance(m, nn.Conv1d) or isinstance( m, nn.ConvTranspose1d ): nn.utils.parametrizations.weight_norm(m) self.apply(_apply_weight_norm) def remove_weight_norm(self): def _remove_weight_norm(m): try: print(m) nn.utils.remove_weight_norm(m) except ValueError: # this module didn't have weight norm return self.apply(_remove_weight_norm) @si_module class GaussianZ(nn.Module): class Config: dim: int latent_dim: int bias: bool = False use_weight_norm: bool = False def __init__(self, c: Config): super().__init__() self.proj_in = nn.Linear(c.dim, c.latent_dim * 2, bias=c.bias) self.proj_out = nn.Linear(c.latent_dim, c.dim, bias=c.bias) if c.use_weight_norm: self.proj_in = weight_norm(self.proj_in) self.proj_out = weight_norm(self.proj_out) def reparam(self, mu, logvar): std = T.exp(logvar / 2) eps = T.randn_like(std) return mu + eps * std def kl_divergence(self, mu, logvar): return T.mean(-0.5 * T.sum( 1 + logvar - mu.pow(2) - logvar.exp(), dim=(1, 2)) ) def repr_from_latent(self, latent: Union[dict, T.Tensor]): if isinstance(latent, T.Tensor): z = latent else: z = self.reparam(latent['mu'], latent['logvar']) l = self.proj_out(z) return l def forward(self, x: T.Tensor) -> Tuple[T.Tensor, dict]: mu, logvar = self.proj_in(x).chunk(2, dim=-1) kl_div = self.kl_divergence(mu, logvar) z = self.reparam(mu, logvar) xhat = self.proj_out(z) latent = {'mu': mu, 'logvar': logvar, 'z': z, 'kl_divergence': kl_div} return xhat, latent @si_module class WaveCodec(nn.Module): class Config: resnet_config: ResNetStack.Config = None sample_rate: int = 16_000 use_weight_norm: bool = False compressor_config: dataclass = None norm_stddev: float = 1.0 def __init__(self, c: Config): super().__init__() self.norm_stddev = c.norm_stddev self.encoder = c.resnet_config(mode='encoder') self.sample_rate = c.sample_rate self.total_stride = 1 for stride in c.resnet_config.strides: self.total_stride *= stride self.tokens_per_second = self.sample_rate / self.total_stride self.compressor = c.compressor_config(dim=self.encoder.middle_channels) self.decoder = c.resnet_config(mode='decoder') if c.use_weight_norm: self.encoder.apply_weight_norm() self.decoder.apply_weight_norm() self.encoder.reset_parameters() self.decoder.reset_parameters() def encode(self, data): return self.encoder(data/self.norm_stddev) def decode(self, latent): return self.decoder(latent.transpose(1, 2))*self.norm_stddev @T.no_grad() def latent_from_data(self, data, get_parameters=False): x = self.encode(data) l_in = x.transpose(1, 2) l, latent = self.compressor(l_in) return latent['z'] if not get_parameters else { 'mu': latent['mu'], 'logvar': latent['logvar'], 'z': latent['z'], } @T.no_grad() def data_from_latent(self, latent): l = self.compressor.repr_from_latent(latent) x = self.decode(l) return x def process(self, x): return self.latent_from_data(x) def unprocess(self, latent): return self.data_from_latent(latent) def forward(self, audio_input): x = self.encode(audio_input) l_in = x.transpose(1, 2) l, latent = self.compressor(l_in) xhat = self.decode(l) return xhat, latent def make_tokenizer(device='cuda'): generator_config = WaveCodec.Config( resnet_config=ResNetStack.Config( input_channels=1, output_channels=1, encode_channels=16, decode_channel_multiplier=4, kernel_size=7, bias=True, channel_ratios=(4, 8, 16, 16, 16, 16), strides=(2, 2, 4, 5, 5, 5), mode=None, ), use_weight_norm=True, compressor_config=GaussianZ.Config( dim=None, latent_dim=32, bias=True, use_weight_norm=True ), norm_stddev=0.05, ) checkpoint = load_ckpt("inference_apatosaurus_95000", expected_hash="ba876edb97b988e9196e449dd176ca97") tokenizer = generator_config() load_result = tokenizer.load_state_dict(checkpoint, strict=False) print_colored(f"Loaded tokenizer state dict: {load_result}", "grey") tokenizer = tokenizer.eval() # Only convert to bfloat16 if using CUDA if device == 'cuda': tokenizer = tokenizer.bfloat16() tokenizer = tokenizer.to(device) tokenizer.requires_grad_ = False return tokenizer