# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

#!/usr/bin/env python3

# activation_checkpointing.py
"""helper function for activation checkpointing"""

from typing import Union, Dict, Callable
from functools import partial
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
    checkpoint_wrapper,
    offload_wrapper,
    CheckpointImpl,
)


# utils.py
"""cascade basic blocks"""

import math
import backoff
import random
import numpy as np
from typing import Optional, Tuple, Union
import torch
from torch import nn
from torch import Tensor
import torch.nn.functional as F


# conformer_encoder.py
"""ConformerEncoder Module"""

from typing import Optional, Tuple, List, Literal
import abc
import math
import numpy as np

import torch
from torch import nn, Tensor

from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import CheckpointWrapper
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel


# activation_checkpointing.py
def validate_checkpointing_config(activation_checkpointing):
    """validate activation checkpointing configuration"""
    if isinstance(activation_checkpointing, str):
        assert activation_checkpointing in (
            "",
            "checkpoint",
            "offload",
        ), "activation_checkpointing has to be a dict or a str in ('', 'checkpoint', 'offload')."
    elif isinstance(activation_checkpointing, dict):
        assert activation_checkpointing.get("module", "transformer") in (
            "transformer",
            "attention",
        ), "module in activation_checkpointing has to be in ('transformer', 'attention')."
    else:
        raise ValueError("activation_checkpointing has to be a str or dict.")


def embedding_checkpoint_wrapper(
    activation_checkpointing: Union[str, Dict],
) -> Callable:
    """return encoder embedding activation checkpoint wrapper"""
    validate_checkpointing_config(activation_checkpointing)

    if isinstance(activation_checkpointing, str):
        if activation_checkpointing:
            if activation_checkpointing == "offload":
                return offload_wrapper
            return partial(checkpoint_wrapper)
        return lambda x: x

    if isinstance(activation_checkpointing, dict):
        enabled = activation_checkpointing.get("embed", False)
        if enabled:
            offloading = activation_checkpointing.get("offload", False)
            if offloading:
                return offload_wrapper
            impl = (
                CheckpointImpl.REENTRANT
                if activation_checkpointing.get("reentrant", False)
                else CheckpointImpl.NO_REENTRANT
            )
            return partial(checkpoint_wrapper, checkpoint_impl=impl)
        return lambda x: x
    raise ValueError("Invalid activation_checkpointing config")


def encoder_checkpoint_wrapper(
    activation_checkpointing: Union[str, Dict],
    layer_cls: type,
    idx: int = 0,
) -> Callable:
    """return encoder activation checkpoint wrapper"""
    validate_checkpointing_config(activation_checkpointing)

    if isinstance(activation_checkpointing, str):
        if activation_checkpointing:
            if activation_checkpointing == "offload":
                return offload_wrapper
            return partial(checkpoint_wrapper)
        return lambda x: x

    if isinstance(activation_checkpointing, dict):
        target_layer_cls = activation_checkpointing.get("module", "transformer")
        if target_layer_cls.lower() == "transformer":
            target_layer_cls = (
                "EncoderLayer",
                "ConformerEncoderLayer",
            )
        elif target_layer_cls.lower() == "attention":
            target_layer_cls = ("MultiHeadedAttention", "MultiHeadAttention")
        checkpointing_interval = activation_checkpointing.get("interval", 1)
        offloading = activation_checkpointing.get("offload", False)
        impl = (
            CheckpointImpl.REENTRANT
            if activation_checkpointing.get("reentrant", True)
            else CheckpointImpl.NO_REENTRANT
        )

        if idx % checkpointing_interval == 0 and layer_cls.__name__ in target_layer_cls:
            if offloading:
                return offload_wrapper
            return partial(checkpoint_wrapper, checkpoint_impl=impl)
        return lambda x: x

    raise ValueError("Invalid activation_checkpointing config")


def attn_checkpointing(activation_checkpointing: Union[str, Dict], i) -> Union[str, Dict]:
    """return activation checkpointing config for attention layer"""
    if isinstance(activation_checkpointing, str):
        return ""

    if isinstance(activation_checkpointing, dict):
        target_layer_cls = activation_checkpointing.get("module", "transformer")
        checkpointing_interval = activation_checkpointing.get("interval", 1)
        if target_layer_cls == "attention" and i % checkpointing_interval == 0:
            return activation_checkpointing
        return ""

    raise ValueError("Invalid activation_checkpointing config")


# utils.py
class Block(nn.Module):
    """Block abstract module"""

    def __init__(self, input_size, output_size):
        super().__init__()
        self.input_size = input_size
        self.output_size = output_size

def get_activation(name="relu"):
    """Select an activation function by name

    Args:
        name: str
            activation function name,
            one of ["relu", "gelu", "swish", "sigmoid"],
            default "relu".
    """
    name = name.lower()
    if name == "relu":
        return nn.ReLU(inplace=True)
    if name == "gelu":
        return nn.GELU()
    if name == "swish":
        return Swish()
    if name == "sigmoid":
        return torch.nn.Sigmoid()
    return nn.Identity()

def adaptive_enc_mask(x_len, chunk_start_idx, left_window=0, right_window=0):
    """
    The function is very important for Transformer Transducer Streaming mode
    Args:
        xs_len (int): sequence length
        chunk_start_idx (list): first idx of each chunk, such as [0,18,36,48]. It also supports adaptive chunk size [0,10,15,45]
        left_window (int): how many left chunks can be seen
        right_window (int): how many right chunks can be seen. It is used for chunk overlap model.
        Returns:
            mask (torch.Tensor): a mask tensor for streaming model
            Torch 1.0.1
            tensor([[1., 1., 0., 0.],
                    [0., 1., 1., 0.],
                    [0., 0., 1., 1.]])
            Torch 1.4.1
            tensor([[True., True., False., False.],
                    [False., True., True., False.],
                    [False., False., True., True.]])
    """
    chunk_start_idx = torch.Tensor(
        chunk_start_idx
    ).long()  # first idx of each chunk, such as [0,18,36,48].
    start_pad = torch.nn.functional.pad(
        chunk_start_idx, (1, 0)
    )  # append 0 to the beginning, so it becomes [0, 0, 18, 36, 48]
    end_pad = torch.nn.functional.pad(
        chunk_start_idx, (0, 1), value=x_len
    )  # append x_len to the end, so it becomes [0,18,36,48, x_len]
    seq_range = torch.arange(0, x_len).unsqueeze(-1)  # seq_range size: [x_len, 1]
    idx = ((seq_range < end_pad) & (seq_range >= start_pad)).nonzero()[:, 1]  # idx size: [x_len]
    boundary = end_pad[idx]  # boundary size: [x_len]
    seq_range_expand = (
        torch.arange(0, x_len).unsqueeze(0).expand(x_len, -1)
    )  # seq_range_expand size [x_len, x_len]
    idx_left = idx - left_window
    idx_left[idx_left < 0] = 0
    boundary_left = start_pad[idx_left]
    mask_left = seq_range_expand >= boundary_left.unsqueeze(-1)
    idx_right = idx + right_window
    idx_right[idx_right > len(chunk_start_idx)] = len(chunk_start_idx)
    boundary_right = end_pad[idx_right]
    mask_right = seq_range_expand < boundary_right.unsqueeze(-1)
    return mask_left & mask_right

class Swish(nn.Module):
    """Implement Swish activation module.
    From https://arxiv.org/pdf/2005.03191.pdf

    """

    def __init__(self) -> None:
        super().__init__()
        self.act_fn = nn.Sigmoid()

    def forward(self, x: Tensor) -> Tensor:
        """Apply Swish function

        Args:
            x: torch.Tensor
                Input.
        """
        return x * self.act_fn(x)

class GLU(nn.Module):
    """Implement Gated Linear Unit (GLU) module"""

    def __init__(self, dim: int = -1, act_name: str = "sigmoid") -> None:
        super().__init__()
        self.dim = dim
        self.act_name = act_name.lower()

        if self.act_name == "relu":
            self.act_fn = nn.ReLU(inplace=True)
        elif self.act_name == "gelu":
            self.act_fn = nn.GELU()
        elif self.act_name == "swish":
            self.act_fn = Swish()
        elif self.act_name == "sigmoid":
            self.act_fn = nn.Sigmoid()
        else:
            self.act_fn = nn.Identity()

    def forward(self, x: Tensor) -> Tensor:
        """GLU forward
        Apply Swish function on the first half of input matrices
        with sigmoid of the second half.

        Args:
            x: torch.Tensor
                Input.

        """
        half_x, gate = x.chunk(2, dim=self.dim)
        return half_x * self.act_fn(gate)

# TODO: Abdel, this can be improved using GLU module
class GLUPointWiseConv(nn.Module):
    """GLUPointWiseConv module
    used for conformer architecture,
    for more details see:
    https://arxiv.org/pdf/2005.08100v1.pdf

    Args:
        input_dim: int
            input channel size.
        output_dim: int
            output channel size.
        kernel_size: int
            kernel size
        glu_type: str, optional
            activation function one of
             ["sigmoid", "relu", "gelu"]
              default "sigmoid".
        bias_in_glu: bool, optional
            use addtive bias in glu
        causal: bool, optional
            if set to True, padding is set to the half of
             kernel size, ie, convolution can't see future frames.
              default False.

    """

    def __init__(
        self, input_dim, output_dim, kernel_size, glu_type="sigmoid", bias_in_glu=True, causal=False
    ):
        super().__init__()

        self.glu_type = glu_type
        self.output_dim = output_dim
        self.bias_in_glu = bias_in_glu
        if causal:
            self.ext_pw_conv_1d = nn.Conv1d(
                input_dim, output_dim * 2, kernel_size, 1, padding=(kernel_size - 1)
            )
        else:
            self.ext_pw_conv_1d = nn.Conv1d(
                input_dim, output_dim * 2, kernel_size, 1, padding=(kernel_size - 1) // 2
            )

        if glu_type == "sigmoid":
            self.glu_act = nn.Sigmoid()
        elif glu_type == "relu":
            self.glu_act = nn.ReLU()
        elif glu_type == "gelu":
            self.glu_act = nn.GELU()
        elif glu_type == "swish":
            self.glu_act = Swish()
        else:
            raise ValueError(f"Unsupported activation type {self.glu_act}")

        if bias_in_glu:
            self.b1 = nn.Parameter(torch.zeros(1, output_dim, 1))
            self.b2 = nn.Parameter(torch.zeros(1, output_dim, 1))

    def forward(self, x):
        """
        Args:
            x: torch.Tensor
                input tensor
        """
        # to be consistent with GLULinear, we assume the input always has the #channel (#dim) in the last dimension of the tensor, so need to switch the dimension first for 1D-Conv case
        x = x.permute([0, 2, 1])
        x = self.ext_pw_conv_1d(x)
        if self.glu_type == "bilinear":
            if self.bias_in_glu:
                x = (x[:, 0 : self.output_dim, :] + self.b1) * (
                    x[:, self.output_dim : self.output_dim * 2, :] + self.b2
                )
            else:
                x = (x[:, 0 : self.output_dim, :]) * (
                    x[:, self.output_dim : self.output_dim * 2, :]
                )
        else:
            if self.bias_in_glu:
                x = (x[:, 0 : self.output_dim, :] + self.b1) * self.glu_act(
                    x[:, self.output_dim : self.output_dim * 2, :] + self.b2
                )
            else:
                x = (x[:, 0 : self.output_dim, :]) * self.glu_act(
                    x[:, self.output_dim : self.output_dim * 2, :]
                )

        x = x.permute([0, 2, 1])
        return x


class DepthWiseSeperableConv1d(nn.Module):
    """DepthWiseSeperableConv1d module used in Convnet module
    for the conformer, for more details see:
    https://arxiv.org/pdf/2005.08100v1.pdf

    Args:
        input_dim: int
            input channel size.
        depthwise_seperable_out_channel: int
            if set different to 0, the number of depthwise_seperable_out_channel
             will be used as a channel_out of the second conv1d layer.
             otherwise, it equal to 0, the second conv1d layer is skipped.
        kernel_size: int
            kernel_size
        depthwise_multiplier: int
            number of input_dim channels duplication. this value
            will be used to compute the hidden channels of the Conv1D.
        padding: int, optional
            padding for the conv1d,
             default: 0.

    """

    def __init__(
        self,
        input_dim,
        depthwise_seperable_out_channel,
        kernel_size,
        depthwise_multiplier,
        padding=0,
    ):
        super().__init__()

        self.dw_conv = nn.Conv1d(
            input_dim,
            input_dim * depthwise_multiplier,
            kernel_size,
            1,
            padding=padding,
            groups=input_dim,
        )

        if depthwise_seperable_out_channel != 0:
            self.pw_conv = nn.Conv1d(
                input_dim * depthwise_multiplier, depthwise_seperable_out_channel, 1, 1, 0
            )
        else:
            self.pw_conv = nn.Identity()
        self.depthwise_seperable_out_channel = depthwise_seperable_out_channel

    def forward(self, x):
        """

        Args:
            x: torch.Tensor
                input tensor
        """
        x = self.dw_conv(x)
        if self.depthwise_seperable_out_channel != 0:
            x = self.pw_conv(x)
        return x


class ConvModule(nn.Module):
    """ConvModule Module for the conformer block.
    for more details see:
    https://arxiv.org/pdf/2005.08100v1.pdf

    Args:
        input_dim: int
            input channel size.
        ext_pw_out_channel: int
            if > 0, ext_pw_out_channel is a dim channel size
             for the last pointwise conv after swish activation.
        depthwise_seperable_out_channel: int
            if set different to 0, the number of depthwise_seperable_out_channel
             will be used as a channel_out of the second conv1d layer.
             otherwise, it equal to 0, the second conv1d layer is skipped.
        ext_pw_kernel_size: int
            kernel size of the conv pointwise of the conformer.
        kernel_size: int
            kernel size.
        depthwise_multiplier: int
            number of input_dim channels duplication. this value
             will be used to compute the hidden channels of the Conv1D.
        dropout_rate: float
            dropout rate.
        causal: bool, optional
            if set to True, convolution have no access
             to future frames. default False.
        batch_norm: bool, optional
            if set to True, apply batchnorm before activation.
            default False
        chunk_se: int, optional
            0 for offline SE.
            1 for streaming SE, where mean is computed
             by accumulated history until current chunk_se.
            2 for streaming SE, where mean is computed
             by only the current chunk.
        chunk_size: int, optional
            chunk size for cnn. default 18
        activation: str, optional
            activation function used in ConvModule,
            default: "relu".
        glu_type: str, optional
            activation function used for the glu,
            default: "sigmoid".
        bias_in_glu: bool, optional
            if set to True, use additive bias in the weight module
             before GLU.
        linear_glu_in_convm: bool, optional
            if set to True, use GLULinear module,
             otherwise, used GLUPointWiseConv module.
              default to False.
        export: bool, optional,
            if set to True, padding is equal to 0.  This is for inference,
             or onnx export.  Typically this is set by the export program or
             the decoder program, and it isn't present in your config file.
             default False
    """

    def __init__(
        self,
        input_dim,
        ext_pw_out_channel,
        depthwise_seperable_out_channel,
        ext_pw_kernel_size,
        kernel_size,
        depthwise_multiplier,
        dropout_rate,
        causal=False,
        batch_norm=False,
        chunk_se=0,
        chunk_size=18,
        activation="relu",
        glu_type="sigmoid",
        bias_in_glu=True,
        linear_glu_in_convm=False,
        export=False,
    ):
        super().__init__()
        self.layer_norm = nn.LayerNorm(input_dim)
        self.input_dim = input_dim
        self.ext_pw_out_channel = ext_pw_out_channel
        self.ext_pw_kernel_size = ext_pw_kernel_size
        self.depthwise_seperable_out_channel = depthwise_seperable_out_channel
        self.glu_type = glu_type
        self.bias_in_glu = bias_in_glu
        self.linear_glu_in_convm = linear_glu_in_convm
        self.causal = causal

        self._add_ext_pw_layer()

        self.batch_norm = batch_norm
        self.kernel_size = kernel_size

        if batch_norm:
            self.bn_layer = nn.BatchNorm1d(input_dim)

        self.act = get_activation(activation)
        self.dropout = nn.Dropout(dropout_rate)
        self.export = export

        if causal:
            if export:  # Inference only.
                padding = 0  # A cache is concatenated to the left. No padding in the kernel.
            else:
                # Training only. Padding will be added symmetrically on both sides.
                # After convolution, clip off kernel_size-1 points on the right.
                padding = kernel_size - 1
        else:
            padding = (kernel_size - 1) // 2

        self.dw_sep_conv_1d = DepthWiseSeperableConv1d(
            input_dim,
            depthwise_seperable_out_channel,
            kernel_size,
            depthwise_multiplier,
            padding=padding,
        )

        if depthwise_seperable_out_channel != 0:
            if input_dim != depthwise_seperable_out_channel:
                self.ln2 = nn.Linear(depthwise_seperable_out_channel, input_dim)
        else:
            if depthwise_multiplier != 1:
                self.ln2 = nn.Linear(input_dim * depthwise_multiplier, input_dim)

    def _add_ext_pw_layer(self):
        """
        This function is an extension of __init__ function
        and dedicated to the convolution module creation
        of the conformer.
        """
        self.ln1 = self.glu = self.bn_layer = self.ext_pw_conv_1d = nn.Identity()  # jit hacks.
        self.squeeze_excitation = nn.Identity()  # jit.
        self.apply_ln1 = self.fix_len1 = False  # jit.

        if self.ext_pw_out_channel != 0:
            if self.causal:
                self.ext_pw_conv_1d = nn.Conv1d(
                    self.input_dim,
                    self.ext_pw_out_channel,
                    self.ext_pw_kernel_size,
                    1,
                    padding=(self.ext_pw_kernel_size - 1),
                )
                if self.ext_pw_kernel_size > 1:
                    self.fix_len1 = True
                else:
                    self.fix_len1 = False
            else:
                self.ext_pw_conv_1d = nn.Conv1d(
                    self.input_dim,
                    self.ext_pw_out_channel,
                    self.ext_pw_kernel_size,
                    1,
                    padding=(self.ext_pw_kernel_size - 1) // 2,
                )
                self.fix_len1 = False

            if self.linear_glu_in_convm:
                self.glu = GLULinear(
                    self.input_dim, self.ext_pw_out_channel, self.glu_type, self.bias_in_glu
                )
            else:
                self.glu = GLUPointWiseConv(
                    self.input_dim,
                    self.ext_pw_out_channel,
                    self.ext_pw_kernel_size,
                    self.glu_type,
                    self.bias_in_glu,
                    self.causal,
                )

            if self.input_dim != self.ext_pw_out_channel:
                self.apply_ln1 = True
                self.ln1 = nn.Linear(self.ext_pw_out_channel, self.input_dim)
            else:
                self.apply_ln1 = False
        else:
            self.pw_conv_simplify_w = torch.nn.Parameter(torch.ones(3))
            self.pw_conv_simplify_b = torch.nn.Parameter(torch.zeros(3))

    def forward(self, x):
        """ConvModule Forward.

        Args:
            x: torch.Tensor
                input tensor.
        """
        x = self.layer_norm(x)

        if self.ext_pw_out_channel != 0:
            x = self.glu(x)
            if self.causal and self.ext_pw_kernel_size > 1:
                x = x[:, : -(self.ext_pw_kernel_size - 1), :]
            if self.apply_ln1:
                x = self.ln1(x)
        else:
            x_0 = x * self.pw_conv_simplify_w[0] + self.pw_conv_simplify_b[0]
            x_1 = x * self.pw_conv_simplify_w[1] + self.pw_conv_simplify_b[1]
            x = x_0 + x_1

        x = x.permute([0, 2, 1])

        x = self.dw_sep_conv_1d(x)
        if self.causal and self.kernel_size > 1:
            x = x[:, :, : -(self.kernel_size - 1)]
        if hasattr(self, "ln2"):
            x = x.permute([0, 2, 1])
            x = self.ln2(x)
            x = x.permute([0, 2, 1])
        if self.batch_norm:
            x = self.bn_layer(x)
        x = self.act(x)

        if self.ext_pw_out_channel != 0:
            x = self.ext_pw_conv_1d(x)
            if self.fix_len1:
                x = x[:, :, : -(self.ext_pw_kernel_size - 1)]

            if self.apply_ln1:
                x = x.permute([0, 2, 1])
                x = self.ln1(x)
                x = x.permute([0, 2, 1])

            x = x.permute([0, 2, 1])
        else:
            x = x.unsqueeze(1).permute([0, 1, 3, 2])
            x = x * self.pw_conv_simplify_w[2] + self.pw_conv_simplify_b[2]
            x = x.squeeze(1)

        x = self.dropout(x)
        return x

class GLULinear(nn.Module):
    """Linear + GLU module

    Args:
        input_dim: int
            input size
        output_dim: int
            output size.
        glu_type:
            activation function name used in glu module.
            default "sigmoid" (swish function).
        bias_in_glu: bool, optional
            If True, the addtive bias is added. Default False.
    """

    def __init__(
        self,
        input_dim,
        output_dim,
        glu_type="sigmoid",
        bias_in_glu=True,
    ):
        super().__init__()
        self.linear = nn.Linear(input_dim, output_dim * 2, bias_in_glu)
        self.glu_act = GLU(-1, glu_type)

    def forward(self, x):
        """GLULinear forward

        Args:
            x: torch.Tensor
                inpute tensor.
        """
        x = self.linear(x)
        return self.glu_act(x)

class FeedForward(nn.Module):
    """FeedForward Module.
    For more details see Conformer paper:
        https://arxiv.org/pdf/2005.08100.pdf

    Args:
        d_model: int
            input size.
        d_inner: int
            output size.
        dropout_rate: float,
            dropout rate.
        activation: str,
            activation function name,
            one of ["relu", "swish", "sigmoid"],
            sigmoid activation is only used with "glu_in_fnn=True",
            default "sigmoid".
        bias_in_glu: bool, optional
    """

    def __init__(
        self,
        d_model,
        d_inner,
        dropout_rate,
        activation="sigmoid",
        bias_in_glu=True,
    ):
        super().__init__()
        self.d_model = d_model
        self.d_inner = d_inner

        self.layer_norm = nn.LayerNorm(d_model)
        module = GLULinear(d_model, d_inner, activation, bias_in_glu)
        self.net = nn.Sequential(
            module,
            nn.Dropout(dropout_rate),
            nn.Linear(d_inner, d_model),
            nn.Dropout(dropout_rate),
        )

    def forward(self, x):
        """FeedForward forward function.

        Args:
            x: torch.Tensor
                input tensor.
        """
        out = self.net(self.layer_norm(x))
    
        return out

#### positional encoding starts here
def _pre_hook(
    state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
):
    """Perform pre-hook in load_state_dict for backward compatibility.

    Note:
        We saved self.pe until v.0.5.2 but we have omitted it later.
        Therefore, we remove the item "pe" from `state_dict` for backward compatibility.

    """
    k = prefix + "pe"
    if k in state_dict:
        state_dict.pop(k)

class T5RelativeAttentionLogitBias(nn.Module):
    """
    This module implements the relative position bias described in Section 2.1 of
    the T5 paper: https://arxiv.org/pdf/1910.10683.pdf

    The Huggingface implementation is used as a reference
    https://github.com/huggingface/transformers/blob/v4.30.0/src/transformers/models/t5/modeling_t5.py#L435

    Modifies attention as Q*K^T + B, where B is a learned scalar bias based on relative position
    of the query and key. It is HxNxN, where H is the number of heads, N is the sequence length.

    I've made these modifications to the original T5 bias:
    - Skipping of the bucketing step. Original T5 bias converted rel position distances into
      logarithmically increasing buckets. This is supposed to help with length generalization.
    - I just directly use rel position index as bias values, as we don't need length
      generalization (40s max is good enough for ASR encoder), and it keeps ONNX export simple.
    - I've also extended it so that biases can be asymmetric, the default implementation treats
      L->R and R->L the same. Asymmetric was found to yield better results in my experiments.

    Args:
        num_heads: int
            Number of attention heads
        num_buckets: int
            Number of buckets to use for relative attention bias. This is the size of the learnable
            bias parameter. Bucketing is not yet supported, so this defaults to -1 which means
            no bucketing is used (max_distance determines size of bias param).
        max_distance: int
            Maximum distance to use for relative attention bias. With num_buckets=-1, this directly
            controls the max size of the bias parameter. When num_buckets > 0 is supported, this
            will control the maximum distance for logarithmic bucketing after which all positions
            are in the same bucket.
        symmetric: bool
            Whether to use symmetric or asymmetric biases. symmetric=False uses 2x number of bias
            params to distinguish L->R from R->L. This was found to be better for the encoder.
    """

    def __init__(self, num_heads, num_buckets=-1, max_distance=1000, symmetric=False):
        super().__init__()
        self.num_heads = num_heads
        self.num_buckets = num_buckets
        self.max_distance = max_distance
        self.symmetric = symmetric
        self._skip_bucketing = self.num_buckets < 0
        if self._skip_bucketing:
            self.num_buckets = max_distance
        else:
            raise NotImplementedError("T5 attention bias with bucketed positions is not yet tested")
        if not self.symmetric:
            self.num_buckets *= 2
        self.bias_values = nn.Embedding(self.num_buckets, self.num_heads)

    def forward(self, x):
        # instantiate bias compatible with shape of x
        maxpos = x.size(1)
        context_position = torch.arange(maxpos, device=x.device, dtype=torch.long)[:, None]
        memory_position = torch.arange(maxpos, device=x.device, dtype=torch.long)[None, :]
        relative_position = memory_position - context_position
        # clipping to a maximum distance using ops that play well with ONNX export
        relative_position = relative_position.masked_fill(
            relative_position < -self.max_distance, -self.max_distance
        )
        relative_position = relative_position.masked_fill(
            relative_position > self.max_distance - 1, self.max_distance - 1
        )

        # mapping from relative position to index in the bias parameter
        if self._skip_bucketing:
            bias_idx = relative_position
        else:
            bias_idx = self._bucket_relative_position(relative_position)
        if self.symmetric:
            bias_idx = bias_idx.abs()
        else:
            bias_idx += self.num_buckets // 2

        t5_rel_att_bias = self.bias_values(bias_idx)  # [L, L, H]
        t5_rel_att_bias = t5_rel_att_bias.permute(2, 0, 1).unsqueeze(0)  # [1, H, L, L]

        return t5_rel_att_bias

    def _bucket_relative_position(self, relative_position):
        # this is a placeholder (isn't tested, likely buggy) using HuggingFace implem as a reference
        # this also needs to be extended to support asymmetric +/- ve positions
        relative_buckets = 0
        if not self.causal:
            num_buckets //= 2
            relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
            relative_position = torch.abs(relative_position)
        else:
            relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
        # now relative_position is in the range [0, inf)

        # half of the buckets are for exact increments in positions
        max_exact = num_buckets // 2
        is_small = relative_position < max_exact

        # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
        relative_position_if_large = max_exact + (
            torch.log(relative_position.float() / max_exact)
            / math.log(self.max_distance / max_exact)
            * (num_buckets - max_exact)
        ).to(torch.long)
        relative_position_if_large = torch.min(
            relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1)
        )

        relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
        return relative_buckets

class AbsolutePositionalEncoding(nn.Module):
    """Absolute Positional encoding module.
    This module implement Absolute sinusoidal positional encoding
    from: https://arxiv.org/pdf/1706.03762.pdf

    Args:
        d_model: int
            Input embedding size.
        dropout_rate: float
            dropout rate
        max_len: int, optional
            Maximum input length sequence, Default 5000

    """

    def __init__(self, d_model, dropout_rate, max_len=5000):
        """Construct an PositionalEncoding object."""
        super().__init__()
        self.d_model = d_model
        self.xscale = math.sqrt(self.d_model)
        self.dropout = torch.nn.Dropout(p=dropout_rate)
        self.pe = None
        self.extend_pe(torch.tensor(0.0).expand(1, max_len))
        self._register_load_state_dict_pre_hook(_pre_hook)

    def extend_pe(self, x):
        """Reset the positional encodings.

        Args:
            x: torch.Tensor
        """
        if self.pe is not None:
            if self.pe.size(1) >= x.size(1):
                if self.pe.dtype != x.dtype or self.pe.device != x.device:
                    self.pe = self.pe.to(dtype=x.dtype, device=x.device)
                return
        pe = torch.zeros(x.size(1), self.d_model)
        position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, self.d_model, 2, dtype=torch.float32)
            * -(math.log(10000.0) / self.d_model)
        )
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.pe = pe.to(device=x.device, dtype=x.dtype)

    def forward(self, x: torch.Tensor):
        """Add positional encoding.

        Args:
            x: torch.Tensor
                Input tensor. shape is (batch, time, ...)

        Returns:
            torch.Tensor: Encoded tensor. Its shape is (batch, time, ...)

        """
        self.extend_pe(x)
        x = x * self.xscale + self.pe[:, : x.size(1)]
        return self.dropout(x)

#### forward embedding layers starts here

@backoff.on_exception(backoff.expo, Exception, max_tries=10)
def np_loadtxt_with_retry(filepath):
    """np.loadtxt with retry

    Args:
        filepath: str
            file path to the numpy array.
    """
    result = np.loadtxt(filepath, dtype="f")
    return result

class MeanVarianceNormLayer(nn.Module):
    """Mean/variance normalization layer.

    Will substract mean and multiply input by inverted standard deviation.
    Typically used as a very first layer in a model.

    Args:
        input_size: int
            layer input size.
    """

    def __init__(self, input_size):
        super().__init__()
        self.input_size = input_size
        self.register_buffer("global_mean", torch.zeros(input_size))
        self.register_buffer("global_invstd", torch.ones(input_size))
        self.global_mean: Optional[Tensor]
        self.global_invstd: Optional[Tensor]

    def forward(self, input_: Tensor) -> Tensor:
        """MeanVarianceNormLayer Forward

        Args:
            input_: torch.Tensor
                input tensor.
        """
        return (input_ - self.global_mean) * self.global_invstd

    def load_mean_invstd(self, mean_file, invstd_file, cuside_features=False):
        """Load feature mean and variance used for normalization.

        Args:
            mean_file: str
                path to the feature mean statistics file.
            invstd_file: str
                path to the features inverted standard deviation
                 statistics file.
            cuside_features: bool
                Boolean that indicates CUSIDE is being used.
                The statistics of CUSIDE features are copied
                from the normal features
        """
        self.global_mean.data = torch.from_numpy(np_loadtxt_with_retry(mean_file))
        self.global_invstd.data = torch.from_numpy(np_loadtxt_with_retry(invstd_file))

        if cuside_features:
            self.global_mean.data = torch.cat((self.global_mean.data, self.global_mean.data), 0)
            self.global_invstd.data = torch.cat(
                (self.global_invstd.data, self.global_invstd.data), 0
            )

class CausalConv1D(nn.Conv1d):
    """
    A causal version of nn.Conv1d where each step would have limited access to locations on its right or left
    All arguments are the same as nn.Conv1d except padding.

    If padding is set None, then paddings are set automatically to make it a causal convolution where each location would not see any steps on its right.

    If padding is set as a list (size of 2), then padding[0] would be used as left padding and padding[1] as right padding.
    It would make it possible to control the number of steps to be accessible on the right and left.
    This mode is not supported when stride > 1. padding[0]+padding[1] should be equal to (kernel_size - 1).
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: int,
        stride: int = 1,
        padding: Union[str, int] = 0,
        dilation: int = 1,
        groups: int = 1,
        bias: bool = True,
        padding_mode: str = "zeros",
        device=None,
        dtype=None,
    ) -> None:
        self.cache_drop_size = None
        if padding is None:
            self._left_padding = kernel_size - 1
            self._right_padding = stride - 1
        else:
            if stride != 1 and padding != kernel_size - 1:
                raise ValueError("No striding allowed for non-symmetric convolutions!")
            if isinstance(padding, int):
                self._left_padding = padding
                self._right_padding = padding
            elif (
                isinstance(padding, list)
                and len(padding) == 2
                and padding[0] + padding[1] == kernel_size - 1
            ):
                self._left_padding = padding[0]
                self._right_padding = padding[1]
            else:
                raise ValueError(f"Invalid padding param: {padding}!")

        self._max_cache_len = self._left_padding

        super().__init__(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=0,
            dilation=dilation,
            groups=groups,
            bias=bias,
            padding_mode=padding_mode,
            device=device,
            dtype=dtype,
        )

    def update_cache(self, x, cache=None):
        if cache is None:
            new_x = F.pad(x, pad=(self._left_padding, self._right_padding))
            next_cache = cache
        else:
            new_x = F.pad(x, pad=(0, self._right_padding))
            new_x = torch.cat([cache, new_x], dim=-1)
            if self.cache_drop_size > 0:
                next_cache = new_x[:, :, : -self.cache_drop_size]
            else:
                next_cache = new_x
            next_cache = next_cache[:, :, -cache.size(-1) :]
        return new_x, next_cache

    def forward(self, x, cache=None):
        x, cache = self.update_cache(x, cache=cache)
        x = super().forward(x)
        if cache is None:
            return x
        else:
            return x, cache


class CausalConv2D(nn.Conv2d):
    """
    A causal version of nn.Conv2d where each location in the 2D matrix would have no access to locations on its right or down
    All arguments are the same as nn.Conv2d except padding which should be set as None
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: int,
        stride: int = 1,
        padding: Union[str, int] = 0,
        dilation: int = 1,
        groups: int = 1,
        bias: bool = True,
        padding_mode: str = "zeros",
        device=None,
        dtype=None,
    ) -> None:
        if padding is not None:
            raise ValueError("Argument padding should be set to None for CausalConv2D.")
        self._left_padding = kernel_size - 1
        self._right_padding = stride - 1

        padding = 0
        super().__init__(
            in_channels,
            out_channels,
            kernel_size,
            stride,
            padding,
            dilation,
            groups,
            bias,
            padding_mode,
            device,
            dtype,
        )

    def forward(
        self,
        x,
    ):
        if self.training:
            x = F.pad(
                x,
                pad=(
                    self._left_padding,
                    self._right_padding,
                    self._left_padding,
                    self._right_padding,
                ),
            )
        else:
            x = F.pad(
                x,
                pad=(self._left_padding, self._right_padding, 0, 0),
            )
        x = super().forward(x)
        return x


class NemoConvSubsampling(torch.nn.Module):
    """Convlutional subsampling module, taken from NeMo ASR
    (https://github.com/NVIDIA/NeMo/blob/b367413645d5c72db3c2c96e46e95a34501479cf/nemo/collections/asr/parts/submodules/subsampling.py)

    Striding Subsampling: "Speech-Transformer: A No-Recurrence Sequence-to-Sequence Model for
    Speech Recognition" by Linhao Dong et al. (https://ieeexplore.ieee.org/document/8462506)


    Compared with the EncoderConv2D (`input_layer: custom`), this is a much simplified approach,
    and uses no LayerNorm and far fewer Conv2Ds.  Moreover, depthwise convolutions are used to reduce
    FLOPs, but the first layer is kept as a regular convolution so as not to degrade accuracy.

    `Striding` and `dw_striding` are the same except that the latter uses depthwise convolutions
    after the first layer, whereas the former does not.

    Args:
        subsampling_factor (int): Time reduction factor
        feat_in (int): size of the input features
        feat_out (int): size of the output features
        subsampling (str): The subsampling technique, choose from
            {"striding", "dw-striding", "striding_conv1d", "dw_striding_conv1d"}
        conv_channels (int): Number of channels for the convolution layers, default is 256.
        subsampling_conv_chunking_factor (int): Input chunking factor which can be -1 (no chunking)
            1 (auto) or a power of 2. Default is 1
        activation (Module): activation function, default is nn.ReLU()
        is_causal (bool): whether to use causal Conv1/2D, where each step will have limited access
            to locations on its right or left
    """

    def __init__(
        self,
        feat_in,
        feat_out,
        subsampling_factor=4,
        subsampling="dw_striding",
        conv_channels=256,
        subsampling_conv_chunking_factor=1,
        activation=nn.ReLU(),
        is_causal=False,
    ):
        super().__init__()
        self._subsampling = subsampling
        self._conv_channels = conv_channels
        self._feat_in = feat_in
        self._feat_out = feat_out

        if subsampling_factor % 2 != 0:
            raise ValueError("Sampling factor should be a multiply of 2!")
        self._sampling_num = int(math.log(subsampling_factor, 2))
        self.subsampling_factor = subsampling_factor
        self.is_causal = is_causal
        self.subsampling_causal_cond = subsampling in ("dw_striding", "striding", "striding_conv1d")

        if (
            subsampling_conv_chunking_factor != -1
            and subsampling_conv_chunking_factor != 1
            and subsampling_conv_chunking_factor % 2 != 0
        ):
            raise ValueError("subsampling_conv_chunking_factor should be -1, 1, or a power of 2")
        self.subsampling_conv_chunking_factor = subsampling_conv_chunking_factor

        in_channels = 1
        layers = []

        if subsampling == "dw_striding":
            self._stride = 2
            self._kernel_size = 3
            self._ceil_mode = False

            if self.is_causal:
                self._left_padding = self._kernel_size - 1
                self._right_padding = self._stride - 1
                self._max_cache_len = subsampling_factor + 1
            else:
                self._left_padding = (self._kernel_size - 1) // 2
                self._right_padding = (self._kernel_size - 1) // 2
                self._max_cache_len = 0

            # Layer 1
            if self.is_causal:
                layers.append(
                    CausalConv2D(
                        in_channels=in_channels,
                        out_channels=conv_channels,
                        kernel_size=self._kernel_size,
                        stride=self._stride,
                        padding=None,
                    )
                )
            else:
                layers.append(
                    torch.nn.Conv2d(
                        in_channels=in_channels,
                        out_channels=conv_channels,
                        kernel_size=self._kernel_size,
                        stride=self._stride,
                        padding=self._left_padding,
                    )
                )
            in_channels = conv_channels
            layers.append(activation)

            for i in range(self._sampling_num - 1):
                if self.is_causal:
                    layers.append(
                        CausalConv2D(
                            in_channels=in_channels,
                            out_channels=in_channels,
                            kernel_size=self._kernel_size,
                            stride=self._stride,
                            padding=None,
                            groups=in_channels,
                        )
                    )
                else:
                    layers.append(
                        torch.nn.Conv2d(
                            in_channels=in_channels,
                            out_channels=in_channels,
                            kernel_size=self._kernel_size,
                            stride=self._stride,
                            padding=self._left_padding,
                            groups=in_channels,
                        )
                    )

                layers.append(
                    torch.nn.Conv2d(
                        in_channels=in_channels,
                        out_channels=conv_channels,
                        kernel_size=1,
                        stride=1,
                        padding=0,
                        groups=1,
                    )
                )
                layers.append(activation)
                in_channels = conv_channels

        elif subsampling == "striding":
            self._stride = 2
            self._kernel_size = 3
            self._ceil_mode = False

            if self.is_causal:
                self._left_padding = self._kernel_size - 1
                self._right_padding = self._stride - 1
                self._max_cache_len = subsampling_factor + 1
            else:
                self._left_padding = (self._kernel_size - 1) // 2
                self._right_padding = (self._kernel_size - 1) // 2
                self._max_cache_len = 0

            for i in range(self._sampling_num):
                if self.is_causal:
                    layers.append(
                        CausalConv2D(
                            in_channels=in_channels,
                            out_channels=conv_channels,
                            kernel_size=self._kernel_size,
                            stride=self._stride,
                            padding=None,
                        )
                    )
                else:
                    layers.append(
                        torch.nn.Conv2d(
                            in_channels=in_channels,
                            out_channels=conv_channels,
                            kernel_size=self._kernel_size,
                            stride=self._stride,
                            padding=self._left_padding,
                        )
                    )
                layers.append(activation)
                in_channels = conv_channels

        elif subsampling == "striding_conv1d":
            in_channels = feat_in

            self._stride = 2
            self._kernel_size = 5
            self._ceil_mode = False

            if self.is_causal:
                self._left_padding = self._kernel_size - 1
                self._right_padding = self._stride - 1
                self._max_cache_len = subsampling_factor + 1
            else:
                self._left_padding = (self._kernel_size - 1) // 2
                self._right_padding = (self._kernel_size - 1) // 2
                self._max_cache_len = 0

            for i in range(self._sampling_num):
                if self.is_causal:
                    layers.append(
                        CausalConv1D(
                            in_channels=in_channels,
                            out_channels=feat_out if self._sampling_num == i + 1 else conv_channels,
                            kernel_size=self._kernel_size,
                            stride=self._stride,
                            padding=None,
                        )
                    )
                else:
                    layers.append(
                        torch.nn.Conv1d(
                            in_channels=in_channels,
                            out_channels=feat_out if self._sampling_num == i + 1 else conv_channels,
                            kernel_size=self._kernel_size,
                            stride=self._stride,
                            padding=self._left_padding,
                        )
                    )
                layers.append(activation)
                in_channels = conv_channels

        elif subsampling == "dw_striding_conv1d":
            in_channels = feat_in

            self._stride = 2
            self._kernel_size = 5
            self._ceil_mode = False

            self._left_padding = (self._kernel_size - 1) // 2
            self._right_padding = (self._kernel_size - 1) // 2

            # Layer 1
            layers.extend(
                [
                    torch.nn.Conv1d(
                        in_channels=in_channels,
                        out_channels=in_channels,
                        kernel_size=self._kernel_size,
                        stride=self._stride,
                        padding=self._left_padding,
                        groups=in_channels,
                    ),
                    torch.nn.Conv1d(
                        in_channels=in_channels,
                        out_channels=feat_out if self._sampling_num == 1 else conv_channels,
                        kernel_size=1,
                        stride=1,
                        padding=0,
                        groups=1,
                    ),
                ]
            )
            in_channels = conv_channels
            layers.append(activation)

            for i in range(self._sampling_num - 1):
                layers.extend(
                    [
                        torch.nn.Conv1d(
                            in_channels=in_channels,
                            out_channels=in_channels,
                            kernel_size=self._kernel_size,
                            stride=self._stride,
                            padding=self._left_padding,
                            groups=in_channels,
                        ),
                        torch.nn.Conv1d(
                            in_channels=in_channels,
                            out_channels=feat_out if self._sampling_num == i + 2 else conv_channels,
                            kernel_size=1,
                            stride=1,
                            padding=0,
                            groups=1,
                        ),
                    ]
                )
                layers.append(activation)
                in_channels = conv_channels

        else:
            raise ValueError(f"Not valid sub-sampling: {subsampling}!")

        if subsampling in ["dw_striding", "striding"]:
            in_length = torch.tensor(feat_in, dtype=torch.float)
            out_length = calc_length(
                lengths=in_length,
                all_paddings=self._left_padding + self._right_padding,
                kernel_size=self._kernel_size,
                stride=self._stride,
                ceil_mode=self._ceil_mode,
                repeat_num=self._sampling_num,
            )
            self.out = torch.nn.Linear(conv_channels * int(out_length), feat_out)
            self.conv2d_subsampling = True
        elif subsampling in ["striding_conv1d", "dw_striding_conv1d"]:
            self.out = None
            self.conv2d_subsampling = False
        else:
            raise ValueError(f"Not valid sub-sampling: {subsampling}!")

        self.conv = torch.nn.Sequential(*layers)

    def get_sampling_frames(self):
        return [1, self.subsampling_factor]

    def get_streaming_cache_size(self):
        return [0, self.subsampling_factor + 1]

    def forward(self, x, mask):
        """
        Forward method for NeMo subsampling.

        Args:
            x[Batch, Time, Filters]: torch.Tensor
                input tensor
            x_mask: torch.Tensor
                input mask

        Returns:
            x: torch.Tensor
                Resulting tensor from subsampling (B, T // time_reduction_factor, feat_out)
            pad_mask: torch.Tensor
                tensor of padded hidden state sequences (B, 1, T // time_reduction_factor)
        """
        # Unsqueeze Channel Axis
        if self.conv2d_subsampling:
            x = x.unsqueeze(1)
        # Transpose to Channel First mode
        else:
            x = x.transpose(1, 2)

        # split inputs if chunking_factor is set
        if self.subsampling_conv_chunking_factor != -1 and self.conv2d_subsampling:
            if self.subsampling_conv_chunking_factor == 1:
                # if subsampling_conv_chunking_factor is 1, we split only if needed
                # avoiding a bug / feature limiting indexing of tensors to 2**31
                # see https://github.com/pytorch/pytorch/issues/80020
                x_ceil = 2**31 / self._conv_channels * self._stride * self._stride
                if torch.numel(x) > x_ceil:
                    need_to_split = True
                else:
                    need_to_split = False
            else:
                # if subsampling_conv_chunking_factor > 1 we always split
                need_to_split = True

            if need_to_split:
                x, success = self.conv_split_by_batch(x)
                if not success:  # if unable to split by batch, try by channel
                    if self._subsampling == "dw_striding":
                        x = self.conv_split_by_channel(x)
                    else:
                        x = self.conv(x)  # try anyway
            else:
                x = self.conv(x)
        else:
            x = self.conv(x)

        # Flatten Channel and Frequency Axes
        if self.conv2d_subsampling:
            b, c, t, f = x.size()
            x = self.out(x.transpose(1, 2).reshape(b, t, -1))
        # Transpose to Channel Last mode
        else:
            x = x.transpose(1, 2)

        if mask is None:
            return x, None

        max_audio_length = x.shape[1]
        feature_lens = mask.sum(1)
        padding_length = torch.ceil(feature_lens / self.subsampling_factor)
        if self.is_causal and self.subsampling_causal_cond:
            feature_lens_remainder = feature_lens % self.subsampling_factor
            padding_length[feature_lens_remainder != 1] += 1
        pad_mask = (
            torch.arange(0, max_audio_length, device=x.device).expand(padding_length.size(0), -1)
            < padding_length.unsqueeze(1)
        )
        return x, pad_mask.unsqueeze(1)

    def reset_parameters(self):
        # initialize weights
        if self._subsampling == "dw_striding":
            with torch.no_grad():
                # init conv
                scale = 1.0 / self._kernel_size
                dw_max = (self._kernel_size**2) ** -0.5
                pw_max = self._conv_channels**-0.5

                torch.nn.init.uniform_(self.conv[0].weight, -scale, scale)
                torch.nn.init.uniform_(self.conv[0].bias, -scale, scale)

                for idx in range(2, len(self.conv), 3):
                    torch.nn.init.uniform_(self.conv[idx].weight, -dw_max, dw_max)
                    torch.nn.init.uniform_(self.conv[idx].bias, -dw_max, dw_max)
                    torch.nn.init.uniform_(self.conv[idx + 1].weight, -pw_max, pw_max)
                    torch.nn.init.uniform_(self.conv[idx + 1].bias, -pw_max, pw_max)

                # init fc (80 * 64 = 5120 from https://github.com/kssteven418/Squeezeformer/blob/13c97d6cf92f2844d2cb3142b4c5bfa9ad1a8951/src/models/conformer_encoder.py#L487
                fc_scale = (self._feat_out * self._feat_in / self._sampling_num) ** -0.5
                torch.nn.init.uniform_(self.out.weight, -fc_scale, fc_scale)
                torch.nn.init.uniform_(self.out.bias, -fc_scale, fc_scale)

    def conv_split_by_batch(self, x):
        """Tries to split input by batch, run conv and concat results"""
        b, _, _, _ = x.size()
        if b == 1:  # can't split if batch size is 1
            return x, False

        if self.subsampling_conv_chunking_factor > 1:
            cf = self.subsampling_conv_chunking_factor
        else:
            # avoiding a bug / feature limiting indexing of tensors to 2**31
            # see https://github.com/pytorch/pytorch/issues/80020
            x_ceil = 2**31 / self._conv_channels * self._stride * self._stride
            p = math.ceil(math.log(torch.numel(x) / x_ceil, 2))
            cf = 2**p

        new_batch_size = b // cf
        if new_batch_size == 0:  # input is too big
            return x, False

        return torch.cat([self.conv(chunk) for chunk in torch.split(x, new_batch_size, 0)]), True

    def conv_split_by_channel(self, x):
        """For dw convs, tries to split input by time, run conv and concat results"""
        x = self.conv[0](x)  # full conv2D
        x = self.conv[1](x)  # activation

        for i in range(self._sampling_num - 1):
            _, c, t, _ = x.size()

            if self.subsampling_conv_chunking_factor > 1:
                cf = self.subsampling_conv_chunking_factor
            else:
                # avoiding a bug / feature limiting indexing of tensors to 2**31
                # see https://github.com/pytorch/pytorch/issues/80020
                p = math.ceil(math.log(torch.numel(x) / 2**31, 2))
                cf = 2**p

            new_c = int(c // cf)
            if new_c == 0:
                new_c = 1

            new_t = int(t // cf)
            if new_t == 0:
                new_t = 1

            x = self.channel_chunked_conv(self.conv[i * 3 + 2], new_c, x)  # conv2D, depthwise

            # splitting pointwise convs by time
            x = torch.cat(
                [self.conv[i * 3 + 3](chunk) for chunk in torch.split(x, new_t, 2)], 2
            )  # conv2D, pointwise
            x = self.conv[i * 3 + 4](x)  # activation
        return x

    def channel_chunked_conv(self, conv, chunk_size, x):
        """Performs channel chunked convolution"""

        ind = 0
        out_chunks = []
        for chunk in torch.split(x, chunk_size, 1):
            step = chunk.size()[1]

            if self.is_causal:
                chunk = nn.functional.pad(
                    chunk,
                    pad=(
                        self._kernel_size - 1,
                        self._stride - 1,
                        self._kernel_size - 1,
                        self._stride - 1,
                    ),
                )
                ch_out = nn.functional.conv2d(
                    chunk,
                    conv.weight[ind : ind + step, :, :, :],
                    bias=conv.bias[ind : ind + step],
                    stride=self._stride,
                    padding=0,
                    groups=step,
                )
            else:
                ch_out = nn.functional.conv2d(
                    chunk,
                    conv.weight[ind : ind + step, :, :, :],
                    bias=conv.bias[ind : ind + step],
                    stride=self._stride,
                    padding=self._left_padding,
                    groups=step,
                )
            out_chunks.append(ch_out)
            ind += step

        return torch.cat(out_chunks, 1)

    def change_subsampling_conv_chunking_factor(self, subsampling_conv_chunking_factor: int):
        if (
            subsampling_conv_chunking_factor != -1
            and subsampling_conv_chunking_factor != 1
            and subsampling_conv_chunking_factor % 2 != 0
        ):
            raise ValueError("subsampling_conv_chunking_factor should be -1, 1, or a power of 2")
        self.subsampling_conv_chunking_factor = subsampling_conv_chunking_factor


def calc_length(lengths, all_paddings, kernel_size, stride, ceil_mode, repeat_num=1):
    """Calculates the output length of a Tensor passed through a convolution or max pooling layer"""
    add_pad: float = all_paddings - kernel_size
    one: float = 1.0
    for i in range(repeat_num):
        lengths = torch.div(lengths.to(dtype=torch.float) + add_pad, stride) + one
        if ceil_mode:
            lengths = torch.ceil(lengths)
        else:
            lengths = torch.floor(lengths)
    return lengths.to(dtype=torch.int)

####  multihead attention starts here
class AttModule(nn.Module):
    """Attention abstraction module"""

    def __init__(self):
        super().__init__()
        self.export_mode = False

    def set_export(self, mode=True):
        """set the export mode"""
        self.export_mode = mode

    def forward(
        self,
        x: Tensor,
        memory: Optional[Tensor] = None,
        pos_emb: Optional[Tensor] = None,
        att_mask: Optional[Tensor] = None,
    ) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]:
        """AttModule forward

        Args:
            x: torch.Tensor
                input tensor.
            memory: torch.Tensor, optional
                memory tensor.
            pos_emb: torch.Tensor, optional
                positional encoder embedding.
            att_mask: torch.Tensor, optional
                attention mask tensor.
        """
        return x, memory, pos_emb, att_mask


class AttBlock(Block, AttModule):
    """Attention Block module to support both Attention and Block module."""

    def memory_dims(self, max_len=False):
        """memory dimensions"""
        return (1, self.input_size)

def masked_softmax(
    scores,
    mask: Optional[Tensor],
):
    if mask is not None:
        mask = mask.unsqueeze(1).eq(0)  # (batch, 1, time1, time2)
        scores = scores.masked_fill(mask, -torch.inf)
        attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0)  # (batch, head, time1, time2)
    else:
        attn = torch.softmax(scores, dim=-1)  # (batch, head, time1, time2)
    return attn


class MultiHeadedAttention(nn.Module):
    """Multi-Head Attention layer with optional relative position embedding and GLU.

    Args:
        n_head: int
            the number of heads.
        n_feat: int
            input size features.
        dropout_rate: float
            dropout rate.
        use_LN: bool
            apply layer norm or not
        dropout_at_output: bool
            whether to apply dropout at output
        attention_inner_dim: int, optional
            the attention dimension used in the class,
            it can be different from the input dimension n_feat.
            default: -1 (equal to n_feat).
        use_pt_scaled_dot_product_attention: bool, optional
            if set True, use pytorch scaled dot product attention in training.  NOTE: this will NOT
            be used in ONNX decoding due to a lack of support.  In that case, we use the original
            attention implementation, which shows no regression.
            default: False.
        n_value: int, optional
            if set to values other than -1, use a different dimension for value. With the default value (i.e. -1), it is backward compatible.
        group_size: int, optional. must divide `n_head`
            if group_size > 1:       GQA
            if group_size = 1:       MHA
            if group_size = n_head:  MQA
    """

    inv_sqrt_d_k: torch.jit.Final[float]
    h: torch.jit.Final[int]
    h_k: torch.jit.Final[int]
    g: torch.jit.Final[int]

    def __init__(
        self,
        n_head,
        n_feat,
        dropout_rate,
        attention_inner_dim=-1,
        glu_type="swish",
        bias_in_glu=True,
        use_pt_scaled_dot_product_attention=False,
        n_value=-1,
        group_size: int = 1,
    ):
        super().__init__()
        if n_value == -1:
            n_value = n_feat
        if attention_inner_dim == -1:
            attention_inner_dim = n_feat
        assert attention_inner_dim % n_head == 0

        # We assume d_v always equals d_k
        self.d_k = attention_inner_dim // n_head
        self.inv_sqrt_d_k = 1.0 / math.sqrt(self.d_k)
        self.h = n_head
        assert n_head % group_size == 0, "group_size must divide n_head"
        self.g = group_size
        self.h_k = n_head // group_size
        
        self.linear_q = nn.Linear(n_feat, attention_inner_dim)
        self.linear_k = nn.Linear(n_feat, attention_inner_dim // group_size)
        self.linear_v = nn.Linear(n_value, attention_inner_dim // group_size)
        self.linear_out = nn.Linear(attention_inner_dim // group_size, n_value)
        
        self.attn = torch.jit.Attribute(None, Optional[Tensor])
        self.dropout = nn.Dropout(p=dropout_rate)
        self.dropout_rate = dropout_rate
        self.use_pt_scaled_dot_product_attention = use_pt_scaled_dot_product_attention

        if use_pt_scaled_dot_product_attention and group_size > 1:
            raise ValueError("Cannot use PT Scaled Attention with GQA")

        # Torchscript eager quantization.  Note that these functions below are
        # NOOPs and have very little impact on performance unless quantization is
        # enabled.
        self.quant_q = torch.ao.quantization.QuantStub()
        self.quant_x = torch.ao.quantization.QuantStub()
        self.dequant = torch.ao.quantization.DeQuantStub()
        self.ffunc = torch.ao.nn.quantized.FloatFunctional()

    def forward(
        self,
        query: Tensor,
        key: Tensor,
        value: Tensor,
        pos_k: Tensor,
        pos_v: Tensor,
        mask: Optional[Tensor],
        relative_attention_bias: Optional[Tensor] = None,
    ):
        """Compute 'Scaled Dot Product Attention'.

        Args:
            query: torch.Tensor
                query tensor (batch, time1, size)
            key: torch.Tensor
                key tensor (batch, time2, size)
            value: torch.Tensor
                value tensor (batch, time1, size)
            pos_k: torch.Tensor
                key tensor used for relative positional embedding.
            pos_v: torch.Tensor
                value tensor used for relative positional embedding.
            mask: torch.Tensor
                mask tensor (batch, time1, time2)
            relative_attention_bias: torch.Tensor
                bias added to attention logits w.r.t. relative positions (1, n_head, time1, time2)
        """
        n_batch = query.size(0)

        q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)  # (b, t, d)
        k = self.linear_k(key).view(n_batch, -1, self.h_k, self.d_k)  # (b, t, d)
        v = self.linear_v(value).view(n_batch, -1, self.h_k, self.d_k)
        q = (
            q.transpose(1, 2)
            if self.use_pt_scaled_dot_product_attention and not torch.jit.is_scripting()
            else q.transpose(1, 2) * self.inv_sqrt_d_k
        )
        k = k.transpose(1, 2)  # (batch, head_k, time2, d_k)
        v = v.transpose(1, 2)  # (batch, head_k, time2, d_k)
        
        if self.use_pt_scaled_dot_product_attention and not torch.jit.is_scripting():
            attn_mask = None
            if mask is not None:
                mask = mask.unsqueeze(1)
                if relative_attention_bias is not None:
                    attn_mask = mask + relative_attention_bias
                else:
                    attn_mask = mask
                if mask.dtype != q.dtype:
                    attn_mask = attn_mask.to(q.dtype)

            with torch.backends.cuda.sdp_kernel(
                enable_flash=True, enable_math=True, enable_mem_efficient=True
            ):
                x = torch.nn.functional.scaled_dot_product_attention(
                    q,
                    k,
                    v,
                    attn_mask=attn_mask,
                    dropout_p=self.dropout_rate,
                )
        else:
            if self.h != self.h_k:
                q = q.reshape(n_batch, self.g, self.h_k, -1, self.d_k)
                A = torch.einsum("b g h t d, b h s d -> b h t s", q, k)
            else:
                A = torch.matmul(q, k.transpose(-2, -1))
            if pos_k is not None:
                if self.h != self.h_k:
                    B = torch.einsum("b g h t d, t s d -> b h t s", q, pos_k)
                else:
                    reshape_q = (
                        q.contiguous().view(n_batch * self.h, -1, self.d_k).transpose(0, 1)
                    )  # (t1,nh,dk)
                    B = torch.matmul(reshape_q, pos_k.transpose(-2, -1))  # pos_k: (t1,dk,t2)
                    B = B.transpose(0, 1).view(n_batch, self.h, pos_k.size(0), pos_k.size(1))
                scores = A + B
            else:
                scores = A

            if relative_attention_bias is not None:
                scores = scores + relative_attention_bias

            attn = masked_softmax(scores, mask)  # (batch, head, time1, time2)

            self.attn = attn

            p_attn = self.dropout(attn)
            x = torch.matmul(p_attn.to(v.dtype), v)  # (batch, head, time1, d_k)
            if pos_v is not None:
                reshape_attn = (
                    p_attn.contiguous()
                    .view(n_batch * self.h, pos_v.size(0), pos_v.size(1))
                    .transpose(0, 1)
                )  # (t1, bh, t2)

                attn_v = (
                    torch.matmul(reshape_attn, pos_v)
                    .transpose(0, 1)
                    .contiguous()
                    .view(n_batch, self.h, pos_v.size(0), self.d_k)
                )
                x = x + attn_v
        x = (
            x.transpose(1, 2).contiguous().view(n_batch, -1, self.h_k * self.d_k)
        )  # (batch, time1, d_model)

        return self.linear_out(x)  # (batch, time1, d_model)


def unfold_tensor(xs_pad, max_seq_len):
    """
    For a given tensor with shape of (N, T, D), if sequence length T is longer than max_seq_len,
    this function unfold it to a (NT', max_seq_len, D) where T' is T // max_seq_len.
    Args:
        xs_pad: N, T, D
    """
    _, _, D = xs_pad.shape
    xs_pad = xs_pad.transpose(-1, -2) # convert to N, D, T
    # N x D x 1 x T => N x (D x max_seq_len) x T'
    xs_pad = F.unfold(
        xs_pad[..., None, :],
        kernel_size=(1, max_seq_len),
        stride=(1, max_seq_len),
    )

    new_bsz, _, slen = xs_pad.shape
    # N x D x max_seq_len x T'
    xs_pad = xs_pad.view(new_bsz, -1, max_seq_len, slen)
    # N x T' x max_seq_len x D
    xs_pad = xs_pad.permute(0, 3, 2, 1).contiguous()
    # NT' x max_seq_len x D
    xs_pad = xs_pad.view(-1, max_seq_len, D)
    return xs_pad

# conformer_encoder.py
class MultiSequential(torch.nn.Sequential):
    """Multi-input multi-output torch.nn.Sequential"""

    @torch.jit.ignore
    def forward(self, *args):
        """Forward method implementation."""
        for m in self:
            args = m(*args)
        return args

def repeat(repeat_num, module_gen_fn):
    """repeat module N times

    :param int repeat_num: repeat time
    :param function module_gen_fn: function to generate module
    :return: repeated modules
    :rtype: MultiSequential
    """
    return MultiSequential(*[module_gen_fn(i) for i in range(repeat_num)])

class ConformerEncoderLayer(nn.Module):
    """ConformerEncoder Layer module.
    for more details see conformer paper:
        https://arxiv.org/abs/2005.08100
    This module implement the Conformer block layer.

    Args:
        d_model: int
            attention dim.
        ext_pw_out_channel: int
            if > 0, ext_pw_out_channel is a dim channel size
             for the last pointwise conv after swish activation.
        depthwise_seperable_out_channel: int
            if set different to 0, the number of depthwise_seperable_out_channel
             will be used as a channel_out of the second conv1d layer.
             otherwise, it equal to 0, the second conv1d layer is skipped.
        depthwise_multiplier: int
            number of input_dim channels duplication. this value
             will be used to compute the hidden channels of the Conv1D.
        n_head: int
            the number of heads for multihead attention module.
        d_ffn: int
            output size of the feed_forward blocks.
        ext_pw_kernel_size: int
            kernel size of the conv pointwise of the conformer.
        kernel_size: int
            kernel size.
        dropout_rate: float
            dropout rate.
        causal: bool, optional
            if set to True, convolution have no access
             to future frames. default False.
        batch_norm: bool, optional
            if set to True, apply batchnorm before activation
            in ConvModule layer of the conformer.
            default False
        activation: str, optional
            activation function name,
            one of ["relu", "swish", "sigmoid"],
            sigmoid activation is only used with "glu_in_fnn=True",
            default "relu".
        chunk_se: int, optional
            0 for offline SE.
            1 for streaming SE, where mean is computed
             by accumulated history until current chunk_se.
            2 for streaming SE, where mean is computed
             by only the current chunk.
            default 0.
        chunk_size: int, optional
            chunk_size for cnn. default 18
        conv_activation: str, optional
            activation function used in ConvModule part
            of the conformer, default "relu".
        conv_glu_type: str, optional
            activation function used for the glu inside
            the ConvModule part of the conformer.
            default: "sigmoid".
        bias_in_glu: bool, optional
            if set to True, use additive bias in the weight module
             before GLU.
        linear_glu_in_convm: bool, optional
            if set to True, use GLULinear module,
             otherwise, used GLUPointWiseConv module.
              default to False.
        attention_innner_dim: int, otional
            if equal to -1, attention dim for linears k/q/v is
            equal to d_model. otherwise attention_innner_dim is used.
            default -1.
        attention_glu_type: str, optional
            activation function for glu used in the multihead attention,
             default "swish".
        activation_checkpointing: str, optional
            a dictionarry of {"module","interval","offload"}, where
                "module": str
                    accept ["transformer", "attention"] to select
                    which module should do activation checkpointing.
                "interval": int, default 1,
                    interval of applying activation checkpointing,
                    interval = 1 means that we apply checkpointing
                    on every layer (if activation), otherwise,
                    we apply it every x interval.
                "offload": bool, default False,
                    if set to True, we offload activation to cpu and
                    reload it during backward, otherwise,
                    we recalculate activation in backward.
            default "".
        export: bool, optional
            if set to True, it remove the padding from convolutional layers
             and allow the onnx conversion for inference.
              default False.
        use_pt_scaled_dot_product_attention: bool, optional
            if set to True, use pytorch's scaled dot product attention implementation in training.
        attn_group_sizes: int, optional
            the number of groups to use for attention, default 1 (Multi-Head Attention),
            1 = typical Multi-Head Attention,
            1 < attn_group_sizes < attention_heads = Grouped-Query Attention
            attn_group_sizes = attenion_heads = Multi-Query Attention
    """

    def __init__(
        self,
        d_model=512,
        ext_pw_out_channel=0,
        depthwise_seperable_out_channel=256,
        depthwise_multiplier=1,
        n_head=4,
        d_ffn=2048,
        ext_pw_kernel_size=1,
        kernel_size=3,
        dropout_rate=0.1,
        causal=False,
        batch_norm=False,
        activation="relu",
        chunk_se=0,
        chunk_size=18,
        conv_activation="relu",
        conv_glu_type="sigmoid",
        bias_in_glu=True,
        linear_glu_in_convm=False,
        attention_innner_dim=-1,
        attention_glu_type="swish",
        activation_checkpointing="",
        export=False,
        use_pt_scaled_dot_product_attention=False,
        attn_group_sizes: int = 1,
    ):
        super().__init__()

        self.feed_forward_in = FeedForward(
            d_model=d_model,
            d_inner=d_ffn,
            dropout_rate=dropout_rate,
            activation=activation,
            bias_in_glu=bias_in_glu,
        )

        self.self_attn = encoder_checkpoint_wrapper(
            activation_checkpointing,
            MultiHeadedAttention,
        )(
            MultiHeadedAttention(
                n_head,
                d_model,
                dropout_rate,
                attention_innner_dim,
                attention_glu_type,
                bias_in_glu,
                use_pt_scaled_dot_product_attention=use_pt_scaled_dot_product_attention,
                group_size=attn_group_sizes,
            )
        )
        self.conv = ConvModule(
            d_model,
            ext_pw_out_channel,
            depthwise_seperable_out_channel,
            ext_pw_kernel_size,
            kernel_size,
            depthwise_multiplier,
            dropout_rate,
            causal,
            batch_norm,
            chunk_se,
            chunk_size,
            conv_activation,
            conv_glu_type,
            bias_in_glu,
            linear_glu_in_convm,
            export=export,
        )

        self.feed_forward_out = FeedForward(
            d_model=d_model,
            d_inner=d_ffn,
            dropout_rate=dropout_rate,
            activation=activation,
            bias_in_glu=bias_in_glu,
        )

        self.layer_norm_att = nn.LayerNorm(d_model)
        self.layer_norm = nn.LayerNorm(d_model)

    def forward(
        self,
        x,
        pos_k,
        pos_v,
        mask,
        relative_attention_bias: Optional[Tensor] = None,
    ):
        """ConformerEncoder forward.

        Args:
            x: torch.Tensor
                input feature of shape (batch, max_time_in, size)
            pos_k: torch.Tensor
                positional key embedding.
            mask: torch.Tensor
                mask for x (batch, max_time_in)
            relative_attention_bias: Optional[torch.Tensor]
                bias added to attention logits w.r.t. relative positions (1, n_head, time1, time2)
        """
        x = x + 0.5 * self.feed_forward_in(x)
        norm_x = self.layer_norm_att(x)

        x = x + self.self_attn(
            norm_x,
            norm_x,
            norm_x,
            pos_k,
            pos_v,
            mask,
            relative_attention_bias=relative_attention_bias,
        )
        x = x + self.conv(x)
        x = x + 0.5 * self.feed_forward_out(x)

        out = self.layer_norm(x)

        return out, pos_k, pos_v, mask
        
class TransformerEncoderBase(abc.ABC, nn.Module):
    """The Base class for Transformer based encoders

    Please set causal = True in streaming model
    Args:
        input_size: int
            input feature dimension.
        chunk_size: int, list(int)
            Number of frames for each chunk
            This variable can take 2 forms:
            int:  Used for inference, or single chunk size training
            list(int) : Used only for variable chunk size training
            Some examples for the 2 cases:
            chunk_size = 12
            chunk_size = [6, 8, 12, 24]
        left_chunk: int, list(int)
            Number of chunks used for masking in streaming mode.
            This variable can take 2 forms:
            int:  Used for inference, or single chunk size training
            list(int) : Used only for variable chunk size training. When
            chunk_size is a list, left_chunk must be a list with same length.
            Some examples for the 2 cases:
            left_chunk = 6
            left_chunk = [12, 9, 6, 3]
        attention_dim: int, optional
            attention dimension. default 256.
        attention_heads: int, optional
            the number of heads. default 4
        input_layer: str, optional
            input layer type before Conformer,
            one of ["linear", "conv2d", "custom", "vgg2l", "embed"],
            default "conv2d"
        cnn_out: int, optional
            the number of CNN channels before Conformer.
            default -1.
        cnn_layer_norm: bool, optional
            layer norm between Conformer and the first CNN.
            default False.
        time_reduction: int, optional
            time reduction factor
            default 4
        dropout_rate: float, optional
            dropout rate. default 0.1
        padding_idx: int, optional
            padding index for input_layer=embed
            default -1
        relative_attention_bias_args: dict, optional
            use more efficient scalar bias-based relative multihead attention (Q*K^T + B)
            implemented in cmb.basics.embedding.[T5/ALiBi]RelativeAttentionLogitBias
            usage: relative_attention_bias_args={"type": t5/alibi}
            additional method-specific arguments can be provided (see transformer_base.py)
        positional_dropout_rate: float, optional
            dropout rate after positional encoding. default 0.0
        nemo_conv_settings: dict, optional
            A dictionary of settings for NeMo Subsampling.
            default None
        conv2d_extra_padding: str, optional
            Add extra padding in conv2d subsampling layers. Choices are
            (feat, feat_time, none, True).
            if True or feat_time, the extra padding is added into non full
            supraframe utts in batch.
            Default: none
        attention_group_size: int, optional
            the number of groups to use for attention, default 1 (Multi-Head Attention),
            1 = typical Multi-Head Attention,
            1 < attention_group_size < attention_heads = Grouped-Query Attention
            attention_group_size = attenion_heads = Multi-Query Attention
    """

    def __init__(
        self,
        input_size,
        chunk_size,
        left_chunk,
        attention_dim=256,
        attention_heads=4,
        input_layer="nemo_conv",
        cnn_out=-1,
        cnn_layer_norm=False,
        time_reduction=4,
        dropout_rate=0.0,
        padding_idx=-1,
        relative_attention_bias_args=None,
        positional_dropout_rate=0.0,
        nemo_conv_settings=None,
        conv2d_extra_padding: Literal["feat", "feat_time", "none", True] = "none",
        attention_group_size=1,
        encoder_embedding_config=None,
    ):
        super().__init__()
        self.input_size = input_size
        self.input_layer = input_layer
        self.chunk_size = chunk_size
        self.left_chunk = left_chunk
        self.attention_dim = attention_dim
        self.num_heads = attention_heads
        self.attention_group_size = attention_group_size
        self.time_reduction = time_reduction
        self.nemo_conv_settings = nemo_conv_settings
        self.encoder_embedding_config = encoder_embedding_config

        if self.input_layer == "nemo_conv":
            default_nemo_conv_settings = {
                "subsampling": "dw_striding",
                "subsampling_factor": self.time_reduction,
                "feat_in": input_size,
                "feat_out": attention_dim,
                "conv_channels": 256,
                "subsampling_conv_chunking_factor": 1,
                "activation": nn.ReLU(),
                "is_causal": False,
            }
            # Override any of the defaults with the incoming, user settings
            if nemo_conv_settings:
                default_nemo_conv_settings.update(nemo_conv_settings)
                for i in ["subsampling_factor", "feat_in", "feat_out"]:
                    assert (
                        i not in nemo_conv_settings
                    ), "{i} should be specified outside of the NeMo dictionary"

            self.embed = NemoConvSubsampling(
                **default_nemo_conv_settings,
            )
        else:
            raise ValueError("unknown input_layer: " + input_layer)

        self.pos_emb = AbsolutePositionalEncoding(attention_dim, positional_dropout_rate)

        self.relative_attention_bias_type = (
            relative_attention_bias_args.get("type") if relative_attention_bias_args else None
        )
        if self.relative_attention_bias_type == "t5":
            assert (
                self.num_heads % self.attention_group_size == 0
            ), "attention_group_size must divide n_head"
            self.relative_attention_bias_layer = T5RelativeAttentionLogitBias(
                self.num_heads // self.attention_group_size,
                max_distance=relative_attention_bias_args.get("t5_bias_max_distance", 1000),
                symmetric=relative_attention_bias_args.get("t5_bias_symmetric", False),
            )
        else:
            raise NotImplementedError

    
    def post_init(self, init_model_config):

        pretrained_speech_encoder_path = init_model_config.get('pretrained_speech_encoder_path', None)
        if pretrained_speech_encoder_path:
            model_state = torch.load(pretrained_speech_encoder_path, map_location="cpu")
            encoder_state_dict = {}
            for k, v in model_state.items():
                if "encoder." in k:
                    tmp_k = k.replace("encoder.", "")
                    encoder_state_dict[tmp_k] = v
            
            if hasattr(self, "encoder_embedding"):
                del self.encoder_embedding
            self.load_state_dict(encoder_state_dict)
        
        if not hasattr(self, "encoder_embedding"):
            self.encoder_embedding = MeanVarianceNormLayer(self.encoder_embedding_config["input_size"])
       
        mean_file = init_model_config.get('mean_file', None)
        invstd_file = init_model_config.get('invstd_file', None)
        if mean_file is not None and invstd_file is not None:
            self.encoder_embedding.load_mean_invstd(mean_file, invstd_file)

    def compute_lens_change(self, feature_lens):
        """feature_lens: int
        return updated feature lens.

        This used to return a different lambda function for each case that computed
        the right thing.  That does not work within Torchscript.  If you really
        need this to be faster, create nn.Module()-s for all the cases and return
        one of them.  Torchscript does support that.
        """
        if self.input_layer == "nemo_conv":
            # Handle the special causal case
            subsampling_causal_cond = self.nemo_conv_settings.get("subsampling", "dw_striding") in [
                "dw_striding",
                "striding",
                "striding_conv1d",
            ]
            is_causal = self.nemo_conv_settings.get("is_causal", False)
            if is_causal and subsampling_causal_cond:
                lens_change = (
                    torch.ceil(feature_lens / self.time_reduction).long()
                    if isinstance(feature_lens, Tensor)
                    else math.ceil(feature_lens / self.time_reduction)
                )
                feature_lens_remainder = feature_lens % self.time_reduction
                if isinstance(feature_lens, Tensor):
                    lens_change[feature_lens_remainder != 1] += 1
                elif feature_lens_remainder != 1:
                    lens_change += 1
                return lens_change
            ceil_func = math.ceil if isinstance(feature_lens, int) else torch.ceil
            return ceil_func(feature_lens / self.time_reduction)

    @abc.abstractmethod
    def forward(self):
        """Abstract forward method implementation."""

    def _chunk_size_selection(self, chunk_size=None, left_chunk=None):
        """If chunk size is a list, we will randomly select a chunk size."""

        if chunk_size is None:
            chunk_size = self.chunk_size
        if left_chunk is None:
            left_chunk = self.left_chunk
        if isinstance(chunk_size, list):
            # Variable chunk size during training
            chunk_size_index = int(torch.randint(low=0, high=len(chunk_size), size=(1,)))
            chunk_size_train_eff = chunk_size[chunk_size_index]
            if not isinstance(left_chunk, list):
                raise ValueError("Since chunk_size is a list, left_chunk must be a list")
            if len(left_chunk) != len(chunk_size):
                raise ValueError(
                    "The length of left_chunk must be the same as length of chunk_size."
                )
            left_chunk_train_eff = left_chunk[chunk_size_index]
        else:
            chunk_size_train_eff = chunk_size
            left_chunk_train_eff = left_chunk

        return chunk_size_train_eff, left_chunk_train_eff

    def _get_embed_class(self, embed):
        # pylint: disable=protected-access
        is_embed_using_act_chkpt = isinstance(embed, CheckpointWrapper)
        is_embed_fsdp_wrapped = isinstance(embed, FullyShardedDataParallel)
        embed_class = embed
        if is_embed_using_act_chkpt:
            embed_class = embed._checkpoint_wrapped_module
        if is_embed_fsdp_wrapped:
            embed_class = embed.module
        return embed_class

    def _forward_embeddings_core(self, input_tensor, masks):
        embed_class = self._get_embed_class(self.embed)
        assert isinstance(embed_class, NemoConvSubsampling)
        input_tensor, masks = self.embed(input_tensor, masks)    
        return input_tensor, masks

    def _position_embedding(self, input_tensor):
        pos_k = None
        pos_v = None
        if self.relative_attention_bias_layer is None:
            input_tensor = self.pos_emb(input_tensor)  # default to add abs sinusoid embedding
        return pos_k, pos_v

    def _streaming_mask(self, seq_len, batch_size, chunk_size, left_chunk):
        chunk_size_train_eff, left_chunk_train_eff = self._chunk_size_selection(
            chunk_size, left_chunk
        )

        # Create mask matrix for streaming
        # S stores start index. if chunksize is 18, s is [0,18,36,....]
        chunk_start_idx = np.arange(0, seq_len, chunk_size_train_eff)
        # avoid randomness when run evaluation or decoding
        if self.training and np.random.rand() > 0.5:
            # Either first or last chunk is not complete.
            # If only the last one is not complete, EOS is not effective
            chunk_start_idx = seq_len - chunk_start_idx
            chunk_start_idx = chunk_start_idx[::-1]
            chunk_start_idx = chunk_start_idx[:-1]
            chunk_start_idx = np.insert(chunk_start_idx, 0, 0)

        enc_streaming_mask = (
            adaptive_enc_mask(seq_len, chunk_start_idx, left_window=left_chunk_train_eff)
            .unsqueeze(0)
            .expand([batch_size, -1, -1])
        )
        return enc_streaming_mask

    def forward_embeddings(self, xs_pad, masks, chunk_size_nc=None, left_chunk_nc=None):
        """Forwarding the inputs through the top embedding layers

        Args:
            xs_pad: torch.Tensor
                input tensor
            masks: torch.Tensor
                input mask
            chunk_size_nc: (optional, default is None) chunk size for non-causal layers
            left_chunk_nc: (optional, default is None) # of left chunks for non-causal layers
        """
        # pylint: disable=R0915
        # get new lens.
        seq_len = int(self.compute_lens_change(xs_pad.shape[1]))
        if seq_len <= 0:
            raise ValueError(
                f"""The squence length after time reduction is invalid: {seq_len}.
                Your input feature is too short. Consider filtering out the very
                short sentence from data loader""",
            )

        batch_size = xs_pad.shape[0]

        enc_streaming_mask = self._streaming_mask(
            seq_len, batch_size, self.chunk_size, self.left_chunk
        )

        if xs_pad.is_cuda:
            enc_streaming_mask = enc_streaming_mask.cuda()
            xs_pad = xs_pad.cuda()

        input_tensor = xs_pad
        input_tensor, masks = self._forward_embeddings_core(input_tensor, masks)

        streaming_mask = enc_streaming_mask
        if streaming_mask is not None and masks is not None:
            hs_mask = masks & streaming_mask
        elif masks is not None:
            hs_mask = masks
        else:
            hs_mask = streaming_mask

        if chunk_size_nc is not None:
            enc_streaming_mask_nc = self._streaming_mask(
                seq_len, batch_size, chunk_size_nc, left_chunk_nc
            )
            if xs_pad.is_cuda:
                enc_streaming_mask_nc = enc_streaming_mask_nc.cuda()
            if masks is not None:
                hs_mask_nc = masks & enc_streaming_mask_nc
            else:
                hs_mask_nc = enc_streaming_mask_nc
        else:
            hs_mask_nc = None

        pos_k, pos_v = self._position_embedding(input_tensor)

        if chunk_size_nc is None:
            return input_tensor, pos_k, pos_v, hs_mask, masks
        return input_tensor, pos_k, pos_v, hs_mask, masks, hs_mask_nc

    def get_offset(self):
        """Returns offset used when retaining inputs for decoding.

        This is essentially, how many additional frames have to be added to
        the front-end CNN input to ensure it can produce a single output.
        So if the "padding" parameter is 0, typically offset will be > 0.
        """
        return get_offset(self.input_layer, self.time_reduction)


def get_offset(input_layer: str, time_reduction: int):
    """Get an offset. We will use the offset for determining #frames of a subsampled feature.

    Args:
        input_layer (str): Type of an input layer
        time_reduction (int): time reduction factor for downsampling a feature
    Returns:
        int: offset
    """
    if input_layer in ("conv2d", "nemo_conv") and time_reduction == 4:
        return 3
    if input_layer in ("conv2d",) and time_reduction == 6:
        return 1
    if input_layer in ("conv2d", "nemo_conv") and time_reduction == 8:
        return 7
    return 0


class ConformerEncoder(TransformerEncoderBase):
    """ConformerEncoder module.
    see original paper for more details:
        https://arxiv.org/abs/2005.08100

    Please set causal = True in streaming model
    Args:
        input_size: int
            input feature dimension.
        chunk_size: int, list(int)
            Number of frames for each chunk
            This variable can take 2 forms:
            int:  Used for inference, or single chunk size training
            list(int) : Used only for variable chunk size training
            Some examples for the 2 cases:
            chunk_size = 12
            chunk_size = [6, 8, 12, 24]
        left_chunk: int, list(int)
            Number of chunks used for masking in streaming mode.
            This variable can take 2 forms:
            int:  Used for inference, or single chunk size training
            list(int) : Used only for variable chunk size training. When
            chunk_size is a list, left_chunk must be a list with same length.
            Some examples for the 2 cases:
            left_chunk = 6
            left_chunk = [12, 9, 6, 3]
        left_chunk: int
            number of chunks used for masking in streaming mode.
        num_lang: int
            This parameter is used to store the number of languages in the lang_dict,
            only used for multiseed/multilingual models. default None.
        attention_dim: int, optional
            attention dimension. default 256.
        attention_heads: int, optional
            the number of heads. default 4
        linear_units:
            the number of units of position-wise feed forward.
            default 2048
        num_block:
            number of Transformer layer. default 6
        dropout_rate: float, optional
            dropout rate. default 0.1
        input_layer: str, optional
            input layer type before Conformer,
            one of ["linear", "conv2d", "custom", "vgg2l", "embed"],
            default "conv2d"
        causal: bool, optional
            if set to True, convolution have no access
             to future frames. default False.
        batch_norm: bool, optional
            if set to True, apply batchnorm before activation
            in ConvModule layer of the conformer.
            default False
        cnn_out: int, optional
            the number of CNN channels before Conformer.
            default -1.
        cnn_layer_norm: bool, optional
            layer norm between Conformer and the first CNN.
            default False.
        ext_pw_out_channel: int, optional
            the number of channel for CNN
            before depthwise_seperable_CNN.
            If 0 then use linear. default 0.
        ext_pw_kernel_size: int, optional
            kernel size of N before depthwise_seperable_CNN.
            only work for ext_pw_out_channel > 0.
            default 1
        depthwise_seperable_out_channel: int, optional
            the number of channel for
            depthwise_seperable_CNN.
            default 256.
        depthwise_multiplier: int, optional
            the number of multiplier for
            depthwise_seperable_CNN.
            default 1.
        chunk_se: int, optional
            0 for offline SE.
            1 for streaming SE, where mean is computed
             by accumulated history until current chunk_se.
            2 for streaming SE, where mean is computed
             by only the current chunk.
            default 0.
        kernel_size: int, optional
            the number of kernels for depthwise_seperable_CNN.
            default 3.
        activation: str, optional
            FeedForward block activation.
            one of ["relu", "swish", "sigmoid"]
            default "relu".
        conv_activation: str, optional
            activation function used in ConvModule part
            of the conformer, default "relu".
        conv_glu_type: str, otional
            activation used use glu in depthwise_seperable_CNN,
            default "sigmoid"
        bias_in_glu: bool, optional
            if set to True, use additive bias in the weight module
             before GLU. default True
        linear_glu_in_convm: bool, optional
            if set to True, use GLULinear module,
             otherwise, used GLUPointWiseConv module.
              default to False.
        attention_glu_type: str
            only work for glu_in_attention !=0
            default "swish".
        export: bool, optional
            if set to True, it remove the padding from convolutional layers
             and allow the onnx conversion for inference.
              default False.
        activation_checkpointing: str, optional
            a dictionarry of {"module","interval","offload"}, where
                "module": str
                    accept ["transformer", "attention"] to select
                    which module should do activation checkpointing.
                "interval": int, default 1,
                    interval of applying activation checkpointing,
                    interval = 1 means that we apply checkpointing
                    on every layer (if activation), otherwise,
                    we apply it every x interval.
                "offload": bool, default False,
                    if set to True, we offload activation to cpu and
                    reload it during backward, otherwise,
                    we recalculate activation in backward.
            default "".
        extra_layer_output_idx: int
            the layer index to be exposed.
        relative_attention_bias_args: dict, optional
            use more efficient scalar bias-based relative multihead attention (Q*K^T + B)
            implemented in cmb.basics.embedding.[T5/ALiBi]RelativeAttentionLogitBias
            usage: relative_attention_bias_args={"type": t5/alibi}
            additional method-specific arguments can be provided (see transformer_base.py)
        time_reduction: int optional
            time reduction factor
            default 4
        use_pt_scaled_dot_product_attention: whether to use pytorch scaled dot product attention
            in training.
            Default: False
        nemo_conv_settings: dict, optional
            A dictionary of settings for NeMo Subsampling.
            default: None
            usage: nemo_conv_settings=
                {
                    "subsampling":
                        dw_striding/striding/dw_striding_conv1d/striding_conv1d,
                    "conv_channels": int,
                    "subsampling_conv_chunking_factor": int,
                    "is_causal": True/False
                }
        conv2d_extra_padding: str, optional
            Add extra padding in conv2d subsampling layers. Choices are
            (feat, feat_time, none, True)
            Default: none
        replication_pad_for_subsample_embedding:  For batched-streaming decoding, use
            "replication" padding for the cache at start of utterance.
             Default: False
        attention_group_size: int, optional
            the number of groups to use for attention, default 1 (Multi-Head Attention),
            1 = typical Multi-Head Attention,
            1 < attention_group_size < attention_heads = Grouped-Query Attention
            attention_group_size = attenion_heads = Multi-Query Attention
    """

    extra_multi_layer_output_idxs: List[int]

    def __init__(  # pylint: disable-all
        self,
        input_size,
        chunk_size,
        left_chunk,
        num_lang=None,
        attention_dim=256,
        attention_heads=4,
        linear_units=2048,
        num_blocks=6,
        dropout_rate=0.1,
        input_layer="nemo_conv",
        causal=True,
        batch_norm=False,
        cnn_out=-1,
        cnn_layer_norm=False,
        ext_pw_out_channel=0,
        ext_pw_kernel_size=1,
        depthwise_seperable_out_channel=256,
        depthwise_multiplier=1,
        chunk_se=0,
        kernel_size=3,
        activation="relu",
        conv_activation="relu",
        conv_glu_type="sigmoid",
        bias_in_glu=True,
        linear_glu_in_convm=False,
        attention_glu_type="swish",
        export=False,
        extra_layer_output_idx=-1,
        extra_multi_layer_output_idxs=[],
        activation_checkpointing="",
        relative_attention_bias_args=None,
        time_reduction=4,
        use_pt_scaled_dot_product_attention=False,
        nemo_conv_settings=None,
        conv2d_extra_padding: Literal["feat", "feat_time", "none", True] = "none",
        replication_pad_for_subsample_embedding=False,
        attention_group_size=1,
        encoder_embedding_config=None,
    ):
        super().__init__(
            input_size,
            chunk_size,
            left_chunk,
            attention_dim,
            attention_heads,
            input_layer,
            cnn_out,
            cnn_layer_norm,
            time_reduction,
            dropout_rate=dropout_rate,
            relative_attention_bias_args=relative_attention_bias_args,
            positional_dropout_rate=0.0,
            nemo_conv_settings=nemo_conv_settings,
            conv2d_extra_padding=conv2d_extra_padding,
            attention_group_size=attention_group_size,
            encoder_embedding_config=encoder_embedding_config,
        )
        self.num_blocks = num_blocks
        self.num_lang = num_lang
        self.kernel_size = kernel_size
        self.embed = embedding_checkpoint_wrapper(activation_checkpointing)(self.embed)
        self.replication_pad_for_subsample_embedding: bool = replication_pad_for_subsample_embedding
        assert self.num_heads % attention_group_size == 0, "attention_group_size must divide n_head"
        self.num_heads_k = self.num_heads // attention_group_size

        self.encoders = repeat(
            num_blocks,
            lambda i: encoder_checkpoint_wrapper(
                activation_checkpointing, ConformerEncoderLayer, i
            )(
                ConformerEncoderLayer(
                    d_model=attention_dim,
                    ext_pw_out_channel=ext_pw_out_channel,
                    depthwise_seperable_out_channel=depthwise_seperable_out_channel,
                    depthwise_multiplier=depthwise_multiplier,
                    n_head=attention_heads,
                    d_ffn=linear_units,
                    ext_pw_kernel_size=ext_pw_kernel_size,
                    kernel_size=kernel_size,
                    dropout_rate=dropout_rate,
                    causal=causal,
                    batch_norm=batch_norm,
                    activation=activation,
                    chunk_se=chunk_se,
                    chunk_size=chunk_size,
                    conv_activation=conv_activation,
                    conv_glu_type=conv_glu_type,
                    bias_in_glu=bias_in_glu,
                    linear_glu_in_convm=linear_glu_in_convm,
                    attention_glu_type=attention_glu_type,
                    activation_checkpointing=attn_checkpointing(activation_checkpointing, i),
                    export=export,
                    use_pt_scaled_dot_product_attention=use_pt_scaled_dot_product_attention,
                    attn_group_sizes=attention_group_size,
                )
            ),
        )
        self.extra_layer_output_idx = extra_layer_output_idx
        self.extra_multi_layer_output_idxs = extra_multi_layer_output_idxs
        # Make a zeros scalar we can use in get_initial_state to determine
        # the device and the needed dtype:
        self.register_buffer("dev_type", torch.zeros(()), persistent=False)

    def init_relative_attention_bias(self, input_tensor):
        if self.relative_attention_bias_layer:
            return self.relative_attention_bias_layer(input_tensor)

    def calculate_hs_mask(self, xs_pad, device, mask):
        max_audio_length = xs_pad.shape[1]
        batch_size = xs_pad.shape[0]
        enc_streaming_mask = self._streaming_mask(
            max_audio_length, batch_size, self.chunk_size, self.left_chunk
        )
        enc_streaming_mask = enc_streaming_mask.to(device)
        if mask is None:
            return enc_streaming_mask

        feature_lens = mask.sum(1)
        padding_length = feature_lens
        pad_mask = (
            torch.arange(0, max_audio_length, device=device).expand(padding_length.size(0), -1)
            < padding_length.unsqueeze(1)
        )
        pad_mask = pad_mask.unsqueeze(1)
        pad_mask = pad_mask & enc_streaming_mask
        return pad_mask

    @torch.jit.ignore
    def forward(self, xs_pad, masks):
        """Conformer Forward function

        Args:
            xs_pad: torch.Tensor
                input tensor
            masks: torch.Tensor
                post-embedding input lengths
        """
        xs_pad = self.encoder_embedding(xs_pad)
        input_tensor, pos_k, pos_v, hs_mask, masks = self.forward_embeddings(xs_pad, masks)

        unfolded = False
        ori_bz, seq_len, D = input_tensor.shape
        max_seq_len = 500 #maxium position for absolute positional encoding
        if seq_len > max_seq_len:
            # audio sequence is longer than max_seq_len, unfold it into chunks of max_seq_len
            unfolded = True
            # the unfold op will drop residual frames, pad it to the multiple of max_seq_len
            if seq_len % max_seq_len > 0:
                chunk_pad_size = max_seq_len - (seq_len % max_seq_len)
            else:
                chunk_pad_size = 0
            if chunk_pad_size > 0:
                input_tensor_pad = F.pad(input_tensor, (0, 0, 0, chunk_pad_size), "constant", 0)
                input_tensor = input_tensor_pad.to(input_tensor.device)

            input_tensor = unfold_tensor(input_tensor, max_seq_len)
            if masks is not None:
                # revise hs_mask here because the previous calculated hs_mask did not consider extra pad
                subsampled_pad_mask = masks.squeeze(1) # [bz, subsampled_unmask_seq_len]
                extra_padded_subsamlped_pad_mask = F.pad(subsampled_pad_mask, (0, chunk_pad_size), "constant", False) # extra padding to the pad mask
                extra_padded_subsamlped_pad_mask = extra_padded_subsamlped_pad_mask.unsqueeze(-1).float()
                masks_unfold = unfold_tensor(extra_padded_subsamlped_pad_mask, max_seq_len) # unfold the pad mask like we did to the input tensor
                masks_unfold = masks_unfold.squeeze(-1).bool() # unfold op does not support bool tensor
            else:
                masks_unfold = None
            hs_mask = self.calculate_hs_mask(input_tensor, input_tensor.device, masks_unfold) # calculate hs_mask based on the unfolded pad mask
        layer_emb = None

        relative_attention_bias = self.init_relative_attention_bias(input_tensor)

        _simplified_path = (
            self.extra_layer_output_idx == -1
            and relative_attention_bias is None
        )

        if _simplified_path:
            input_tensor, *_ = self.encoders(input_tensor, pos_k, pos_v, hs_mask)
        else:
            for i, layer in enumerate(self.encoders):
                input_tensor, _, _, _ = layer(
                    input_tensor,
                    pos_k,
                    pos_v,
                    hs_mask,
                    relative_attention_bias=relative_attention_bias,
                )

                if i == self.extra_layer_output_idx:
                    layer_emb = input_tensor
        if unfolded:
            embed_dim = input_tensor.shape[-1]
            input_tensor = input_tensor.reshape(ori_bz, -1, embed_dim)
            # if we ever padded before unfolding, we need to remove the padding
            if chunk_pad_size > 0:
                input_tensor = input_tensor[:, :-chunk_pad_size, :]
        return input_tensor, masks #, layer_emb

    def gradient_checkpointing_enable(self):
        pass