diff --git "a/speech_conformer_encoder.py" "b/speech_conformer_encoder.py" deleted file mode 100644--- "a/speech_conformer_encoder.py" +++ /dev/null @@ -1,2905 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -#!/usr/bin/env python3 - -# activation_checkpointing.py -"""helper function for activation checkpointing""" - -from typing import Union, Dict, Callable -from functools import partial -from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( - checkpoint_wrapper, - offload_wrapper, - CheckpointImpl, -) - - -# utils.py -"""cascade basic blocks""" - -import math -import backoff -import random -import numpy as np -from typing import Optional, Tuple, Union -import torch -from torch import nn -from torch import Tensor -import torch.nn.functional as F - - -# conformer_encoder.py -"""ConformerEncoder Module""" - -from typing import Optional, Tuple, List, Literal -import abc -import math -import numpy as np - -import torch -from torch import nn, Tensor - -from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import CheckpointWrapper -from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel - - -# activation_checkpointing.py -def validate_checkpointing_config(activation_checkpointing): - """validate activation checkpointing configuration""" - if isinstance(activation_checkpointing, str): - assert activation_checkpointing in ( - "", - "checkpoint", - "offload", - ), "activation_checkpointing has to be a dict or a str in ('', 'checkpoint', 'offload')." - elif isinstance(activation_checkpointing, dict): - assert activation_checkpointing.get("module", "transformer") in ( - "transformer", - "attention", - ), "module in activation_checkpointing has to be in ('transformer', 'attention')." - else: - raise ValueError("activation_checkpointing has to be a str or dict.") - - -def embedding_checkpoint_wrapper( - activation_checkpointing: Union[str, Dict], -) -> Callable: - """return encoder embedding activation checkpoint wrapper""" - validate_checkpointing_config(activation_checkpointing) - - if isinstance(activation_checkpointing, str): - if activation_checkpointing: - if activation_checkpointing == "offload": - return offload_wrapper - return partial(checkpoint_wrapper) - return lambda x: x - - if isinstance(activation_checkpointing, dict): - enabled = activation_checkpointing.get("embed", False) - if enabled: - offloading = activation_checkpointing.get("offload", False) - if offloading: - return offload_wrapper - impl = ( - CheckpointImpl.REENTRANT - if activation_checkpointing.get("reentrant", False) - else CheckpointImpl.NO_REENTRANT - ) - return partial(checkpoint_wrapper, checkpoint_impl=impl) - return lambda x: x - raise ValueError("Invalid activation_checkpointing config") - - -def encoder_checkpoint_wrapper( - activation_checkpointing: Union[str, Dict], - layer_cls: type, - idx: int = 0, -) -> Callable: - """return encoder activation checkpoint wrapper""" - validate_checkpointing_config(activation_checkpointing) - - if isinstance(activation_checkpointing, str): - if activation_checkpointing: - if activation_checkpointing == "offload": - return offload_wrapper - return partial(checkpoint_wrapper) - return lambda x: x - - if isinstance(activation_checkpointing, dict): - target_layer_cls = activation_checkpointing.get("module", "transformer") - if target_layer_cls.lower() == "transformer": - target_layer_cls = ( - "EncoderLayer", - "ConformerEncoderLayer", - ) - elif target_layer_cls.lower() == "attention": - target_layer_cls = ("MultiHeadedAttention", "MultiHeadAttention") - checkpointing_interval = activation_checkpointing.get("interval", 1) - offloading = activation_checkpointing.get("offload", False) - impl = ( - CheckpointImpl.REENTRANT - if activation_checkpointing.get("reentrant", True) - else CheckpointImpl.NO_REENTRANT - ) - - if idx % checkpointing_interval == 0 and layer_cls.__name__ in target_layer_cls: - if offloading: - return offload_wrapper - return partial(checkpoint_wrapper, checkpoint_impl=impl) - return lambda x: x - - raise ValueError("Invalid activation_checkpointing config") - - -def attn_checkpointing(activation_checkpointing: Union[str, Dict], i) -> Union[str, Dict]: - """return activation checkpointing config for attention layer""" - if isinstance(activation_checkpointing, str): - return "" - - if isinstance(activation_checkpointing, dict): - target_layer_cls = activation_checkpointing.get("module", "transformer") - checkpointing_interval = activation_checkpointing.get("interval", 1) - if target_layer_cls == "attention" and i % checkpointing_interval == 0: - return activation_checkpointing - return "" - - raise ValueError("Invalid activation_checkpointing config") - - -# utils.py -class Block(nn.Module): - """Block abstract module""" - - def __init__(self, input_size, output_size): - super().__init__() - self.input_size = input_size - self.output_size = output_size - -def get_activation(name="relu"): - """Select an activation function by name - - Args: - name: str - activation function name, - one of ["relu", "gelu", "swish", "sigmoid"], - default "relu". - """ - name = name.lower() - if name == "relu": - return nn.ReLU(inplace=True) - if name == "gelu": - return nn.GELU() - if name == "swish": - return Swish() - if name == "sigmoid": - return torch.nn.Sigmoid() - return nn.Identity() - -def adaptive_enc_mask(x_len, chunk_start_idx, left_window=0, right_window=0): - """ - The function is very important for Transformer Transducer Streaming mode - Args: - xs_len (int): sequence length - chunk_start_idx (list): first idx of each chunk, such as [0,18,36,48]. It also supports adaptive chunk size [0,10,15,45] - left_window (int): how many left chunks can be seen - right_window (int): how many right chunks can be seen. It is used for chunk overlap model. - Returns: - mask (torch.Tensor): a mask tensor for streaming model - Torch 1.0.1 - tensor([[1., 1., 0., 0.], - [0., 1., 1., 0.], - [0., 0., 1., 1.]]) - Torch 1.4.1 - tensor([[True., True., False., False.], - [False., True., True., False.], - [False., False., True., True.]]) - """ - chunk_start_idx = torch.Tensor( - chunk_start_idx - ).long() # first idx of each chunk, such as [0,18,36,48]. - start_pad = torch.nn.functional.pad( - chunk_start_idx, (1, 0) - ) # append 0 to the beginning, so it becomes [0, 0, 18, 36, 48] - end_pad = torch.nn.functional.pad( - chunk_start_idx, (0, 1), value=x_len - ) # append x_len to the end, so it becomes [0,18,36,48, x_len] - seq_range = torch.arange(0, x_len).unsqueeze(-1) # seq_range size: [x_len, 1] - idx = ((seq_range < end_pad) & (seq_range >= start_pad)).nonzero()[:, 1] # idx size: [x_len] - boundary = end_pad[idx] # boundary size: [x_len] - seq_range_expand = ( - torch.arange(0, x_len).unsqueeze(0).expand(x_len, -1) - ) # seq_range_expand size [x_len, x_len] - idx_left = idx - left_window - idx_left[idx_left < 0] = 0 - boundary_left = start_pad[idx_left] - mask_left = seq_range_expand >= boundary_left.unsqueeze(-1) - idx_right = idx + right_window - idx_right[idx_right > len(chunk_start_idx)] = len(chunk_start_idx) - boundary_right = end_pad[idx_right] - mask_right = seq_range_expand < boundary_right.unsqueeze(-1) - return mask_left & mask_right - -class Swish(nn.Module): - """Implement Swish activation module. - From https://arxiv.org/pdf/2005.03191.pdf - - """ - - def __init__(self) -> None: - super().__init__() - self.act_fn = nn.Sigmoid() - - def forward(self, x: Tensor) -> Tensor: - """Apply Swish function - - Args: - x: torch.Tensor - Input. - """ - return x * self.act_fn(x) - -class GLU(nn.Module): - """Implement Gated Linear Unit (GLU) module""" - - def __init__(self, dim: int = -1, act_name: str = "sigmoid") -> None: - super().__init__() - self.dim = dim - self.act_name = act_name.lower() - - if self.act_name == "relu": - self.act_fn = nn.ReLU(inplace=True) - elif self.act_name == "gelu": - self.act_fn = nn.GELU() - elif self.act_name == "swish": - self.act_fn = Swish() - elif self.act_name == "sigmoid": - self.act_fn = nn.Sigmoid() - else: - self.act_fn = nn.Identity() - - def forward(self, x: Tensor) -> Tensor: - """GLU forward - Apply Swish function on the first half of input matrices - with sigmoid of the second half. - - Args: - x: torch.Tensor - Input. - - """ - half_x, gate = x.chunk(2, dim=self.dim) - return half_x * self.act_fn(gate) - -# TODO: Abdel, this can be improved using GLU module -class GLUPointWiseConv(nn.Module): - """GLUPointWiseConv module - used for conformer architecture, - for more details see: - https://arxiv.org/pdf/2005.08100v1.pdf - - Args: - input_dim: int - input channel size. - output_dim: int - output channel size. - kernel_size: int - kernel size - glu_type: str, optional - activation function one of - ["sigmoid", "relu", "gelu"] - default "sigmoid". - bias_in_glu: bool, optional - use addtive bias in glu - causal: bool, optional - if set to True, padding is set to the half of - kernel size, ie, convolution can't see future frames. - default False. - - """ - - def __init__( - self, input_dim, output_dim, kernel_size, glu_type="sigmoid", bias_in_glu=True, causal=False - ): - super().__init__() - - self.glu_type = glu_type - self.output_dim = output_dim - self.bias_in_glu = bias_in_glu - if causal: - self.ext_pw_conv_1d = nn.Conv1d( - input_dim, output_dim * 2, kernel_size, 1, padding=(kernel_size - 1) - ) - else: - self.ext_pw_conv_1d = nn.Conv1d( - input_dim, output_dim * 2, kernel_size, 1, padding=(kernel_size - 1) // 2 - ) - - if glu_type == "sigmoid": - self.glu_act = nn.Sigmoid() - elif glu_type == "relu": - self.glu_act = nn.ReLU() - elif glu_type == "gelu": - self.glu_act = nn.GELU() - elif glu_type == "swish": - self.glu_act = Swish() - else: - raise ValueError(f"Unsupported activation type {self.glu_act}") - - if bias_in_glu: - self.b1 = nn.Parameter(torch.zeros(1, output_dim, 1)) - self.b2 = nn.Parameter(torch.zeros(1, output_dim, 1)) - - def forward(self, x): - """ - Args: - x: torch.Tensor - input tensor - """ - # to be consistent with GLULinear, we assume the input always has the #channel (#dim) in the last dimension of the tensor, so need to switch the dimension first for 1D-Conv case - x = x.permute([0, 2, 1]) - x = self.ext_pw_conv_1d(x) - if self.glu_type == "bilinear": - if self.bias_in_glu: - x = (x[:, 0 : self.output_dim, :] + self.b1) * ( - x[:, self.output_dim : self.output_dim * 2, :] + self.b2 - ) - else: - x = (x[:, 0 : self.output_dim, :]) * ( - x[:, self.output_dim : self.output_dim * 2, :] - ) - else: - if self.bias_in_glu: - x = (x[:, 0 : self.output_dim, :] + self.b1) * self.glu_act( - x[:, self.output_dim : self.output_dim * 2, :] + self.b2 - ) - else: - x = (x[:, 0 : self.output_dim, :]) * self.glu_act( - x[:, self.output_dim : self.output_dim * 2, :] - ) - - x = x.permute([0, 2, 1]) - return x - - -class DepthWiseSeperableConv1d(nn.Module): - """DepthWiseSeperableConv1d module used in Convnet module - for the conformer, for more details see: - https://arxiv.org/pdf/2005.08100v1.pdf - - Args: - input_dim: int - input channel size. - depthwise_seperable_out_channel: int - if set different to 0, the number of depthwise_seperable_out_channel - will be used as a channel_out of the second conv1d layer. - otherwise, it equal to 0, the second conv1d layer is skipped. - kernel_size: int - kernel_size - depthwise_multiplier: int - number of input_dim channels duplication. this value - will be used to compute the hidden channels of the Conv1D. - padding: int, optional - padding for the conv1d, - default: 0. - - """ - - def __init__( - self, - input_dim, - depthwise_seperable_out_channel, - kernel_size, - depthwise_multiplier, - padding=0, - ): - super().__init__() - - self.dw_conv = nn.Conv1d( - input_dim, - input_dim * depthwise_multiplier, - kernel_size, - 1, - padding=padding, - groups=input_dim, - ) - - if depthwise_seperable_out_channel != 0: - self.pw_conv = nn.Conv1d( - input_dim * depthwise_multiplier, depthwise_seperable_out_channel, 1, 1, 0 - ) - else: - self.pw_conv = nn.Identity() - self.depthwise_seperable_out_channel = depthwise_seperable_out_channel - - def forward(self, x): - """ - - Args: - x: torch.Tensor - input tensor - """ - x = self.dw_conv(x) - if self.depthwise_seperable_out_channel != 0: - x = self.pw_conv(x) - return x - - -class ConvModule(nn.Module): - """ConvModule Module for the conformer block. - for more details see: - https://arxiv.org/pdf/2005.08100v1.pdf - - Args: - input_dim: int - input channel size. - ext_pw_out_channel: int - if > 0, ext_pw_out_channel is a dim channel size - for the last pointwise conv after swish activation. - depthwise_seperable_out_channel: int - if set different to 0, the number of depthwise_seperable_out_channel - will be used as a channel_out of the second conv1d layer. - otherwise, it equal to 0, the second conv1d layer is skipped. - ext_pw_kernel_size: int - kernel size of the conv pointwise of the conformer. - kernel_size: int - kernel size. - depthwise_multiplier: int - number of input_dim channels duplication. this value - will be used to compute the hidden channels of the Conv1D. - dropout_rate: float - dropout rate. - causal: bool, optional - if set to True, convolution have no access - to future frames. default False. - batch_norm: bool, optional - if set to True, apply batchnorm before activation. - default False - chunk_se: int, optional - 0 for offline SE. - 1 for streaming SE, where mean is computed - by accumulated history until current chunk_se. - 2 for streaming SE, where mean is computed - by only the current chunk. - chunk_size: int, optional - chunk size for cnn. default 18 - activation: str, optional - activation function used in ConvModule, - default: "relu". - glu_type: str, optional - activation function used for the glu, - default: "sigmoid". - bias_in_glu: bool, optional - if set to True, use additive bias in the weight module - before GLU. - linear_glu_in_convm: bool, optional - if set to True, use GLULinear module, - otherwise, used GLUPointWiseConv module. - default to False. - export: bool, optional, - if set to True, padding is equal to 0. This is for inference, - or onnx export. Typically this is set by the export program or - the decoder program, and it isn't present in your config file. - default False - """ - - def __init__( - self, - input_dim, - ext_pw_out_channel, - depthwise_seperable_out_channel, - ext_pw_kernel_size, - kernel_size, - depthwise_multiplier, - dropout_rate, - causal=False, - batch_norm=False, - chunk_se=0, - chunk_size=18, - activation="relu", - glu_type="sigmoid", - bias_in_glu=True, - linear_glu_in_convm=False, - export=False, - ): - super().__init__() - self.layer_norm = nn.LayerNorm(input_dim) - self.input_dim = input_dim - self.ext_pw_out_channel = ext_pw_out_channel - self.ext_pw_kernel_size = ext_pw_kernel_size - self.depthwise_seperable_out_channel = depthwise_seperable_out_channel - self.glu_type = glu_type - self.bias_in_glu = bias_in_glu - self.linear_glu_in_convm = linear_glu_in_convm - self.causal = causal - - self._add_ext_pw_layer() - - self.batch_norm = batch_norm - self.kernel_size = kernel_size - - if batch_norm: - self.bn_layer = nn.BatchNorm1d(input_dim) - - self.act = get_activation(activation) - self.dropout = nn.Dropout(dropout_rate) - self.export = export - - if causal: - if export: # Inference only. - padding = 0 # A cache is concatenated to the left. No padding in the kernel. - else: - # Training only. Padding will be added symmetrically on both sides. - # After convolution, clip off kernel_size-1 points on the right. - padding = kernel_size - 1 - else: - padding = (kernel_size - 1) // 2 - - self.dw_sep_conv_1d = DepthWiseSeperableConv1d( - input_dim, - depthwise_seperable_out_channel, - kernel_size, - depthwise_multiplier, - padding=padding, - ) - - if depthwise_seperable_out_channel != 0: - if input_dim != depthwise_seperable_out_channel: - self.ln2 = nn.Linear(depthwise_seperable_out_channel, input_dim) - else: - if depthwise_multiplier != 1: - self.ln2 = nn.Linear(input_dim * depthwise_multiplier, input_dim) - - def _add_ext_pw_layer(self): - """ - This function is an extension of __init__ function - and dedicated to the convolution module creation - of the conformer. - """ - self.ln1 = self.glu = self.bn_layer = self.ext_pw_conv_1d = nn.Identity() # jit hacks. - self.squeeze_excitation = nn.Identity() # jit. - self.apply_ln1 = self.fix_len1 = False # jit. - - if self.ext_pw_out_channel != 0: - if self.causal: - self.ext_pw_conv_1d = nn.Conv1d( - self.input_dim, - self.ext_pw_out_channel, - self.ext_pw_kernel_size, - 1, - padding=(self.ext_pw_kernel_size - 1), - ) - if self.ext_pw_kernel_size > 1: - self.fix_len1 = True - else: - self.fix_len1 = False - else: - self.ext_pw_conv_1d = nn.Conv1d( - self.input_dim, - self.ext_pw_out_channel, - self.ext_pw_kernel_size, - 1, - padding=(self.ext_pw_kernel_size - 1) // 2, - ) - self.fix_len1 = False - - if self.linear_glu_in_convm: - self.glu = GLULinear( - self.input_dim, self.ext_pw_out_channel, self.glu_type, self.bias_in_glu - ) - else: - self.glu = GLUPointWiseConv( - self.input_dim, - self.ext_pw_out_channel, - self.ext_pw_kernel_size, - self.glu_type, - self.bias_in_glu, - self.causal, - ) - - if self.input_dim != self.ext_pw_out_channel: - self.apply_ln1 = True - self.ln1 = nn.Linear(self.ext_pw_out_channel, self.input_dim) - else: - self.apply_ln1 = False - else: - self.pw_conv_simplify_w = torch.nn.Parameter(torch.ones(3)) - self.pw_conv_simplify_b = torch.nn.Parameter(torch.zeros(3)) - - def forward(self, x): - """ConvModule Forward. - - Args: - x: torch.Tensor - input tensor. - """ - x = self.layer_norm(x) - - if self.ext_pw_out_channel != 0: - x = self.glu(x) - if self.causal and self.ext_pw_kernel_size > 1: - x = x[:, : -(self.ext_pw_kernel_size - 1), :] - if self.apply_ln1: - x = self.ln1(x) - else: - x_0 = x * self.pw_conv_simplify_w[0] + self.pw_conv_simplify_b[0] - x_1 = x * self.pw_conv_simplify_w[1] + self.pw_conv_simplify_b[1] - x = x_0 + x_1 - - x = x.permute([0, 2, 1]) - - x = self.dw_sep_conv_1d(x) - if self.causal and self.kernel_size > 1: - x = x[:, :, : -(self.kernel_size - 1)] - if hasattr(self, "ln2"): - x = x.permute([0, 2, 1]) - x = self.ln2(x) - x = x.permute([0, 2, 1]) - if self.batch_norm: - x = self.bn_layer(x) - x = self.act(x) - - if self.ext_pw_out_channel != 0: - x = self.ext_pw_conv_1d(x) - if self.fix_len1: - x = x[:, :, : -(self.ext_pw_kernel_size - 1)] - - if self.apply_ln1: - x = x.permute([0, 2, 1]) - x = self.ln1(x) - x = x.permute([0, 2, 1]) - - x = x.permute([0, 2, 1]) - else: - x = x.unsqueeze(1).permute([0, 1, 3, 2]) - x = x * self.pw_conv_simplify_w[2] + self.pw_conv_simplify_b[2] - x = x.squeeze(1) - - x = self.dropout(x) - return x - -class GLULinear(nn.Module): - """Linear + GLU module - - Args: - input_dim: int - input size - output_dim: int - output size. - glu_type: - activation function name used in glu module. - default "sigmoid" (swish function). - bias_in_glu: bool, optional - If True, the addtive bias is added. Default False. - """ - - def __init__( - self, - input_dim, - output_dim, - glu_type="sigmoid", - bias_in_glu=True, - ): - super().__init__() - self.linear = nn.Linear(input_dim, output_dim * 2, bias_in_glu) - self.glu_act = GLU(-1, glu_type) - - def forward(self, x): - """GLULinear forward - - Args: - x: torch.Tensor - inpute tensor. - """ - x = self.linear(x) - return self.glu_act(x) - -class FeedForward(nn.Module): - """FeedForward Module. - For more details see Conformer paper: - https://arxiv.org/pdf/2005.08100.pdf - - Args: - d_model: int - input size. - d_inner: int - output size. - dropout_rate: float, - dropout rate. - activation: str, - activation function name, - one of ["relu", "swish", "sigmoid"], - sigmoid activation is only used with "glu_in_fnn=True", - default "sigmoid". - bias_in_glu: bool, optional - """ - - def __init__( - self, - d_model, - d_inner, - dropout_rate, - activation="sigmoid", - bias_in_glu=True, - ): - super().__init__() - self.d_model = d_model - self.d_inner = d_inner - - self.layer_norm = nn.LayerNorm(d_model) - module = GLULinear(d_model, d_inner, activation, bias_in_glu) - self.net = nn.Sequential( - module, - nn.Dropout(dropout_rate), - nn.Linear(d_inner, d_model), - nn.Dropout(dropout_rate), - ) - - def forward(self, x): - """FeedForward forward function. - - Args: - x: torch.Tensor - input tensor. - """ - out = self.net(self.layer_norm(x)) - - return out - -#### positional encoding starts here -def _pre_hook( - state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs -): - """Perform pre-hook in load_state_dict for backward compatibility. - - Note: - We saved self.pe until v.0.5.2 but we have omitted it later. - Therefore, we remove the item "pe" from `state_dict` for backward compatibility. - - """ - k = prefix + "pe" - if k in state_dict: - state_dict.pop(k) - -class T5RelativeAttentionLogitBias(nn.Module): - """ - This module implements the relative position bias described in Section 2.1 of - the T5 paper: https://arxiv.org/pdf/1910.10683.pdf - - The Huggingface implementation is used as a reference - https://github.com/huggingface/transformers/blob/v4.30.0/src/transformers/models/t5/modeling_t5.py#L435 - - Modifies attention as Q*K^T + B, where B is a learned scalar bias based on relative position - of the query and key. It is HxNxN, where H is the number of heads, N is the sequence length. - - I've made these modifications to the original T5 bias: - - Skipping of the bucketing step. Original T5 bias converted rel position distances into - logarithmically increasing buckets. This is supposed to help with length generalization. - - I just directly use rel position index as bias values, as we don't need length - generalization (40s max is good enough for ASR encoder), and it keeps ONNX export simple. - - I've also extended it so that biases can be asymmetric, the default implementation treats - L->R and R->L the same. Asymmetric was found to yield better results in my experiments. - - Args: - num_heads: int - Number of attention heads - num_buckets: int - Number of buckets to use for relative attention bias. This is the size of the learnable - bias parameter. Bucketing is not yet supported, so this defaults to -1 which means - no bucketing is used (max_distance determines size of bias param). - max_distance: int - Maximum distance to use for relative attention bias. With num_buckets=-1, this directly - controls the max size of the bias parameter. When num_buckets > 0 is supported, this - will control the maximum distance for logarithmic bucketing after which all positions - are in the same bucket. - symmetric: bool - Whether to use symmetric or asymmetric biases. symmetric=False uses 2x number of bias - params to distinguish L->R from R->L. This was found to be better for the encoder. - """ - - def __init__(self, num_heads, num_buckets=-1, max_distance=1000, symmetric=False): - super().__init__() - self.num_heads = num_heads - self.num_buckets = num_buckets - self.max_distance = max_distance - self.symmetric = symmetric - self._skip_bucketing = self.num_buckets < 0 - if self._skip_bucketing: - self.num_buckets = max_distance - else: - raise NotImplementedError("T5 attention bias with bucketed positions is not yet tested") - if not self.symmetric: - self.num_buckets *= 2 - self.bias_values = nn.Embedding(self.num_buckets, self.num_heads) - - def forward(self, x): - # instantiate bias compatible with shape of x - maxpos = x.size(1) - context_position = torch.arange(maxpos, device=x.device, dtype=torch.long)[:, None] - memory_position = torch.arange(maxpos, device=x.device, dtype=torch.long)[None, :] - relative_position = memory_position - context_position - # clipping to a maximum distance using ops that play well with ONNX export - relative_position = relative_position.masked_fill( - relative_position < -self.max_distance, -self.max_distance - ) - relative_position = relative_position.masked_fill( - relative_position > self.max_distance - 1, self.max_distance - 1 - ) - - # mapping from relative position to index in the bias parameter - if self._skip_bucketing: - bias_idx = relative_position - else: - bias_idx = self._bucket_relative_position(relative_position) - if self.symmetric: - bias_idx = bias_idx.abs() - else: - bias_idx += self.num_buckets // 2 - - t5_rel_att_bias = self.bias_values(bias_idx) # [L, L, H] - t5_rel_att_bias = t5_rel_att_bias.permute(2, 0, 1).unsqueeze(0) # [1, H, L, L] - - return t5_rel_att_bias - - def _bucket_relative_position(self, relative_position): - # this is a placeholder (isn't tested, likely buggy) using HuggingFace implem as a reference - # this also needs to be extended to support asymmetric +/- ve positions - relative_buckets = 0 - if not self.causal: - num_buckets //= 2 - relative_buckets += (relative_position > 0).to(torch.long) * num_buckets - relative_position = torch.abs(relative_position) - else: - relative_position = -torch.min(relative_position, torch.zeros_like(relative_position)) - # now relative_position is in the range [0, inf) - - # half of the buckets are for exact increments in positions - max_exact = num_buckets // 2 - is_small = relative_position < max_exact - - # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance - relative_position_if_large = max_exact + ( - torch.log(relative_position.float() / max_exact) - / math.log(self.max_distance / max_exact) - * (num_buckets - max_exact) - ).to(torch.long) - relative_position_if_large = torch.min( - relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1) - ) - - relative_buckets += torch.where(is_small, relative_position, relative_position_if_large) - return relative_buckets - -class AbsolutePositionalEncoding(nn.Module): - """Absolute Positional encoding module. - This module implement Absolute sinusoidal positional encoding - from: https://arxiv.org/pdf/1706.03762.pdf - - Args: - d_model: int - Input embedding size. - dropout_rate: float - dropout rate - max_len: int, optional - Maximum input length sequence, Default 5000 - - """ - - def __init__(self, d_model, dropout_rate, max_len=5000): - """Construct an PositionalEncoding object.""" - super().__init__() - self.d_model = d_model - self.xscale = math.sqrt(self.d_model) - self.dropout = torch.nn.Dropout(p=dropout_rate) - self.pe = None - self.extend_pe(torch.tensor(0.0).expand(1, max_len)) - self._register_load_state_dict_pre_hook(_pre_hook) - - def extend_pe(self, x): - """Reset the positional encodings. - - Args: - x: torch.Tensor - """ - if self.pe is not None: - if self.pe.size(1) >= x.size(1): - if self.pe.dtype != x.dtype or self.pe.device != x.device: - self.pe = self.pe.to(dtype=x.dtype, device=x.device) - return - pe = torch.zeros(x.size(1), self.d_model) - position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) - div_term = torch.exp( - torch.arange(0, self.d_model, 2, dtype=torch.float32) - * -(math.log(10000.0) / self.d_model) - ) - pe[:, 0::2] = torch.sin(position * div_term) - pe[:, 1::2] = torch.cos(position * div_term) - pe = pe.unsqueeze(0) - self.pe = pe.to(device=x.device, dtype=x.dtype) - - def forward(self, x: torch.Tensor): - """Add positional encoding. - - Args: - x: torch.Tensor - Input tensor. shape is (batch, time, ...) - - Returns: - torch.Tensor: Encoded tensor. Its shape is (batch, time, ...) - - """ - self.extend_pe(x) - x = x * self.xscale + self.pe[:, : x.size(1)] - return self.dropout(x) - -#### forward embedding layers starts here - -@backoff.on_exception(backoff.expo, Exception, max_tries=10) -def np_loadtxt_with_retry(filepath): - """np.loadtxt with retry - - Args: - filepath: str - file path to the numpy array. - """ - result = np.loadtxt(filepath, dtype="f") - return result - -class MeanVarianceNormLayer(nn.Module): - """Mean/variance normalization layer. - - Will substract mean and multiply input by inverted standard deviation. - Typically used as a very first layer in a model. - - Args: - input_size: int - layer input size. - """ - - def __init__(self, input_size): - super().__init__() - self.input_size = input_size - self.register_buffer("global_mean", torch.zeros(input_size)) - self.register_buffer("global_invstd", torch.ones(input_size)) - self.global_mean: Optional[Tensor] - self.global_invstd: Optional[Tensor] - - def forward(self, input_: Tensor) -> Tensor: - """MeanVarianceNormLayer Forward - - Args: - input_: torch.Tensor - input tensor. - """ - return (input_ - self.global_mean) * self.global_invstd - - def load_mean_invstd(self, mean_file, invstd_file, cuside_features=False): - """Load feature mean and variance used for normalization. - - Args: - mean_file: str - path to the feature mean statistics file. - invstd_file: str - path to the features inverted standard deviation - statistics file. - cuside_features: bool - Boolean that indicates CUSIDE is being used. - The statistics of CUSIDE features are copied - from the normal features - """ - self.global_mean.data = torch.from_numpy(np_loadtxt_with_retry(mean_file)) - self.global_invstd.data = torch.from_numpy(np_loadtxt_with_retry(invstd_file)) - - if cuside_features: - self.global_mean.data = torch.cat((self.global_mean.data, self.global_mean.data), 0) - self.global_invstd.data = torch.cat( - (self.global_invstd.data, self.global_invstd.data), 0 - ) - -class CausalConv1D(nn.Conv1d): - """ - A causal version of nn.Conv1d where each step would have limited access to locations on its right or left - All arguments are the same as nn.Conv1d except padding. - - If padding is set None, then paddings are set automatically to make it a causal convolution where each location would not see any steps on its right. - - If padding is set as a list (size of 2), then padding[0] would be used as left padding and padding[1] as right padding. - It would make it possible to control the number of steps to be accessible on the right and left. - This mode is not supported when stride > 1. padding[0]+padding[1] should be equal to (kernel_size - 1). - """ - - def __init__( - self, - in_channels: int, - out_channels: int, - kernel_size: int, - stride: int = 1, - padding: Union[str, int] = 0, - dilation: int = 1, - groups: int = 1, - bias: bool = True, - padding_mode: str = "zeros", - device=None, - dtype=None, - ) -> None: - self.cache_drop_size = None - if padding is None: - self._left_padding = kernel_size - 1 - self._right_padding = stride - 1 - else: - if stride != 1 and padding != kernel_size - 1: - raise ValueError("No striding allowed for non-symmetric convolutions!") - if isinstance(padding, int): - self._left_padding = padding - self._right_padding = padding - elif ( - isinstance(padding, list) - and len(padding) == 2 - and padding[0] + padding[1] == kernel_size - 1 - ): - self._left_padding = padding[0] - self._right_padding = padding[1] - else: - raise ValueError(f"Invalid padding param: {padding}!") - - self._max_cache_len = self._left_padding - - super().__init__( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - stride=stride, - padding=0, - dilation=dilation, - groups=groups, - bias=bias, - padding_mode=padding_mode, - device=device, - dtype=dtype, - ) - - def update_cache(self, x, cache=None): - if cache is None: - new_x = F.pad(x, pad=(self._left_padding, self._right_padding)) - next_cache = cache - else: - new_x = F.pad(x, pad=(0, self._right_padding)) - new_x = torch.cat([cache, new_x], dim=-1) - if self.cache_drop_size > 0: - next_cache = new_x[:, :, : -self.cache_drop_size] - else: - next_cache = new_x - next_cache = next_cache[:, :, -cache.size(-1) :] - return new_x, next_cache - - def forward(self, x, cache=None): - x, cache = self.update_cache(x, cache=cache) - x = super().forward(x) - if cache is None: - return x - else: - return x, cache - - -class CausalConv2D(nn.Conv2d): - """ - A causal version of nn.Conv2d where each location in the 2D matrix would have no access to locations on its right or down - All arguments are the same as nn.Conv2d except padding which should be set as None - """ - - def __init__( - self, - in_channels: int, - out_channels: int, - kernel_size: int, - stride: int = 1, - padding: Union[str, int] = 0, - dilation: int = 1, - groups: int = 1, - bias: bool = True, - padding_mode: str = "zeros", - device=None, - dtype=None, - ) -> None: - if padding is not None: - raise ValueError("Argument padding should be set to None for CausalConv2D.") - self._left_padding = kernel_size - 1 - self._right_padding = stride - 1 - - padding = 0 - super().__init__( - in_channels, - out_channels, - kernel_size, - stride, - padding, - dilation, - groups, - bias, - padding_mode, - device, - dtype, - ) - - def forward( - self, - x, - ): - if self.training: - x = F.pad( - x, - pad=( - self._left_padding, - self._right_padding, - self._left_padding, - self._right_padding, - ), - ) - else: - x = F.pad( - x, - pad=(self._left_padding, self._right_padding, 0, 0), - ) - x = super().forward(x) - return x - - -class NemoConvSubsampling(torch.nn.Module): - """Convlutional subsampling module, taken from NeMo ASR - (https://github.com/NVIDIA/NeMo/blob/b367413645d5c72db3c2c96e46e95a34501479cf/nemo/collections/asr/parts/submodules/subsampling.py) - - Striding Subsampling: "Speech-Transformer: A No-Recurrence Sequence-to-Sequence Model for - Speech Recognition" by Linhao Dong et al. (https://ieeexplore.ieee.org/document/8462506) - - - Compared with the EncoderConv2D (`input_layer: custom`), this is a much simplified approach, - and uses no LayerNorm and far fewer Conv2Ds. Moreover, depthwise convolutions are used to reduce - FLOPs, but the first layer is kept as a regular convolution so as not to degrade accuracy. - - `Striding` and `dw_striding` are the same except that the latter uses depthwise convolutions - after the first layer, whereas the former does not. - - Args: - subsampling_factor (int): Time reduction factor - feat_in (int): size of the input features - feat_out (int): size of the output features - subsampling (str): The subsampling technique, choose from - {"striding", "dw-striding", "striding_conv1d", "dw_striding_conv1d"} - conv_channels (int): Number of channels for the convolution layers, default is 256. - subsampling_conv_chunking_factor (int): Input chunking factor which can be -1 (no chunking) - 1 (auto) or a power of 2. Default is 1 - activation (Module): activation function, default is nn.ReLU() - is_causal (bool): whether to use causal Conv1/2D, where each step will have limited access - to locations on its right or left - """ - - def __init__( - self, - feat_in, - feat_out, - subsampling_factor=4, - subsampling="dw_striding", - conv_channels=256, - subsampling_conv_chunking_factor=1, - activation=nn.ReLU(), - is_causal=False, - ): - super().__init__() - self._subsampling = subsampling - self._conv_channels = conv_channels - self._feat_in = feat_in - self._feat_out = feat_out - - if subsampling_factor % 2 != 0: - raise ValueError("Sampling factor should be a multiply of 2!") - self._sampling_num = int(math.log(subsampling_factor, 2)) - self.subsampling_factor = subsampling_factor - self.is_causal = is_causal - self.subsampling_causal_cond = subsampling in ("dw_striding", "striding", "striding_conv1d") - - if ( - subsampling_conv_chunking_factor != -1 - and subsampling_conv_chunking_factor != 1 - and subsampling_conv_chunking_factor % 2 != 0 - ): - raise ValueError("subsampling_conv_chunking_factor should be -1, 1, or a power of 2") - self.subsampling_conv_chunking_factor = subsampling_conv_chunking_factor - - in_channels = 1 - layers = [] - - if subsampling == "dw_striding": - self._stride = 2 - self._kernel_size = 3 - self._ceil_mode = False - - if self.is_causal: - self._left_padding = self._kernel_size - 1 - self._right_padding = self._stride - 1 - self._max_cache_len = subsampling_factor + 1 - else: - self._left_padding = (self._kernel_size - 1) // 2 - self._right_padding = (self._kernel_size - 1) // 2 - self._max_cache_len = 0 - - # Layer 1 - if self.is_causal: - layers.append( - CausalConv2D( - in_channels=in_channels, - out_channels=conv_channels, - kernel_size=self._kernel_size, - stride=self._stride, - padding=None, - ) - ) - else: - layers.append( - torch.nn.Conv2d( - in_channels=in_channels, - out_channels=conv_channels, - kernel_size=self._kernel_size, - stride=self._stride, - padding=self._left_padding, - ) - ) - in_channels = conv_channels - layers.append(activation) - - for i in range(self._sampling_num - 1): - if self.is_causal: - layers.append( - CausalConv2D( - in_channels=in_channels, - out_channels=in_channels, - kernel_size=self._kernel_size, - stride=self._stride, - padding=None, - groups=in_channels, - ) - ) - else: - layers.append( - torch.nn.Conv2d( - in_channels=in_channels, - out_channels=in_channels, - kernel_size=self._kernel_size, - stride=self._stride, - padding=self._left_padding, - groups=in_channels, - ) - ) - - layers.append( - torch.nn.Conv2d( - in_channels=in_channels, - out_channels=conv_channels, - kernel_size=1, - stride=1, - padding=0, - groups=1, - ) - ) - layers.append(activation) - in_channels = conv_channels - - elif subsampling == "striding": - self._stride = 2 - self._kernel_size = 3 - self._ceil_mode = False - - if self.is_causal: - self._left_padding = self._kernel_size - 1 - self._right_padding = self._stride - 1 - self._max_cache_len = subsampling_factor + 1 - else: - self._left_padding = (self._kernel_size - 1) // 2 - self._right_padding = (self._kernel_size - 1) // 2 - self._max_cache_len = 0 - - for i in range(self._sampling_num): - if self.is_causal: - layers.append( - CausalConv2D( - in_channels=in_channels, - out_channels=conv_channels, - kernel_size=self._kernel_size, - stride=self._stride, - padding=None, - ) - ) - else: - layers.append( - torch.nn.Conv2d( - in_channels=in_channels, - out_channels=conv_channels, - kernel_size=self._kernel_size, - stride=self._stride, - padding=self._left_padding, - ) - ) - layers.append(activation) - in_channels = conv_channels - - elif subsampling == "striding_conv1d": - in_channels = feat_in - - self._stride = 2 - self._kernel_size = 5 - self._ceil_mode = False - - if self.is_causal: - self._left_padding = self._kernel_size - 1 - self._right_padding = self._stride - 1 - self._max_cache_len = subsampling_factor + 1 - else: - self._left_padding = (self._kernel_size - 1) // 2 - self._right_padding = (self._kernel_size - 1) // 2 - self._max_cache_len = 0 - - for i in range(self._sampling_num): - if self.is_causal: - layers.append( - CausalConv1D( - in_channels=in_channels, - out_channels=feat_out if self._sampling_num == i + 1 else conv_channels, - kernel_size=self._kernel_size, - stride=self._stride, - padding=None, - ) - ) - else: - layers.append( - torch.nn.Conv1d( - in_channels=in_channels, - out_channels=feat_out if self._sampling_num == i + 1 else conv_channels, - kernel_size=self._kernel_size, - stride=self._stride, - padding=self._left_padding, - ) - ) - layers.append(activation) - in_channels = conv_channels - - elif subsampling == "dw_striding_conv1d": - in_channels = feat_in - - self._stride = 2 - self._kernel_size = 5 - self._ceil_mode = False - - self._left_padding = (self._kernel_size - 1) // 2 - self._right_padding = (self._kernel_size - 1) // 2 - - # Layer 1 - layers.extend( - [ - torch.nn.Conv1d( - in_channels=in_channels, - out_channels=in_channels, - kernel_size=self._kernel_size, - stride=self._stride, - padding=self._left_padding, - groups=in_channels, - ), - torch.nn.Conv1d( - in_channels=in_channels, - out_channels=feat_out if self._sampling_num == 1 else conv_channels, - kernel_size=1, - stride=1, - padding=0, - groups=1, - ), - ] - ) - in_channels = conv_channels - layers.append(activation) - - for i in range(self._sampling_num - 1): - layers.extend( - [ - torch.nn.Conv1d( - in_channels=in_channels, - out_channels=in_channels, - kernel_size=self._kernel_size, - stride=self._stride, - padding=self._left_padding, - groups=in_channels, - ), - torch.nn.Conv1d( - in_channels=in_channels, - out_channels=feat_out if self._sampling_num == i + 2 else conv_channels, - kernel_size=1, - stride=1, - padding=0, - groups=1, - ), - ] - ) - layers.append(activation) - in_channels = conv_channels - - else: - raise ValueError(f"Not valid sub-sampling: {subsampling}!") - - if subsampling in ["dw_striding", "striding"]: - in_length = torch.tensor(feat_in, dtype=torch.float) - out_length = calc_length( - lengths=in_length, - all_paddings=self._left_padding + self._right_padding, - kernel_size=self._kernel_size, - stride=self._stride, - ceil_mode=self._ceil_mode, - repeat_num=self._sampling_num, - ) - self.out = torch.nn.Linear(conv_channels * int(out_length), feat_out) - self.conv2d_subsampling = True - elif subsampling in ["striding_conv1d", "dw_striding_conv1d"]: - self.out = None - self.conv2d_subsampling = False - else: - raise ValueError(f"Not valid sub-sampling: {subsampling}!") - - self.conv = torch.nn.Sequential(*layers) - - def get_sampling_frames(self): - return [1, self.subsampling_factor] - - def get_streaming_cache_size(self): - return [0, self.subsampling_factor + 1] - - def forward(self, x, mask): - """ - Forward method for NeMo subsampling. - - Args: - x[Batch, Time, Filters]: torch.Tensor - input tensor - x_mask: torch.Tensor - input mask - - Returns: - x: torch.Tensor - Resulting tensor from subsampling (B, T // time_reduction_factor, feat_out) - pad_mask: torch.Tensor - tensor of padded hidden state sequences (B, 1, T // time_reduction_factor) - """ - # Unsqueeze Channel Axis - if self.conv2d_subsampling: - x = x.unsqueeze(1) - # Transpose to Channel First mode - else: - x = x.transpose(1, 2) - - # split inputs if chunking_factor is set - if self.subsampling_conv_chunking_factor != -1 and self.conv2d_subsampling: - if self.subsampling_conv_chunking_factor == 1: - # if subsampling_conv_chunking_factor is 1, we split only if needed - # avoiding a bug / feature limiting indexing of tensors to 2**31 - # see https://github.com/pytorch/pytorch/issues/80020 - x_ceil = 2**31 / self._conv_channels * self._stride * self._stride - if torch.numel(x) > x_ceil: - need_to_split = True - else: - need_to_split = False - else: - # if subsampling_conv_chunking_factor > 1 we always split - need_to_split = True - - if need_to_split: - x, success = self.conv_split_by_batch(x) - if not success: # if unable to split by batch, try by channel - if self._subsampling == "dw_striding": - x = self.conv_split_by_channel(x) - else: - x = self.conv(x) # try anyway - else: - x = self.conv(x) - else: - x = self.conv(x) - - # Flatten Channel and Frequency Axes - if self.conv2d_subsampling: - b, c, t, f = x.size() - x = self.out(x.transpose(1, 2).reshape(b, t, -1)) - # Transpose to Channel Last mode - else: - x = x.transpose(1, 2) - - if mask is None: - return x, None - - max_audio_length = x.shape[1] - feature_lens = mask.sum(1) - padding_length = torch.ceil(feature_lens / self.subsampling_factor) - if self.is_causal and self.subsampling_causal_cond: - feature_lens_remainder = feature_lens % self.subsampling_factor - padding_length[feature_lens_remainder != 1] += 1 - pad_mask = ( - torch.arange(0, max_audio_length, device=x.device).expand(padding_length.size(0), -1) - < padding_length.unsqueeze(1) - ) - return x, pad_mask.unsqueeze(1) - - def reset_parameters(self): - # initialize weights - if self._subsampling == "dw_striding": - with torch.no_grad(): - # init conv - scale = 1.0 / self._kernel_size - dw_max = (self._kernel_size**2) ** -0.5 - pw_max = self._conv_channels**-0.5 - - torch.nn.init.uniform_(self.conv[0].weight, -scale, scale) - torch.nn.init.uniform_(self.conv[0].bias, -scale, scale) - - for idx in range(2, len(self.conv), 3): - torch.nn.init.uniform_(self.conv[idx].weight, -dw_max, dw_max) - torch.nn.init.uniform_(self.conv[idx].bias, -dw_max, dw_max) - torch.nn.init.uniform_(self.conv[idx + 1].weight, -pw_max, pw_max) - torch.nn.init.uniform_(self.conv[idx + 1].bias, -pw_max, pw_max) - - # init fc (80 * 64 = 5120 from https://github.com/kssteven418/Squeezeformer/blob/13c97d6cf92f2844d2cb3142b4c5bfa9ad1a8951/src/models/conformer_encoder.py#L487 - fc_scale = (self._feat_out * self._feat_in / self._sampling_num) ** -0.5 - torch.nn.init.uniform_(self.out.weight, -fc_scale, fc_scale) - torch.nn.init.uniform_(self.out.bias, -fc_scale, fc_scale) - - def conv_split_by_batch(self, x): - """Tries to split input by batch, run conv and concat results""" - b, _, _, _ = x.size() - if b == 1: # can't split if batch size is 1 - return x, False - - if self.subsampling_conv_chunking_factor > 1: - cf = self.subsampling_conv_chunking_factor - else: - # avoiding a bug / feature limiting indexing of tensors to 2**31 - # see https://github.com/pytorch/pytorch/issues/80020 - x_ceil = 2**31 / self._conv_channels * self._stride * self._stride - p = math.ceil(math.log(torch.numel(x) / x_ceil, 2)) - cf = 2**p - - new_batch_size = b // cf - if new_batch_size == 0: # input is too big - return x, False - - return torch.cat([self.conv(chunk) for chunk in torch.split(x, new_batch_size, 0)]), True - - def conv_split_by_channel(self, x): - """For dw convs, tries to split input by time, run conv and concat results""" - x = self.conv[0](x) # full conv2D - x = self.conv[1](x) # activation - - for i in range(self._sampling_num - 1): - _, c, t, _ = x.size() - - if self.subsampling_conv_chunking_factor > 1: - cf = self.subsampling_conv_chunking_factor - else: - # avoiding a bug / feature limiting indexing of tensors to 2**31 - # see https://github.com/pytorch/pytorch/issues/80020 - p = math.ceil(math.log(torch.numel(x) / 2**31, 2)) - cf = 2**p - - new_c = int(c // cf) - if new_c == 0: - new_c = 1 - - new_t = int(t // cf) - if new_t == 0: - new_t = 1 - - x = self.channel_chunked_conv(self.conv[i * 3 + 2], new_c, x) # conv2D, depthwise - - # splitting pointwise convs by time - x = torch.cat( - [self.conv[i * 3 + 3](chunk) for chunk in torch.split(x, new_t, 2)], 2 - ) # conv2D, pointwise - x = self.conv[i * 3 + 4](x) # activation - return x - - def channel_chunked_conv(self, conv, chunk_size, x): - """Performs channel chunked convolution""" - - ind = 0 - out_chunks = [] - for chunk in torch.split(x, chunk_size, 1): - step = chunk.size()[1] - - if self.is_causal: - chunk = nn.functional.pad( - chunk, - pad=( - self._kernel_size - 1, - self._stride - 1, - self._kernel_size - 1, - self._stride - 1, - ), - ) - ch_out = nn.functional.conv2d( - chunk, - conv.weight[ind : ind + step, :, :, :], - bias=conv.bias[ind : ind + step], - stride=self._stride, - padding=0, - groups=step, - ) - else: - ch_out = nn.functional.conv2d( - chunk, - conv.weight[ind : ind + step, :, :, :], - bias=conv.bias[ind : ind + step], - stride=self._stride, - padding=self._left_padding, - groups=step, - ) - out_chunks.append(ch_out) - ind += step - - return torch.cat(out_chunks, 1) - - def change_subsampling_conv_chunking_factor(self, subsampling_conv_chunking_factor: int): - if ( - subsampling_conv_chunking_factor != -1 - and subsampling_conv_chunking_factor != 1 - and subsampling_conv_chunking_factor % 2 != 0 - ): - raise ValueError("subsampling_conv_chunking_factor should be -1, 1, or a power of 2") - self.subsampling_conv_chunking_factor = subsampling_conv_chunking_factor - - -def calc_length(lengths, all_paddings, kernel_size, stride, ceil_mode, repeat_num=1): - """Calculates the output length of a Tensor passed through a convolution or max pooling layer""" - add_pad: float = all_paddings - kernel_size - one: float = 1.0 - for i in range(repeat_num): - lengths = torch.div(lengths.to(dtype=torch.float) + add_pad, stride) + one - if ceil_mode: - lengths = torch.ceil(lengths) - else: - lengths = torch.floor(lengths) - return lengths.to(dtype=torch.int) - -#### multihead attention starts here -class AttModule(nn.Module): - """Attention abstraction module""" - - def __init__(self): - super().__init__() - self.export_mode = False - - def set_export(self, mode=True): - """set the export mode""" - self.export_mode = mode - - def forward( - self, - x: Tensor, - memory: Optional[Tensor] = None, - pos_emb: Optional[Tensor] = None, - att_mask: Optional[Tensor] = None, - ) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]: - """AttModule forward - - Args: - x: torch.Tensor - input tensor. - memory: torch.Tensor, optional - memory tensor. - pos_emb: torch.Tensor, optional - positional encoder embedding. - att_mask: torch.Tensor, optional - attention mask tensor. - """ - return x, memory, pos_emb, att_mask - - -class AttBlock(Block, AttModule): - """Attention Block module to support both Attention and Block module.""" - - def memory_dims(self, max_len=False): - """memory dimensions""" - return (1, self.input_size) - -def masked_softmax( - scores, - mask: Optional[Tensor], -): - if mask is not None: - mask = mask.unsqueeze(1).eq(0) # (batch, 1, time1, time2) - scores = scores.masked_fill(mask, -torch.inf) - attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0) # (batch, head, time1, time2) - else: - attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2) - return attn - - -class MultiHeadedAttention(nn.Module): - """Multi-Head Attention layer with optional relative position embedding and GLU. - - Args: - n_head: int - the number of heads. - n_feat: int - input size features. - dropout_rate: float - dropout rate. - use_LN: bool - apply layer norm or not - dropout_at_output: bool - whether to apply dropout at output - attention_inner_dim: int, optional - the attention dimension used in the class, - it can be different from the input dimension n_feat. - default: -1 (equal to n_feat). - use_pt_scaled_dot_product_attention: bool, optional - if set True, use pytorch scaled dot product attention in training. NOTE: this will NOT - be used in ONNX decoding due to a lack of support. In that case, we use the original - attention implementation, which shows no regression. - default: False. - n_value: int, optional - if set to values other than -1, use a different dimension for value. With the default value (i.e. -1), it is backward compatible. - group_size: int, optional. must divide `n_head` - if group_size > 1: GQA - if group_size = 1: MHA - if group_size = n_head: MQA - """ - - inv_sqrt_d_k: torch.jit.Final[float] - h: torch.jit.Final[int] - h_k: torch.jit.Final[int] - g: torch.jit.Final[int] - - def __init__( - self, - n_head, - n_feat, - dropout_rate, - attention_inner_dim=-1, - glu_type="swish", - bias_in_glu=True, - use_pt_scaled_dot_product_attention=False, - n_value=-1, - group_size: int = 1, - ): - super().__init__() - if n_value == -1: - n_value = n_feat - if attention_inner_dim == -1: - attention_inner_dim = n_feat - assert attention_inner_dim % n_head == 0 - - # We assume d_v always equals d_k - self.d_k = attention_inner_dim // n_head - self.inv_sqrt_d_k = 1.0 / math.sqrt(self.d_k) - self.h = n_head - assert n_head % group_size == 0, "group_size must divide n_head" - self.g = group_size - self.h_k = n_head // group_size - - self.linear_q = nn.Linear(n_feat, attention_inner_dim) - self.linear_k = nn.Linear(n_feat, attention_inner_dim // group_size) - self.linear_v = nn.Linear(n_value, attention_inner_dim // group_size) - self.linear_out = nn.Linear(attention_inner_dim // group_size, n_value) - - self.attn = torch.jit.Attribute(None, Optional[Tensor]) - self.dropout = nn.Dropout(p=dropout_rate) - self.dropout_rate = dropout_rate - self.use_pt_scaled_dot_product_attention = use_pt_scaled_dot_product_attention - - if use_pt_scaled_dot_product_attention and group_size > 1: - raise ValueError("Cannot use PT Scaled Attention with GQA") - - # Torchscript eager quantization. Note that these functions below are - # NOOPs and have very little impact on performance unless quantization is - # enabled. - self.quant_q = torch.ao.quantization.QuantStub() - self.quant_x = torch.ao.quantization.QuantStub() - self.dequant = torch.ao.quantization.DeQuantStub() - self.ffunc = torch.ao.nn.quantized.FloatFunctional() - - def forward( - self, - query: Tensor, - key: Tensor, - value: Tensor, - pos_k: Tensor, - pos_v: Tensor, - mask: Optional[Tensor], - relative_attention_bias: Optional[Tensor] = None, - ): - """Compute 'Scaled Dot Product Attention'. - - Args: - query: torch.Tensor - query tensor (batch, time1, size) - key: torch.Tensor - key tensor (batch, time2, size) - value: torch.Tensor - value tensor (batch, time1, size) - pos_k: torch.Tensor - key tensor used for relative positional embedding. - pos_v: torch.Tensor - value tensor used for relative positional embedding. - mask: torch.Tensor - mask tensor (batch, time1, time2) - relative_attention_bias: torch.Tensor - bias added to attention logits w.r.t. relative positions (1, n_head, time1, time2) - """ - n_batch = query.size(0) - - q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k) # (b, t, d) - k = self.linear_k(key).view(n_batch, -1, self.h_k, self.d_k) # (b, t, d) - v = self.linear_v(value).view(n_batch, -1, self.h_k, self.d_k) - q = ( - q.transpose(1, 2) - if self.use_pt_scaled_dot_product_attention and not torch.jit.is_scripting() - else q.transpose(1, 2) * self.inv_sqrt_d_k - ) - k = k.transpose(1, 2) # (batch, head_k, time2, d_k) - v = v.transpose(1, 2) # (batch, head_k, time2, d_k) - - if self.use_pt_scaled_dot_product_attention and not torch.jit.is_scripting(): - attn_mask = None - if mask is not None: - mask = mask.unsqueeze(1) - if relative_attention_bias is not None: - attn_mask = mask + relative_attention_bias - else: - attn_mask = mask - if mask.dtype != q.dtype: - attn_mask = attn_mask.to(q.dtype) - - with torch.backends.cuda.sdp_kernel( - enable_flash=True, enable_math=True, enable_mem_efficient=True - ): - x = torch.nn.functional.scaled_dot_product_attention( - q, - k, - v, - attn_mask=attn_mask, - dropout_p=self.dropout_rate, - ) - else: - if self.h != self.h_k: - q = q.reshape(n_batch, self.g, self.h_k, -1, self.d_k) - A = torch.einsum("b g h t d, b h s d -> b h t s", q, k) - else: - A = torch.matmul(q, k.transpose(-2, -1)) - if pos_k is not None: - if self.h != self.h_k: - B = torch.einsum("b g h t d, t s d -> b h t s", q, pos_k) - else: - reshape_q = ( - q.contiguous().view(n_batch * self.h, -1, self.d_k).transpose(0, 1) - ) # (t1,nh,dk) - B = torch.matmul(reshape_q, pos_k.transpose(-2, -1)) # pos_k: (t1,dk,t2) - B = B.transpose(0, 1).view(n_batch, self.h, pos_k.size(0), pos_k.size(1)) - scores = A + B - else: - scores = A - - if relative_attention_bias is not None: - scores = scores + relative_attention_bias - - attn = masked_softmax(scores, mask) # (batch, head, time1, time2) - - self.attn = attn - - p_attn = self.dropout(attn) - x = torch.matmul(p_attn.to(v.dtype), v) # (batch, head, time1, d_k) - if pos_v is not None: - reshape_attn = ( - p_attn.contiguous() - .view(n_batch * self.h, pos_v.size(0), pos_v.size(1)) - .transpose(0, 1) - ) # (t1, bh, t2) - - attn_v = ( - torch.matmul(reshape_attn, pos_v) - .transpose(0, 1) - .contiguous() - .view(n_batch, self.h, pos_v.size(0), self.d_k) - ) - x = x + attn_v - x = ( - x.transpose(1, 2).contiguous().view(n_batch, -1, self.h_k * self.d_k) - ) # (batch, time1, d_model) - - return self.linear_out(x) # (batch, time1, d_model) - - -def unfold_tensor(xs_pad, max_seq_len): - """ - For a given tensor with shape of (N, T, D), if sequence length T is longer than max_seq_len, - this function unfold it to a (NT', max_seq_len, D) where T' is T // max_seq_len. - Args: - xs_pad: N, T, D - """ - _, _, D = xs_pad.shape - xs_pad = xs_pad.transpose(-1, -2) # convert to N, D, T - # N x D x 1 x T => N x (D x max_seq_len) x T' - xs_pad = F.unfold( - xs_pad[..., None, :], - kernel_size=(1, max_seq_len), - stride=(1, max_seq_len), - ) - - new_bsz, _, slen = xs_pad.shape - # N x D x max_seq_len x T' - xs_pad = xs_pad.view(new_bsz, -1, max_seq_len, slen) - # N x T' x max_seq_len x D - xs_pad = xs_pad.permute(0, 3, 2, 1).contiguous() - # NT' x max_seq_len x D - xs_pad = xs_pad.view(-1, max_seq_len, D) - return xs_pad - -# conformer_encoder.py -class MultiSequential(torch.nn.Sequential): - """Multi-input multi-output torch.nn.Sequential""" - - @torch.jit.ignore - def forward(self, *args): - """Forward method implementation.""" - for m in self: - args = m(*args) - return args - -def repeat(repeat_num, module_gen_fn): - """repeat module N times - - :param int repeat_num: repeat time - :param function module_gen_fn: function to generate module - :return: repeated modules - :rtype: MultiSequential - """ - return MultiSequential(*[module_gen_fn(i) for i in range(repeat_num)]) - -class ConformerEncoderLayer(nn.Module): - """ConformerEncoder Layer module. - for more details see conformer paper: - https://arxiv.org/abs/2005.08100 - This module implement the Conformer block layer. - - Args: - d_model: int - attention dim. - ext_pw_out_channel: int - if > 0, ext_pw_out_channel is a dim channel size - for the last pointwise conv after swish activation. - depthwise_seperable_out_channel: int - if set different to 0, the number of depthwise_seperable_out_channel - will be used as a channel_out of the second conv1d layer. - otherwise, it equal to 0, the second conv1d layer is skipped. - depthwise_multiplier: int - number of input_dim channels duplication. this value - will be used to compute the hidden channels of the Conv1D. - n_head: int - the number of heads for multihead attention module. - d_ffn: int - output size of the feed_forward blocks. - ext_pw_kernel_size: int - kernel size of the conv pointwise of the conformer. - kernel_size: int - kernel size. - dropout_rate: float - dropout rate. - causal: bool, optional - if set to True, convolution have no access - to future frames. default False. - batch_norm: bool, optional - if set to True, apply batchnorm before activation - in ConvModule layer of the conformer. - default False - activation: str, optional - activation function name, - one of ["relu", "swish", "sigmoid"], - sigmoid activation is only used with "glu_in_fnn=True", - default "relu". - chunk_se: int, optional - 0 for offline SE. - 1 for streaming SE, where mean is computed - by accumulated history until current chunk_se. - 2 for streaming SE, where mean is computed - by only the current chunk. - default 0. - chunk_size: int, optional - chunk_size for cnn. default 18 - conv_activation: str, optional - activation function used in ConvModule part - of the conformer, default "relu". - conv_glu_type: str, optional - activation function used for the glu inside - the ConvModule part of the conformer. - default: "sigmoid". - bias_in_glu: bool, optional - if set to True, use additive bias in the weight module - before GLU. - linear_glu_in_convm: bool, optional - if set to True, use GLULinear module, - otherwise, used GLUPointWiseConv module. - default to False. - attention_innner_dim: int, otional - if equal to -1, attention dim for linears k/q/v is - equal to d_model. otherwise attention_innner_dim is used. - default -1. - attention_glu_type: str, optional - activation function for glu used in the multihead attention, - default "swish". - activation_checkpointing: str, optional - a dictionarry of {"module","interval","offload"}, where - "module": str - accept ["transformer", "attention"] to select - which module should do activation checkpointing. - "interval": int, default 1, - interval of applying activation checkpointing, - interval = 1 means that we apply checkpointing - on every layer (if activation), otherwise, - we apply it every x interval. - "offload": bool, default False, - if set to True, we offload activation to cpu and - reload it during backward, otherwise, - we recalculate activation in backward. - default "". - export: bool, optional - if set to True, it remove the padding from convolutional layers - and allow the onnx conversion for inference. - default False. - use_pt_scaled_dot_product_attention: bool, optional - if set to True, use pytorch's scaled dot product attention implementation in training. - attn_group_sizes: int, optional - the number of groups to use for attention, default 1 (Multi-Head Attention), - 1 = typical Multi-Head Attention, - 1 < attn_group_sizes < attention_heads = Grouped-Query Attention - attn_group_sizes = attenion_heads = Multi-Query Attention - """ - - def __init__( - self, - d_model=512, - ext_pw_out_channel=0, - depthwise_seperable_out_channel=256, - depthwise_multiplier=1, - n_head=4, - d_ffn=2048, - ext_pw_kernel_size=1, - kernel_size=3, - dropout_rate=0.1, - causal=False, - batch_norm=False, - activation="relu", - chunk_se=0, - chunk_size=18, - conv_activation="relu", - conv_glu_type="sigmoid", - bias_in_glu=True, - linear_glu_in_convm=False, - attention_innner_dim=-1, - attention_glu_type="swish", - activation_checkpointing="", - export=False, - use_pt_scaled_dot_product_attention=False, - attn_group_sizes: int = 1, - ): - super().__init__() - - self.feed_forward_in = FeedForward( - d_model=d_model, - d_inner=d_ffn, - dropout_rate=dropout_rate, - activation=activation, - bias_in_glu=bias_in_glu, - ) - - self.self_attn = encoder_checkpoint_wrapper( - activation_checkpointing, - MultiHeadedAttention, - )( - MultiHeadedAttention( - n_head, - d_model, - dropout_rate, - attention_innner_dim, - attention_glu_type, - bias_in_glu, - use_pt_scaled_dot_product_attention=use_pt_scaled_dot_product_attention, - group_size=attn_group_sizes, - ) - ) - self.conv = ConvModule( - d_model, - ext_pw_out_channel, - depthwise_seperable_out_channel, - ext_pw_kernel_size, - kernel_size, - depthwise_multiplier, - dropout_rate, - causal, - batch_norm, - chunk_se, - chunk_size, - conv_activation, - conv_glu_type, - bias_in_glu, - linear_glu_in_convm, - export=export, - ) - - self.feed_forward_out = FeedForward( - d_model=d_model, - d_inner=d_ffn, - dropout_rate=dropout_rate, - activation=activation, - bias_in_glu=bias_in_glu, - ) - - self.layer_norm_att = nn.LayerNorm(d_model) - self.layer_norm = nn.LayerNorm(d_model) - - def forward( - self, - x, - pos_k, - pos_v, - mask, - relative_attention_bias: Optional[Tensor] = None, - ): - """ConformerEncoder forward. - - Args: - x: torch.Tensor - input feature of shape (batch, max_time_in, size) - pos_k: torch.Tensor - positional key embedding. - mask: torch.Tensor - mask for x (batch, max_time_in) - relative_attention_bias: Optional[torch.Tensor] - bias added to attention logits w.r.t. relative positions (1, n_head, time1, time2) - """ - x = x + 0.5 * self.feed_forward_in(x) - norm_x = self.layer_norm_att(x) - - x = x + self.self_attn( - norm_x, - norm_x, - norm_x, - pos_k, - pos_v, - mask, - relative_attention_bias=relative_attention_bias, - ) - x = x + self.conv(x) - x = x + 0.5 * self.feed_forward_out(x) - - out = self.layer_norm(x) - - return out, pos_k, pos_v, mask - -class TransformerEncoderBase(abc.ABC, nn.Module): - """The Base class for Transformer based encoders - - Please set causal = True in streaming model - Args: - input_size: int - input feature dimension. - chunk_size: int, list(int) - Number of frames for each chunk - This variable can take 2 forms: - int: Used for inference, or single chunk size training - list(int) : Used only for variable chunk size training - Some examples for the 2 cases: - chunk_size = 12 - chunk_size = [6, 8, 12, 24] - left_chunk: int, list(int) - Number of chunks used for masking in streaming mode. - This variable can take 2 forms: - int: Used for inference, or single chunk size training - list(int) : Used only for variable chunk size training. When - chunk_size is a list, left_chunk must be a list with same length. - Some examples for the 2 cases: - left_chunk = 6 - left_chunk = [12, 9, 6, 3] - attention_dim: int, optional - attention dimension. default 256. - attention_heads: int, optional - the number of heads. default 4 - input_layer: str, optional - input layer type before Conformer, - one of ["linear", "conv2d", "custom", "vgg2l", "embed"], - default "conv2d" - cnn_out: int, optional - the number of CNN channels before Conformer. - default -1. - cnn_layer_norm: bool, optional - layer norm between Conformer and the first CNN. - default False. - time_reduction: int, optional - time reduction factor - default 4 - dropout_rate: float, optional - dropout rate. default 0.1 - padding_idx: int, optional - padding index for input_layer=embed - default -1 - relative_attention_bias_args: dict, optional - use more efficient scalar bias-based relative multihead attention (Q*K^T + B) - implemented in cmb.basics.embedding.[T5/ALiBi]RelativeAttentionLogitBias - usage: relative_attention_bias_args={"type": t5/alibi} - additional method-specific arguments can be provided (see transformer_base.py) - positional_dropout_rate: float, optional - dropout rate after positional encoding. default 0.0 - nemo_conv_settings: dict, optional - A dictionary of settings for NeMo Subsampling. - default None - conv2d_extra_padding: str, optional - Add extra padding in conv2d subsampling layers. Choices are - (feat, feat_time, none, True). - if True or feat_time, the extra padding is added into non full - supraframe utts in batch. - Default: none - attention_group_size: int, optional - the number of groups to use for attention, default 1 (Multi-Head Attention), - 1 = typical Multi-Head Attention, - 1 < attention_group_size < attention_heads = Grouped-Query Attention - attention_group_size = attenion_heads = Multi-Query Attention - """ - - def __init__( - self, - input_size, - chunk_size, - left_chunk, - attention_dim=256, - attention_heads=4, - input_layer="nemo_conv", - cnn_out=-1, - cnn_layer_norm=False, - time_reduction=4, - dropout_rate=0.0, - padding_idx=-1, - relative_attention_bias_args=None, - positional_dropout_rate=0.0, - nemo_conv_settings=None, - conv2d_extra_padding: Literal["feat", "feat_time", "none", True] = "none", - attention_group_size=1, - encoder_embedding_config=None, - ): - super().__init__() - self.input_size = input_size - self.input_layer = input_layer - self.chunk_size = chunk_size - self.left_chunk = left_chunk - self.attention_dim = attention_dim - self.num_heads = attention_heads - self.attention_group_size = attention_group_size - self.time_reduction = time_reduction - self.nemo_conv_settings = nemo_conv_settings - self.encoder_embedding_config = encoder_embedding_config - - if self.input_layer == "nemo_conv": - default_nemo_conv_settings = { - "subsampling": "dw_striding", - "subsampling_factor": self.time_reduction, - "feat_in": input_size, - "feat_out": attention_dim, - "conv_channels": 256, - "subsampling_conv_chunking_factor": 1, - "activation": nn.ReLU(), - "is_causal": False, - } - # Override any of the defaults with the incoming, user settings - if nemo_conv_settings: - default_nemo_conv_settings.update(nemo_conv_settings) - for i in ["subsampling_factor", "feat_in", "feat_out"]: - assert ( - i not in nemo_conv_settings - ), "{i} should be specified outside of the NeMo dictionary" - - self.embed = NemoConvSubsampling( - **default_nemo_conv_settings, - ) - else: - raise ValueError("unknown input_layer: " + input_layer) - - self.pos_emb = AbsolutePositionalEncoding(attention_dim, positional_dropout_rate) - - self.relative_attention_bias_type = ( - relative_attention_bias_args.get("type") if relative_attention_bias_args else None - ) - if self.relative_attention_bias_type == "t5": - assert ( - self.num_heads % self.attention_group_size == 0 - ), "attention_group_size must divide n_head" - self.relative_attention_bias_layer = T5RelativeAttentionLogitBias( - self.num_heads // self.attention_group_size, - max_distance=relative_attention_bias_args.get("t5_bias_max_distance", 1000), - symmetric=relative_attention_bias_args.get("t5_bias_symmetric", False), - ) - else: - raise NotImplementedError - - - def post_init(self, init_model_config): - - pretrained_speech_encoder_path = init_model_config.get('pretrained_speech_encoder_path', None) - if pretrained_speech_encoder_path: - model_state = torch.load(pretrained_speech_encoder_path, map_location="cpu") - encoder_state_dict = {} - for k, v in model_state.items(): - if "encoder." in k: - tmp_k = k.replace("encoder.", "") - encoder_state_dict[tmp_k] = v - - if hasattr(self, "encoder_embedding"): - del self.encoder_embedding - self.load_state_dict(encoder_state_dict) - - if not hasattr(self, "encoder_embedding"): - self.encoder_embedding = MeanVarianceNormLayer(self.encoder_embedding_config["input_size"]) - - mean_file = init_model_config.get('mean_file', None) - invstd_file = init_model_config.get('invstd_file', None) - if mean_file is not None and invstd_file is not None: - self.encoder_embedding.load_mean_invstd(mean_file, invstd_file) - - def compute_lens_change(self, feature_lens): - """feature_lens: int - return updated feature lens. - - This used to return a different lambda function for each case that computed - the right thing. That does not work within Torchscript. If you really - need this to be faster, create nn.Module()-s for all the cases and return - one of them. Torchscript does support that. - """ - if self.input_layer == "nemo_conv": - # Handle the special causal case - subsampling_causal_cond = self.nemo_conv_settings.get("subsampling", "dw_striding") in [ - "dw_striding", - "striding", - "striding_conv1d", - ] - is_causal = self.nemo_conv_settings.get("is_causal", False) - if is_causal and subsampling_causal_cond: - lens_change = ( - torch.ceil(feature_lens / self.time_reduction).long() - if isinstance(feature_lens, Tensor) - else math.ceil(feature_lens / self.time_reduction) - ) - feature_lens_remainder = feature_lens % self.time_reduction - if isinstance(feature_lens, Tensor): - lens_change[feature_lens_remainder != 1] += 1 - elif feature_lens_remainder != 1: - lens_change += 1 - return lens_change - ceil_func = math.ceil if isinstance(feature_lens, int) else torch.ceil - return ceil_func(feature_lens / self.time_reduction) - - @abc.abstractmethod - def forward(self): - """Abstract forward method implementation.""" - - def _chunk_size_selection(self, chunk_size=None, left_chunk=None): - """If chunk size is a list, we will randomly select a chunk size.""" - - if chunk_size is None: - chunk_size = self.chunk_size - if left_chunk is None: - left_chunk = self.left_chunk - if isinstance(chunk_size, list): - # Variable chunk size during training - chunk_size_index = int(torch.randint(low=0, high=len(chunk_size), size=(1,))) - chunk_size_train_eff = chunk_size[chunk_size_index] - if not isinstance(left_chunk, list): - raise ValueError("Since chunk_size is a list, left_chunk must be a list") - if len(left_chunk) != len(chunk_size): - raise ValueError( - "The length of left_chunk must be the same as length of chunk_size." - ) - left_chunk_train_eff = left_chunk[chunk_size_index] - else: - chunk_size_train_eff = chunk_size - left_chunk_train_eff = left_chunk - - return chunk_size_train_eff, left_chunk_train_eff - - def _get_embed_class(self, embed): - # pylint: disable=protected-access - is_embed_using_act_chkpt = isinstance(embed, CheckpointWrapper) - is_embed_fsdp_wrapped = isinstance(embed, FullyShardedDataParallel) - embed_class = embed - if is_embed_using_act_chkpt: - embed_class = embed._checkpoint_wrapped_module - if is_embed_fsdp_wrapped: - embed_class = embed.module - return embed_class - - def _forward_embeddings_core(self, input_tensor, masks): - embed_class = self._get_embed_class(self.embed) - assert isinstance(embed_class, NemoConvSubsampling) - input_tensor, masks = self.embed(input_tensor, masks) - return input_tensor, masks - - def _position_embedding(self, input_tensor): - pos_k = None - pos_v = None - if self.relative_attention_bias_layer is None: - input_tensor = self.pos_emb(input_tensor) # default to add abs sinusoid embedding - return pos_k, pos_v - - def _streaming_mask(self, seq_len, batch_size, chunk_size, left_chunk): - chunk_size_train_eff, left_chunk_train_eff = self._chunk_size_selection( - chunk_size, left_chunk - ) - - # Create mask matrix for streaming - # S stores start index. if chunksize is 18, s is [0,18,36,....] - chunk_start_idx = np.arange(0, seq_len, chunk_size_train_eff) - # avoid randomness when run evaluation or decoding - if self.training and np.random.rand() > 0.5: - # Either first or last chunk is not complete. - # If only the last one is not complete, EOS is not effective - chunk_start_idx = seq_len - chunk_start_idx - chunk_start_idx = chunk_start_idx[::-1] - chunk_start_idx = chunk_start_idx[:-1] - chunk_start_idx = np.insert(chunk_start_idx, 0, 0) - - enc_streaming_mask = ( - adaptive_enc_mask(seq_len, chunk_start_idx, left_window=left_chunk_train_eff) - .unsqueeze(0) - .expand([batch_size, -1, -1]) - ) - return enc_streaming_mask - - def forward_embeddings(self, xs_pad, masks, chunk_size_nc=None, left_chunk_nc=None): - """Forwarding the inputs through the top embedding layers - - Args: - xs_pad: torch.Tensor - input tensor - masks: torch.Tensor - input mask - chunk_size_nc: (optional, default is None) chunk size for non-causal layers - left_chunk_nc: (optional, default is None) # of left chunks for non-causal layers - """ - # pylint: disable=R0915 - # get new lens. - seq_len = int(self.compute_lens_change(xs_pad.shape[1])) - if seq_len <= 0: - raise ValueError( - f"""The squence length after time reduction is invalid: {seq_len}. - Your input feature is too short. Consider filtering out the very - short sentence from data loader""", - ) - - batch_size = xs_pad.shape[0] - - enc_streaming_mask = self._streaming_mask( - seq_len, batch_size, self.chunk_size, self.left_chunk - ) - - if xs_pad.device != "cpu": - enc_streaming_mask = enc_streaming_mask.to(xs_pad.device) - - input_tensor = xs_pad - input_tensor, masks = self._forward_embeddings_core(input_tensor, masks) - - streaming_mask = enc_streaming_mask - if streaming_mask is not None and masks is not None: - hs_mask = masks & streaming_mask - elif masks is not None: - hs_mask = masks - else: - hs_mask = streaming_mask - - if chunk_size_nc is not None: - enc_streaming_mask_nc = self._streaming_mask( - seq_len, batch_size, chunk_size_nc, left_chunk_nc - ) - if xs_pad.device != "cpu": - enc_streaming_mask_nc = enc_streaming_mask_nc.to(xs_pad.device) - if masks is not None: - hs_mask_nc = masks & enc_streaming_mask_nc - else: - hs_mask_nc = enc_streaming_mask_nc - else: - hs_mask_nc = None - - pos_k, pos_v = self._position_embedding(input_tensor) - - if chunk_size_nc is None: - return input_tensor, pos_k, pos_v, hs_mask, masks - return input_tensor, pos_k, pos_v, hs_mask, masks, hs_mask_nc - - def get_offset(self): - """Returns offset used when retaining inputs for decoding. - - This is essentially, how many additional frames have to be added to - the front-end CNN input to ensure it can produce a single output. - So if the "padding" parameter is 0, typically offset will be > 0. - """ - return get_offset(self.input_layer, self.time_reduction) - - -def get_offset(input_layer: str, time_reduction: int): - """Get an offset. We will use the offset for determining #frames of a subsampled feature. - - Args: - input_layer (str): Type of an input layer - time_reduction (int): time reduction factor for downsampling a feature - Returns: - int: offset - """ - if input_layer in ("conv2d", "nemo_conv") and time_reduction == 4: - return 3 - if input_layer in ("conv2d",) and time_reduction == 6: - return 1 - if input_layer in ("conv2d", "nemo_conv") and time_reduction == 8: - return 7 - return 0 - - -class ConformerEncoder(TransformerEncoderBase): - """ConformerEncoder module. - see original paper for more details: - https://arxiv.org/abs/2005.08100 - - Please set causal = True in streaming model - Args: - input_size: int - input feature dimension. - chunk_size: int, list(int) - Number of frames for each chunk - This variable can take 2 forms: - int: Used for inference, or single chunk size training - list(int) : Used only for variable chunk size training - Some examples for the 2 cases: - chunk_size = 12 - chunk_size = [6, 8, 12, 24] - left_chunk: int, list(int) - Number of chunks used for masking in streaming mode. - This variable can take 2 forms: - int: Used for inference, or single chunk size training - list(int) : Used only for variable chunk size training. When - chunk_size is a list, left_chunk must be a list with same length. - Some examples for the 2 cases: - left_chunk = 6 - left_chunk = [12, 9, 6, 3] - left_chunk: int - number of chunks used for masking in streaming mode. - num_lang: int - This parameter is used to store the number of languages in the lang_dict, - only used for multiseed/multilingual models. default None. - attention_dim: int, optional - attention dimension. default 256. - attention_heads: int, optional - the number of heads. default 4 - linear_units: - the number of units of position-wise feed forward. - default 2048 - num_block: - number of Transformer layer. default 6 - dropout_rate: float, optional - dropout rate. default 0.1 - input_layer: str, optional - input layer type before Conformer, - one of ["linear", "conv2d", "custom", "vgg2l", "embed"], - default "conv2d" - causal: bool, optional - if set to True, convolution have no access - to future frames. default False. - batch_norm: bool, optional - if set to True, apply batchnorm before activation - in ConvModule layer of the conformer. - default False - cnn_out: int, optional - the number of CNN channels before Conformer. - default -1. - cnn_layer_norm: bool, optional - layer norm between Conformer and the first CNN. - default False. - ext_pw_out_channel: int, optional - the number of channel for CNN - before depthwise_seperable_CNN. - If 0 then use linear. default 0. - ext_pw_kernel_size: int, optional - kernel size of N before depthwise_seperable_CNN. - only work for ext_pw_out_channel > 0. - default 1 - depthwise_seperable_out_channel: int, optional - the number of channel for - depthwise_seperable_CNN. - default 256. - depthwise_multiplier: int, optional - the number of multiplier for - depthwise_seperable_CNN. - default 1. - chunk_se: int, optional - 0 for offline SE. - 1 for streaming SE, where mean is computed - by accumulated history until current chunk_se. - 2 for streaming SE, where mean is computed - by only the current chunk. - default 0. - kernel_size: int, optional - the number of kernels for depthwise_seperable_CNN. - default 3. - activation: str, optional - FeedForward block activation. - one of ["relu", "swish", "sigmoid"] - default "relu". - conv_activation: str, optional - activation function used in ConvModule part - of the conformer, default "relu". - conv_glu_type: str, otional - activation used use glu in depthwise_seperable_CNN, - default "sigmoid" - bias_in_glu: bool, optional - if set to True, use additive bias in the weight module - before GLU. default True - linear_glu_in_convm: bool, optional - if set to True, use GLULinear module, - otherwise, used GLUPointWiseConv module. - default to False. - attention_glu_type: str - only work for glu_in_attention !=0 - default "swish". - export: bool, optional - if set to True, it remove the padding from convolutional layers - and allow the onnx conversion for inference. - default False. - activation_checkpointing: str, optional - a dictionarry of {"module","interval","offload"}, where - "module": str - accept ["transformer", "attention"] to select - which module should do activation checkpointing. - "interval": int, default 1, - interval of applying activation checkpointing, - interval = 1 means that we apply checkpointing - on every layer (if activation), otherwise, - we apply it every x interval. - "offload": bool, default False, - if set to True, we offload activation to cpu and - reload it during backward, otherwise, - we recalculate activation in backward. - default "". - extra_layer_output_idx: int - the layer index to be exposed. - relative_attention_bias_args: dict, optional - use more efficient scalar bias-based relative multihead attention (Q*K^T + B) - implemented in cmb.basics.embedding.[T5/ALiBi]RelativeAttentionLogitBias - usage: relative_attention_bias_args={"type": t5/alibi} - additional method-specific arguments can be provided (see transformer_base.py) - time_reduction: int optional - time reduction factor - default 4 - use_pt_scaled_dot_product_attention: whether to use pytorch scaled dot product attention - in training. - Default: False - nemo_conv_settings: dict, optional - A dictionary of settings for NeMo Subsampling. - default: None - usage: nemo_conv_settings= - { - "subsampling": - dw_striding/striding/dw_striding_conv1d/striding_conv1d, - "conv_channels": int, - "subsampling_conv_chunking_factor": int, - "is_causal": True/False - } - conv2d_extra_padding: str, optional - Add extra padding in conv2d subsampling layers. Choices are - (feat, feat_time, none, True) - Default: none - replication_pad_for_subsample_embedding: For batched-streaming decoding, use - "replication" padding for the cache at start of utterance. - Default: False - attention_group_size: int, optional - the number of groups to use for attention, default 1 (Multi-Head Attention), - 1 = typical Multi-Head Attention, - 1 < attention_group_size < attention_heads = Grouped-Query Attention - attention_group_size = attenion_heads = Multi-Query Attention - """ - - extra_multi_layer_output_idxs: List[int] - - def __init__( # pylint: disable-all - self, - input_size, - chunk_size, - left_chunk, - num_lang=None, - attention_dim=256, - attention_heads=4, - linear_units=2048, - num_blocks=6, - dropout_rate=0.1, - input_layer="nemo_conv", - causal=True, - batch_norm=False, - cnn_out=-1, - cnn_layer_norm=False, - ext_pw_out_channel=0, - ext_pw_kernel_size=1, - depthwise_seperable_out_channel=256, - depthwise_multiplier=1, - chunk_se=0, - kernel_size=3, - activation="relu", - conv_activation="relu", - conv_glu_type="sigmoid", - bias_in_glu=True, - linear_glu_in_convm=False, - attention_glu_type="swish", - export=False, - extra_layer_output_idx=-1, - extra_multi_layer_output_idxs=[], - activation_checkpointing="", - relative_attention_bias_args=None, - time_reduction=4, - use_pt_scaled_dot_product_attention=False, - nemo_conv_settings=None, - conv2d_extra_padding: Literal["feat", "feat_time", "none", True] = "none", - replication_pad_for_subsample_embedding=False, - attention_group_size=1, - encoder_embedding_config=None, - ): - super().__init__( - input_size, - chunk_size, - left_chunk, - attention_dim, - attention_heads, - input_layer, - cnn_out, - cnn_layer_norm, - time_reduction, - dropout_rate=dropout_rate, - relative_attention_bias_args=relative_attention_bias_args, - positional_dropout_rate=0.0, - nemo_conv_settings=nemo_conv_settings, - conv2d_extra_padding=conv2d_extra_padding, - attention_group_size=attention_group_size, - encoder_embedding_config=encoder_embedding_config, - ) - self.num_blocks = num_blocks - self.num_lang = num_lang - self.kernel_size = kernel_size - self.embed = embedding_checkpoint_wrapper(activation_checkpointing)(self.embed) - self.replication_pad_for_subsample_embedding: bool = replication_pad_for_subsample_embedding - assert self.num_heads % attention_group_size == 0, "attention_group_size must divide n_head" - self.num_heads_k = self.num_heads // attention_group_size - - self.encoders = repeat( - num_blocks, - lambda i: encoder_checkpoint_wrapper( - activation_checkpointing, ConformerEncoderLayer, i - )( - ConformerEncoderLayer( - d_model=attention_dim, - ext_pw_out_channel=ext_pw_out_channel, - depthwise_seperable_out_channel=depthwise_seperable_out_channel, - depthwise_multiplier=depthwise_multiplier, - n_head=attention_heads, - d_ffn=linear_units, - ext_pw_kernel_size=ext_pw_kernel_size, - kernel_size=kernel_size, - dropout_rate=dropout_rate, - causal=causal, - batch_norm=batch_norm, - activation=activation, - chunk_se=chunk_se, - chunk_size=chunk_size, - conv_activation=conv_activation, - conv_glu_type=conv_glu_type, - bias_in_glu=bias_in_glu, - linear_glu_in_convm=linear_glu_in_convm, - attention_glu_type=attention_glu_type, - activation_checkpointing=attn_checkpointing(activation_checkpointing, i), - export=export, - use_pt_scaled_dot_product_attention=use_pt_scaled_dot_product_attention, - attn_group_sizes=attention_group_size, - ) - ), - ) - self.extra_layer_output_idx = extra_layer_output_idx - self.extra_multi_layer_output_idxs = extra_multi_layer_output_idxs - # Make a zeros scalar we can use in get_initial_state to determine - # the device and the needed dtype: - self.register_buffer("dev_type", torch.zeros(()), persistent=False) - - def init_relative_attention_bias(self, input_tensor): - if self.relative_attention_bias_layer: - return self.relative_attention_bias_layer(input_tensor) - - def calculate_hs_mask(self, xs_pad, device, mask): - max_audio_length = xs_pad.shape[1] - batch_size = xs_pad.shape[0] - enc_streaming_mask = self._streaming_mask( - max_audio_length, batch_size, self.chunk_size, self.left_chunk - ) - enc_streaming_mask = enc_streaming_mask.to(device) - if mask is None: - return enc_streaming_mask - - feature_lens = mask.sum(1) - padding_length = feature_lens - pad_mask = ( - torch.arange(0, max_audio_length, device=device).expand(padding_length.size(0), -1) - < padding_length.unsqueeze(1) - ) - pad_mask = pad_mask.unsqueeze(1) - pad_mask = pad_mask & enc_streaming_mask - return pad_mask - - @torch.jit.ignore - def forward(self, xs_pad, masks): - """Conformer Forward function - - Args: - xs_pad: torch.Tensor - input tensor - masks: torch.Tensor - post-embedding input lengths - """ - xs_pad = self.encoder_embedding(xs_pad) - input_tensor, pos_k, pos_v, hs_mask, masks = self.forward_embeddings(xs_pad, masks) - - unfolded = False - ori_bz, seq_len, D = input_tensor.shape - max_seq_len = 500 #maxium position for absolute positional encoding - if seq_len > max_seq_len: - # audio sequence is longer than max_seq_len, unfold it into chunks of max_seq_len - unfolded = True - # the unfold op will drop residual frames, pad it to the multiple of max_seq_len - if seq_len % max_seq_len > 0: - chunk_pad_size = max_seq_len - (seq_len % max_seq_len) - else: - chunk_pad_size = 0 - if chunk_pad_size > 0: - input_tensor_pad = F.pad(input_tensor, (0, 0, 0, chunk_pad_size), "constant", 0) - input_tensor = input_tensor_pad.to(input_tensor.device) - - input_tensor = unfold_tensor(input_tensor, max_seq_len) - if masks is not None: - # revise hs_mask here because the previous calculated hs_mask did not consider extra pad - subsampled_pad_mask = masks.squeeze(1) # [bz, subsampled_unmask_seq_len] - extra_padded_subsamlped_pad_mask = F.pad(subsampled_pad_mask, (0, chunk_pad_size), "constant", False) # extra padding to the pad mask - extra_padded_subsamlped_pad_mask = extra_padded_subsamlped_pad_mask.unsqueeze(-1).float() - masks_unfold = unfold_tensor(extra_padded_subsamlped_pad_mask, max_seq_len) # unfold the pad mask like we did to the input tensor - masks_unfold = masks_unfold.squeeze(-1).bool() # unfold op does not support bool tensor - else: - masks_unfold = None - hs_mask = self.calculate_hs_mask(input_tensor, input_tensor.device, masks_unfold) # calculate hs_mask based on the unfolded pad mask - layer_emb = None - - relative_attention_bias = self.init_relative_attention_bias(input_tensor) - - _simplified_path = ( - self.extra_layer_output_idx == -1 - and relative_attention_bias is None - ) - - if _simplified_path: - input_tensor, *_ = self.encoders(input_tensor, pos_k, pos_v, hs_mask) - else: - for i, layer in enumerate(self.encoders): - input_tensor, _, _, _ = layer( - input_tensor, - pos_k, - pos_v, - hs_mask, - relative_attention_bias=relative_attention_bias, - ) - - if i == self.extra_layer_output_idx: - layer_emb = input_tensor - if unfolded: - embed_dim = input_tensor.shape[-1] - input_tensor = input_tensor.reshape(ori_bz, -1, embed_dim) - # if we ever padded before unfolding, we need to remove the padding - if chunk_pad_size > 0: - input_tensor = input_tensor[:, :-chunk_pad_size, :] - return input_tensor, masks #, layer_emb - - def gradient_checkpointing_enable(self): - pass