# --------------------------------------------------------
# ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621)
# Github source: https://github.com/mbzuai-nlp/ArTST
# Based on speecht5, fairseq and espnet code bases
# https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet
# --------------------------------------------------------

import logging
from ast import literal_eval
from typing import Dict, List, Optional, Tuple

import torch
import torch.nn.functional as F
from fairseq import utils
from fairseq.models import (
    FairseqEncoderDecoderModel,
    FairseqIncrementalDecoder,
    register_model,
    register_model_architecture,
)
from .modules.text_encoder_prenet import TextEncoderPrenet
from .modules.text_decoder_prenet import TextDecoderPrenet
from .modules.text_decoder_postnet import TextDecoderPostnet
from .modules.speech_encoder_prenet import SpeechEncoderPrenet
from .modules.speech_encoder_postnet import SpeechEncoderPostnet
from .modules.speech_decoder_prenet import SpeechDecoderPrenet
from .modules.speech_decoder_postnet import SpeechDecoderPostnet
from .modules.speaker_decoder_postnet import SpeakerDecoderPostnet
from .modules.encoder import TransformerEncoder
from .modules.decoder import TransformerDecoder
from fairseq.modules.transformer_sentence_encoder import init_bert_params
from fairseq.models.transformer import Embedding
from fairseq.modules import (
    GumbelVectorQuantizer,
)
from torch import Tensor


logger = logging.getLogger(__name__)

DEFAULT_MAX_TEXT_POSITIONS = 450
DEFAULT_MAX_SPEECH_POSITIONS = 4000


@register_model("artst_transformer")
class ArTSTTransformerModel(FairseqEncoderDecoderModel):
    """Adapted Transformer model (https://arxiv.org/abs/1706.03762) for
    speech-to-text tasks. The Transformer encoder/decoder remains the same.
    A trainable input subsampler is prepended to the Transformer encoder to
    project inputs into the encoder dimension as well as downsample input
    sequence for computational efficiency."""

    def __init__(
            self, 
            args,
            encoder, decoder,
            text_encoder_prenet, speech_encoder_prenet,
            text_decoder_prenet, speech_decoder_prenet,
            text_decoder_postnet, speech_decoder_postnet,
            speaker_decoder_postnet, speech_encoder_postnet, 
        ):
        super().__init__(encoder, decoder)

        self.encoder = encoder
        self.decoder = decoder

        self.text_encoder_prenet = text_encoder_prenet
        self.speech_encoder_prenet = speech_encoder_prenet

        self.text_decoder_prenet = text_decoder_prenet
        self.speech_decoder_prenet = speech_decoder_prenet

        self.text_decoder_postnet = text_decoder_postnet
        self.speech_decoder_postnet = speech_decoder_postnet
        self.speaker_decoder_postnet = speaker_decoder_postnet

        self.hubert_layer = speech_encoder_postnet

        self.reduction_factor = args.reduction_factor
        self.spk_embed_dim = args.spk_embed_dim

        # define projection layer
        self.spk_embed_integration_type = args.spk_embed_integration_type
        if self.spk_embed_dim is not None and self.spk_embed_integration_type != 'pre':
            if self.spk_embed_integration_type == "add":
                self.projection = torch.nn.Linear(self.spk_embed_dim, args.decoder_embed_dim)
            else:
                self.projection = torch.nn.Linear(
                    args.decoder_embed_dim + self.spk_embed_dim, args.decoder_embed_dim
                )

        # Hawau: here we can add language embedding integration

        self.use_codebook = args.use_codebook
        self.codebook_prob = getattr(args, "codebook_prob", 0.5) # args.codebook_prob
        if self.use_codebook:
            vq_dim = args.latent_dim if args.latent_dim > 0 else args.encoder_embed_dim
            self.quantizer = GumbelVectorQuantizer(
                dim=args.encoder_embed_dim,
                num_vars=args.latent_vars,
                temp=args.latent_temp,
                groups=args.latent_groups,
                combine_groups=False,
                vq_dim=vq_dim,
                time_first=True,
                weight_proj_depth=args.quantizer_depth,
                weight_proj_factor=args.quantizer_factor,
            )

        self.num_updates = 0

        # # Follow BERT's random weight initialization (for BART)
        if args.bert_init:
            self.apply(init_bert_params)
        self.args = args
        self.prune_modules(args.modules_filter)

    @staticmethod
    def add_args(parser):
        """Add model-specific arguments to the parser."""
        # Transformer
        parser.add_argument(
            "--activation-fn",
            type=str,
            choices=utils.get_available_activation_fns(),
            help="activation function to use",
        )
        parser.add_argument(
            "--dropout", type=float, metavar="D", help="dropout probability"
        )
        parser.add_argument(
            "--attention-dropout",
            type=float,
            metavar="D",
            help="dropout probability for attention weights",
        )
        parser.add_argument(
            "--activation-dropout",
            "--relu-dropout",
            type=float,
            metavar="D",
            help="dropout probability after activation in FFN.",
        )
        parser.add_argument(
            "--encoder-embed-dim",
            type=int,
            metavar="N",
            help="encoder embedding dimension",
        )
        parser.add_argument(
            "--encoder-ffn-embed-dim",
            type=int,
            metavar="N",
            help="encoder embedding dimension for FFN",
        )
        parser.add_argument(
            "--encoder-layers", type=int, metavar="N", help="num encoder layers"
        )
        parser.add_argument(
            "--encoder-attention-heads",
            type=int,
            metavar="N",
            help="num encoder attention heads",
        )
        parser.add_argument(
            "--encoder-normalize-before",
            action="store_true",
            help="apply layernorm before each encoder block",
        )
        parser.add_argument(
            "--decoder-normalize-before",
            action="store_true",
            help="apply layernorm before each decoder block",
        )
        parser.add_argument(
            "--decoder-embed-dim",
            type=int,
            metavar="N",
            help="decoder embedding dimension",
        )
        parser.add_argument(
            "--decoder-ffn-embed-dim",
            type=int,
            metavar="N",
            help="decoder embedding dimension for FFN",
        )
        parser.add_argument(
            "--decoder-layers", type=int, metavar="N", help="num decoder layers"
        )
        parser.add_argument(
            "--decoder-attention-heads",
            type=int,
            metavar="N",
            help="num decoder attention heads",
        )
        parser.add_argument(
            "--reduction-factor",
            type=int,
            help="reduction factor for decoder",
        )
        parser.add_argument(
            "--spk-embed-dim",
            type=int,
            help="speaker embedding dimension",
        )
        parser.add_argument(
            "--layernorm-embedding",
            action="store_true",
            help="add layernorm to embedding",
        )
        parser.add_argument(
            "--load-pretrained-encoder-from",
            type=str,
            metavar="STR",
            help="model to take encoder weights from (for initialization)",
        )
        parser.add_argument(
            '--freeze-encoder-updates',
            type=int,
            help='number of steps to freeze encoder before finetune'
        )
        parser.add_argument(
            '--freeze-decoder-updates',
            type=int,
            help='number of steps to freeze decoder before finetune'
        )
        parser.add_argument(
            '--no-freeze-encoder-layer',
            type=str,
            help='which encoder layer not freeze during finetune'
        )
        parser.add_argument(
            "--share-input-output-embed",
            action="store_true",
            help="share decoder input and output embeddings",
        )
        parser.add_argument(
            "--share-ctc-embed",
            action="store_true",
            help="share ctc embed and decoder embed",
        )
        parser.add_argument(
            "--encoder-sliding-window-attn",
            default=None,
            type=int,
            help="If not None but a even number, set sliding window attention to encoder's attn_mask, e.g., 4, 10, and 20",
        )
        
        # Convolutional subsampler
        parser.add_argument(
            "--encoder-speech-prenet",
            default="conv",
            type=str,
            choices=["conv", "linear"],
            help="The type of encoder speech prenet, e.g., conv or linear."
        )
        parser.add_argument(
            "--conv-kernel-sizes",
            default="5,5",
            type=str,
            help="The layer of convolution of encoder speech prenet."
        )
        parser.add_argument(
            "--conv-channels",
            default=1024,
            type=int,
            help="The channels of encoder speech prenet."
        )
        parser.add_argument(
            "--subsample-stride",
            default="2,2",
            type=str,
            help="The subsample stride for conv1dsubsample."
        )
        parser.add_argument(
            "--spk-embed-integration-type",
            type=str,
            choices=["pre", "add"],
            help="speaker embedding integration type"
        )
        parser.add_argument(
            "--dprenet-dropout-rate",
            default=0.5,
            type=float,
            help="The dropout rate of decoder speech prenet."
        )
        
        ## SE
        parser.add_argument(
            "--se-predict",
            default=None,
            choices=["masking", "target", "delta"],
            help="If set, source speech inputs decoder to predict the masking/target/delta of corresponding inputs."
               + "masking is [0, 1], target is predicted output, delta is difference between inputs and outputs",
        )
        parser.add_argument(
            "--se-decoder-input",
            type=str,
            default="previous_target",
            choices=["previous_target", "source"],
        )
        
        ## SID
        parser.add_argument(
            "--modules-filter",
            default=None,
            type=str,
            help="Remove unused modules for, e.g., SID.",
        )
        parser.add_argument(
            "--sid-pad-prenet",
            action="store_true",
            help="If set, the size of text dictionary is as small as for <pad> token.",
        )
        parser.add_argument(
            "--encoder-attn-branch",
            type=str,
            default="identity,full",
            help="encoder attention branch sliding window, e.g., 'identity,0,2,4,full'",
        )
        parser.add_argument(
            "--encoder-block-branch",
            type=str,
            help="average the output of encoder, e.g., '4,5,6'",
        )
        parser.add_argument(
            "--sid-encoder-cls",
            default=None,
            choices=["encoder"],
            help="If set, add cls vector to the encoder input, e.g., constant vector.",
        )
        parser.add_argument(
            "--sid-shuffle-encoder-input",
            action="store_true",
            help="If set, shuffle encoder input in time.",
        )
        parser.add_argument(
            "--sid-decoder-speaker",
            action="store_true",
            help="If set, apply speaker decoder as transformer decoder.",
        )
        parser.add_argument(
            "--sid-decoder-attn-dim",
            default=128,
            type=int,
            help="Attention dimension in attensive statistics pooling of speaker decoder.",
        )
        parser.add_argument(
            "--sid-t5-postnet",
            action="store_true",
            help="If set, apply TextDecoderPostnet as speaker classification.",
        )
        parser.add_argument(
            "--sid-embed-dim",
            default=128,
            type=int,
            help="Embedding dimension in speaker postnet for speaker identification if embed postnet.",
        )
        parser.add_argument(
            "--sid-pooling-layer",
            default="decoder",
            type=str,
            choices=["decoder-las", "decoder", "encoder", "encoder-cls", "encoder-speaker"],
            help="The output of decoder or encoder uses as SID pooling layer over temporal dimension.",
        )
        parser.add_argument(
            "--sid-no-pooling-bn",
            action="store_true",
            help="If set, not attention batchnorm.",
        )
        parser.add_argument(
            "--sid-no-embed-postnet",
            action="store_true",
            help="If set, no layer between decoder output and classification layer.",
        )
        parser.add_argument(
            "--sid-normalize-postnet",
            action="store_true",
            help="If set, normalize input and weight in postnet/classifier.",
        )
        parser.add_argument(
            "--sid-softmax-type",
            default="softmax",
            choices=["softmax", "amsoftmax", "aamsoftmax"],
            help="If using amsoftmax or aamsoftmax, the target should be given.",
        )
        parser.add_argument(
            "--softmax-scale",
            default=1.0,
            type=float,
            help="Scale for AMSoftmax or AAMSoftmax.",
        )
        parser.add_argument(
            "--softmax-margin",
            default=0.0,
            type=float,
            help="Margin for AMSoftmax or AAMSoftmax.",
        )
        parser.add_argument(
            "--softmax-easy-margin",
            action="store_true",
            help="Enable easy margin for AAMSoftmax.",
        )
        parser.add_argument(
            "--encoder-layerdrop",
            type=float,
            metavar="D",
            help="LayerDrop probability for encoder",
        )
        parser.add_argument(
            "--decoder-layerdrop",
            type=float,
            metavar="D",
            help="LayerDrop probability for decoder",
        )
        
        ## Hubert
        parser.add_argument(
            '--feature-grad-mult',
            type=float,
            help='multiply feature extractor var grads by this'
        )
        parser.add_argument(
            '--logit-temp',
            type=float,
            help='temperature to divide logits by'
        )
        parser.add_argument(
            '--final-dim',
            type=int,
            help="project final representations and targets to this many "
            "dimensions. set to encoder_embed_dim is <= 0"
        )
        
        # mask
        parser.add_argument(
            '--hubert-mask-length',
            type=int,
            help='mask length'
        )
        parser.add_argument(
            '--mask-prob',
            type=float,
            help='probability of replacing a token with mask'
        )
        parser.add_argument(
            "--mask-selection",
            choices=["static", "uniform", "normal", "poisson"],
            help="how to choose mask length",
        )
        parser.add_argument(
            '--mask-other',
            type=float,
            help="secondary mask argument "
            "(used for more complex distributions), "
            "see help in compute_mask_indices"
        )
        parser.add_argument(
            '--mask-min-space',
            type=int,
            help='min space between spans (if no overlap is enabled)'
        )
        
        # channel masking
        parser.add_argument(
            '--mask-channel-length',
            type=int,
            help='length of the mask for features (channels)'
        )
        parser.add_argument(
            '--mask-channel-prob',
            type=float,
            help="probability of replacing a feature with 0"
        )
        parser.add_argument(
            "--mask-channel-selection",
            choices=["static", "uniform", "normal", "poisson"],
            help="how to choose mask length for channel masking",
        )
        parser.add_argument(
            '--mask-channel-other',
            type=float,
            help="secondary mask argument "
            "(used for more complex distributions), "
            "see help in compute_mask_indices"
        )
        parser.add_argument(
            '--mask-channel-min-space',
            type=int,
            help='min space between spans (if no overlap is enabled)'
        )
        
        # abs positional embeddings
        parser.add_argument(
            '--conv-pos',
            type=int,
            help='number of filters for convolutional positional embeddings'
        )
        parser.add_argument(
            '--conv-pos-groups',
            type=int,
            help='number of groups for convolutional positional embedding'
        )
        
        # codebook related
        parser.add_argument(
            "--use-codebook",
            action="store_true",
            help="whether to use codebook",
        )
        parser.add_argument(
            "--codebook-prob",
            type=float,
            help="probability to use codebook",
        )
        parser.add_argument(
            "--latent-vars",
            type=int,
            help="number of latent variables V in each group of the codebook",
        )
        parser.add_argument(
            "--latent-groups",
            type=int,
            help="number of groups G of latent variables in the codebook",
        )
        parser.add_argument(
            "--latent-dim",
            type=int,
            help="if > 0, uses this dimensionality for latent variables. "
            "otherwise uses final_dim / latent_groups",
        )
        parser.add_argument(
            "--latent-temp",
            type=literal_eval,
            help="temperature for latent variable sampling. "
            "can be tuple of 3 values (start, end, decay)",
        )
        parser.add_argument(
            "--quantizer-depth",
            type=int,
            help="number of quantizer layers",
        )
        parser.add_argument(
            "--quantizer-factor",
            type=int,
            help="number of quantizer layers",
        )
        parser.add_argument(
            "--get-code-distribution",
            action='store_true',
            help="whether to get the code distribution (for test)",
        )

        # relative pos enc
        parser.add_argument(
            "--relative-position-embedding",
            action='store_true',
            help="whether to use relative position embedding",
        )
        parser.add_argument(
            "--num-buckets",
            type=int,
            default=320,
            help="num of buckets for relative position embedding",
        )
        parser.add_argument(
            "--max-distance",
            type=int,
            default=1280,
            help="max distance for relative position embedding",
        )
        parser.add_argument(
            "--encoder-max-relative-position",
            type=int,
            help="max distance for relative position embedding in encoder",
        )
        parser.add_argument(
            "--decoder-max-relative-position",
            type=int,
            help="max distance for relative position embedding in decoder",
        )

        # hubert feature extractor
        parser.add_argument(
            "--conv-feature-layers",
            type=str,
            help= "string describing convolutional feature extraction "
            "layers in form of a python list that contains "
            "[(dim, kernel_size, stride), ...]",
        )
        parser.add_argument(
            "--conv-bias",
            action='store_true',
            help="include bias in conv encoder",
        )
        parser.add_argument(
            "--extractor-mode",
            choices=["default", "layer_norm"],
            help="mode for feature extractor. default has a single group "
            "norm with d groups in the first conv block, whereas layer_norm "
            "has layer norms in every block (meant to use with normalize=True)"
        )

        # others
        parser.add_argument(
            "--bert-init",
            action='store_true',
            help="initilize as bert",
        )
        parser.add_argument(
            "--unb-enc-layer",
            type=int,
            default=-1,
            help="which layer's output is used as the input of decoder",
        )

    # Encoder, Decoder
    @classmethod
    def build_encoder(cls, args, dictionary=None, embed_tokens=None):
        return TransformerEncoder(args, dictionary, embed_tokens)

    @classmethod
    def build_decoder(cls, args):
        return TransformerDecoder(args)

    # Encoder Prenet
    @classmethod
    def build_text_encoder_prenet(cls, embed_tokens, args):
        return TextEncoderPrenet(embed_tokens, args)

    @classmethod
    def build_speech_encoder_prenet(cls, args):
        return SpeechEncoderPrenet(args)

    # Decoder Prenet
    @classmethod
    def build_text_decoder_prenet(cls, embed_tokens, args):
        return TextDecoderPrenet(embed_tokens, args)

    @classmethod
    def build_speech_decoder_prenet(cls, odim, args):
        return SpeechDecoderPrenet(odim, args)

    # Decoder Postnet
    @classmethod
    def build_text_decoder_postnet(cls, embed_tokens, dictionary, args):
        return TextDecoderPostnet(embed_tokens, dictionary, args)

    @classmethod
    def build_speaker_decoder_postnet(cls, embed_dim, class_num, args):
        return SpeakerDecoderPostnet(embed_dim, class_num, args)

    @classmethod
    def build_speech_decoder_postnet(cls, odim, args):
        return SpeechDecoderPostnet(odim, args)

    @classmethod
    def build_speech_encoder_postnet(cls, dictionaries, args):
        return SpeechEncoderPostnet(dictionaries, args)

    @classmethod
    def build_model(cls, args, task):
        """Build a new model instance."""

        # make sure all arguments are present in older models
        base_architecture(args)

        def build_embedding(dictionary, embed_dim, max_num_embeddings=None):
            num_embeddings = len(dictionary)
            if max_num_embeddings is not None and isinstance(max_num_embeddings, int):
                num_embeddings = min(num_embeddings, max_num_embeddings)  
            padding_idx = dictionary.pad()
            return Embedding(num_embeddings, embed_dim, padding_idx)

        if hasattr(args, "sid_pad_prenet") and args.sid_pad_prenet:
            max_num_embeddings = 3 # <pad> at index 2
        else:
            max_num_embeddings = None
        
        text_decoder_embed_tokens = build_embedding(
            task.dicts["text"], args.decoder_embed_dim, max_num_embeddings
        )        

        if args.share_input_output_embed:
            text_encoder_embed_tokens = text_decoder_embed_tokens
        else:
            text_encoder_embed_tokens = build_embedding(
                task.dicts["text"], args.encoder_embed_dim
            )

        speech_odim = args.speech_odim
        if "text" in task.dicts:
            encoder = cls.build_encoder(args, task.dicts["text"], text_encoder_embed_tokens)
        else:
            encoder = cls.build_encoder(args)      
        decoder = cls.build_decoder(args)

        text_encoder_prenet = cls.build_text_encoder_prenet(text_encoder_embed_tokens, args)
        speech_encoder_prenet = cls.build_speech_encoder_prenet(args)

        text_decoder_prenet = cls.build_text_decoder_prenet(text_decoder_embed_tokens, args)
        if getattr(args, "sid_pooling_layer", None) == "decoder-las":
            speech_decoder_prenet = cls.build_speech_encoder_prenet(args)
        else:
            speech_decoder_prenet = cls.build_speech_decoder_prenet(speech_odim, args)

        text_decoder_postnet = cls.build_text_decoder_postnet(text_decoder_embed_tokens, task.dicts['text'], args)
        speech_decoder_postnet = cls.build_speech_decoder_postnet(speech_odim, args)

        if getattr(args, "sid_t5_postnet", False):
            speaker_decoder_postnet = None
        else:
            if task.t5_task == "s2c":
                speaker_decoder_postnet = cls.build_speaker_decoder_postnet(args.sid_embed_dim, len(task.dicts['text']), args)
            else:
                speaker_decoder_postnet = None

        if "hubert" in task.dicts:
            speech_encoder_postnet = cls.build_speech_encoder_postnet(task.dicts['hubert'], args)
        else:
            speech_encoder_postnet = None

        return cls(
            args, 
            encoder, decoder, 
            text_encoder_prenet, speech_encoder_prenet,
            text_decoder_prenet, speech_decoder_prenet,
            text_decoder_postnet, speech_decoder_postnet,
            speaker_decoder_postnet, speech_encoder_postnet,
        )

    def get_normalized_probs(
        self,
        net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]],
        log_probs: bool,
        sample: Optional[Dict[str, Tensor]] = None,
    ):
        # net_output['encoder_out'] is a (B, T, D) tensor
        lprobs = self.get_normalized_probs_scriptable(net_output, log_probs, sample)
        lprobs.batch_first = True
        return lprobs

    def get_normalized_probs_for_ctc(self, net_output, log_probs):
        """Get normalized probabilities (or log probs) from a net's output."""

        logits = net_output["encoder_out_for_ctc"][0]
        if log_probs:
            return utils.log_softmax(logits.float(), dim=-1)
        else:
            return utils.softmax(logits.float(), dim=-1)

    def get_logits(self, net_output, is_masked=True):
        if is_masked:
            logits_list = net_output["logit_m_list"]
        else:
            logits_list = net_output["logit_u_list"]
        logits_list = [x.float() for x in logits_list if x is not None]
        return logits_list

    def get_targets(self, sample, net_output, is_masked=True):
        if "logit_m_list" in net_output:
            logits_list = self.get_logits(net_output, is_masked)
            targets_list = [
                x.new_zeros(x.size(0), dtype=torch.long) for x in logits_list
            ]
            return targets_list
        else:
            return sample["target"]

    def get_extra_losses(self, net_output):
        extra_losses = []
        names = []

        if "features_pen" in net_output:
            extra_losses.append(net_output["features_pen"])
            names.append("features_pen")

        if "prob_perplexity" in net_output:
            extra_losses.append(
                (net_output["num_vars"] - net_output["prob_perplexity"])
                / net_output["num_vars"]
            )
            names.append("prob_perplexity")

        return extra_losses, names

    def forward(self, source=None, src_tokens=None, src_lengths=None, prev_output_tokens=None, tgt_lengths=None, spkembs=None, target_list=None, task_name=None, padding_mask=None, only_hubert=False, only_ctc=False, feature_only=False, tgt_enc_layer=None, mask=True):
        """
        The forward method inherited from the base class has a **kwargs
        argument in its input, which is not supported in torchscript. This
        method overwrites the forward method definition without **kwargs.
        """
        assert source is not None or src_tokens is not None
        # padding_mask is not none only when input is waveform
        if source is None and padding_mask is None and not feature_only:
            input_type = 'text'
        else:
            input_type = 'speech'

        if prev_output_tokens is not None and len(prev_output_tokens.size()) == 2:
            output_type = 'text'
            codebook_out = {}
        else:
            output_type = 'speech'

        if task_name is not None and task_name == "s2c":
            if target_list is not None and target_list.size(1) == 1 and not getattr(self.args, "sid_t5_postnet", False):
                sid_target = F.one_hot(target_list.squeeze(1), num_classes=self.speaker_decoder_postnet.class_num)
            else:
                sid_target = None
            target_list = None

        # Encoder Prenet
        if input_type == 'text':
            encoder_input, encoder_padding_mask = self.text_encoder_prenet(src_tokens)
        else:
            if target_list is not None:
                encoder_input, encoder_padding_mask = self.speech_encoder_prenet(source, require_feat_pen=True, target_list=target_list, padding_mask=padding_mask, mask=mask)
                encoder_input, features_pen, mask_indices, target_list = encoder_input
            else:
                encoder_input, encoder_padding_mask = self.speech_encoder_prenet(source, padding_mask=padding_mask, mask=self.training)
                # shuffle a batch of inputs of encoder
                if self.training and hasattr(self.args, "sid_shuffle_encoder_input") and getattr(self.args, "sid_shuffle_encoder_input", False):
                    shuffle_index = torch.randperm(encoder_padding_mask.size(1), device=encoder_padding_mask.device)
                    encoder_input = torch.index_select(encoder_input, 1, shuffle_index)
                    encoder_padding_mask = torch.index_select(encoder_padding_mask, 1, shuffle_index)
                if getattr(self.args, "sid_encoder_cls", None) == "encoder":
                    prev_output_tokens = torch.zeros_like(prev_output_tokens)
                    encoder_input, encoder_padding_mask = self._integrate_with_speaker_cls(prev_output_tokens, encoder_input, encoder_padding_mask)

        # Encoder: T x B x C
        encoder_output = self.encoder(encoder_input, encoder_padding_mask, tgt_layer=tgt_enc_layer)

        if task_name is not None and task_name == 'speech_pretrain' and feature_only:
            return encoder_output["encoder_out"][0].transpose(0, 1)

        if task_name is not None and task_name == 's2c':
            if self.args.sid_pooling_layer == "encoder":
                return self.speaker_decoder_postnet(encoder_output["encoder_out"][0].transpose(0, 1).mean(1), sid_target), None
            elif self.args.sid_pooling_layer == "encoder-cls":
                return self.speaker_decoder_postnet(encoder_output["encoder_out"][0].transpose(0, 1)[:,0], sid_target), None
            elif self.args.sid_pooling_layer == "encoder-speaker" or getattr(self.args, "sid_decoder_speaker", False):
                return self.speaker_decoder_postnet(encoder_output["encoder_out"][0].transpose(0, 1), sid_target), None

        if target_list is not None:
            hubert_results = self.hubert_layer(
                encoder_output["encoder_out"][0].transpose(0, 1), 
                encoder_padding_mask, 
                mask_indices, 
                target_list
            )

            hubert_results['features_pen'] = features_pen

        if "decoder_input" in encoder_output and encoder_output["decoder_input"][0] is not None:
            # Change the encoder output to decoder input once set unb-enc-layer
            encoder_output["encoder_out"] = encoder_output["decoder_input"]

        if self.use_codebook:
            q = self.quantizer(encoder_output["encoder_out"][0].transpose(0, 1))

            # q["x"]: B x T x C
            # Sample indexs according to the codebook prob
            random_idx = torch.randperm(q["x"].size(1))[:int(q["x"].size(1) * self.codebook_prob)]
            # Make weight for q
            q_w = q["x"].new_zeros(q["x"].size(1))
            q_w[random_idx] = 1.0
            # Combine quantized codes and encoder output
            encoder_output["encoder_out"][0] = (
                q_w.view(-1, 1) * q["x"] + (- q_w + 1).view(-1, 1) * encoder_output["encoder_out"][0].transpose(0, 1)
            ).transpose(0, 1)

            # encoder_output["encoder_out"][0] = q["x"].transpose(0, 1)
            if output_type == 'speech':
                hubert_results["prob_perplexity"] = q["prob_perplexity"]
                hubert_results["code_perplexity"] = q["code_perplexity"]
                hubert_results["num_vars"] = q["num_vars"]
                hubert_results["temp"] = q["temp"]
            elif output_type == 'text':
                codebook_out["prob_perplexity"] = q["prob_perplexity"]
                codebook_out["code_perplexity"] = q["code_perplexity"]
                codebook_out["num_vars"] = q["num_vars"]
                codebook_out["temp"] = q["temp"]

        if only_hubert and target_list is not None:
            return hubert_results, None
        
        if only_ctc and task_name is not None and task_name == "s2t":
            return None, encoder_output
        elif not self.training and prev_output_tokens is None and task_name == "s2t" and task_name is not None:
            return encoder_output

        # Decoder Prenet
        if output_type == 'text':
            # _ is the incremental state
            prev_output_tokens, tgt_mask, _ = self.text_decoder_prenet(prev_output_tokens)
            if task_name is not None and task_name == 's2c':
                prev_output_tokens = torch.zeros_like(prev_output_tokens)
        else:
            # integrate speaker embedding
            if self.spk_embed_integration_type == "pre" and self.spk_embed_dim is not None:
                # Decoder Prenet
                prev_output_tokens, tgt_mask = self.speech_decoder_prenet(prev_output_tokens, tgt_lengths, spkembs)
            else:
                if self.spk_embed_dim is not None:
                    encoder_output["encoder_out"] = [self._integrate_with_spk_embed(
                        encoder_output["encoder_out"][0].transpose(0, 1), spkembs
                    ).transpose(0, 1)]

                prev_output_tokens, tgt_mask = self.speech_decoder_prenet(prev_output_tokens, tgt_lengths)

        # BART Sequence Classification: cat <pad> + feature before decoder
        if task_name is not None and task_name == 's2c' and self.args.sid_pooling_layer == "decoder-las":
            decoder_feat_input, decoder_feat_mask = self.speech_decoder_prenet(src_tokens, src_lengths)
            prev_output_tokens, tgt_mask = self._integrate_with_speaker_cls((prev_output_tokens, tgt_mask), decoder_feat_input, decoder_feat_mask, cls_first=False)
        
        # SE predict masking to corresponding inputs and source speech replaces the prev_output_tokens as the input of decoder
        if task_name is not None and task_name == "s2s" and getattr(self.args, "se_decoder_input", "previous_target") == "source":
            prev_output_tokens, tgt_mask = self.speech_decoder_prenet(src_tokens, src_lengths)

        # Decoder
        decoder_output, extra = self.decoder(prev_output_tokens, tgt_mask, encoder_output, 
                                             full_context_alignment=getattr(self.args, "decoder_full_context_alignment", False), 
                                             alignment_layer=(-1 if target_list is None and output_type == 'speech' else None))
        # Decoder Postnet
        if task_name is not None and task_name == 's2c':
            if not getattr(self.args, "sid_t5_postnet", False):
                if self.args.sid_pooling_layer == "decoder":
                    return self.speaker_decoder_postnet(decoder_output.mean(1), sid_target), None
                elif self.args.sid_pooling_layer == "decoder-las":
                    indices = (tgt_mask.eq(False).float().sum(1) - 1.0).type(torch.int64)
                    indices = indices.unsqueeze(1).unsqueeze(2).expand(-1, -1, decoder_output.size(2))
                    return self.speaker_decoder_postnet(decoder_output.gather(1, indices), sid_target), None
            else:
                return (self.text_decoder_postnet(decoder_output), None), encoder_output

        # SE predict: masking, target, delta. Ensure reduction factor 1
        if task_name is not None and task_name == 's2s' and getattr(self.args, "se_predict", None) is not None:
            assert self.reduction_factor == 1, f"{self.reduction_factor} != 1"
            before_outs, after_outs, logits = self.speech_decoder_postnet(decoder_output)
            se_predict = getattr(self.args, "se_predict")
            if se_predict == "masking":
                before_outs = torch.sigmoid(before_outs) * src_tokens
                after_outs = torch.sigmoid(after_outs) * src_tokens
                return before_outs, after_outs, logits, extra['attn'][0]
            elif se_predict == "target":
                return before_outs, after_outs, logits, extra['attn'][0]
            elif se_predict == "delta":
                before_outs = before_outs - src_tokens
                after_outs = after_outs - src_tokens
                return before_outs, after_outs, logits, extra['attn'][0]
            else:
                raise ValueError(f"{se_predict} not in [masking, target, delta]")

        if task_name is not None and task_name == 's2t':
            #return self.text_decoder_postnet(decoder_output), None
            return (self.text_decoder_postnet(decoder_output), None), encoder_output
        if output_type == 'text':
            return (self.text_decoder_postnet(decoder_output), None), codebook_out, encoder_output
        else:
            if target_list is not None:
                return hubert_results, (self.speech_decoder_postnet(decoder_output) + (extra['attn'][0],))
            else:
                return self.speech_decoder_postnet(decoder_output) + (extra['attn'][0],)

    def _integrate_with_speaker_cls(self, pad_input, encoder_input, encoder_padding_mask=None, cls_first=True):
        """
        encoder_input: [B, T, C]
        encoder_padding_mask: [B, T]
        """
        if hasattr(self, "text_decoder_prenet"):
            if isinstance(pad_input, tuple):
                repeat_cls_vector, repeat_cls_mask = pad_input
            else:
                repeat_cls_vector, repeat_cls_mask, _ = self.text_decoder_prenet(pad_input)

            if encoder_padding_mask is not None:
                bsz = encoder_input.size(0)
                tsz = encoder_input.size(1)
                encoder_padding_mask = encoder_input.new_zeros((bsz, tsz)) == 1.0
            if repeat_cls_mask is None:
                mask_size = (encoder_padding_mask.size(0), 1)
                mask_type = encoder_padding_mask.dtype
                repeat_cls_mask = encoder_padding_mask.new_zeros(mask_size) == 1.0
            ret_encoder_padding_mask = torch.cat([repeat_cls_mask, encoder_padding_mask], dim=1)

            if cls_first:
                ret_encoder_input = torch.cat([repeat_cls_vector, encoder_input], dim=1)
            else:
                ret_encoder_input = torch.cat([encoder_input, encoder_input[:,-1:,:]], dim=1)
                mask_size = (encoder_padding_mask.size(0), 1)
                mask_type = encoder_padding_mask.dtype
                repeat_cls_mask_ = encoder_padding_mask.new_ones(mask_size) == 1.0
                encoder_padding_mask_ = torch.cat([encoder_padding_mask, repeat_cls_mask_], dim=1)
                indices = encoder_padding_mask.eq(False).float().sum(1).type(torch.int64).unsqueeze(1)
                indices_mask = torch.zeros_like(ret_encoder_padding_mask).scatter(1, indices, 1.0)
                ret_encoder_input = ret_encoder_input * (1.0 - encoder_padding_mask_.type(ret_encoder_input.dtype).unsqueeze(2)) \
                    + repeat_cls_vector * indices_mask.type(repeat_cls_vector.dtype).unsqueeze(2)
            
        return ret_encoder_input, ret_encoder_padding_mask

    def _integrate_with_spk_embed(self, hs, spembs):
        """Integrate speaker embedding with hidden states.
        Args:
            hs (Tensor): Batch of hidden state sequences (B, Tmax, adim).
            spembs (Tensor): Batch of speaker embeddings (B, spk_embed_dim).
        Returns:
            Tensor: Batch of integrated hidden state sequences (B, Tmax, adim)
        """
        if self.spk_embed_integration_type == "add":
            # apply projection and then add to hidden states
            spembs = self.projection(F.normalize(spembs))
            hs = hs + spembs.unsqueeze(1)
        elif self.spk_embed_integration_type == "concat":
            # concat hidden states with spk embeds and then apply projection
            spembs = F.normalize(spembs).unsqueeze(1).expand(-1, hs.size(1), -1)
            hs = self.projection(torch.cat([hs, spembs], dim=-1))
        else:
            raise NotImplementedError("support only add or concat.")

        return hs

    def load_state_dict(
        self,
        state_dict,
        strict=True,
        model_cfg=None,
        args=None,
    ):
        """NOT STRICT Copies parameters and buffers from *state_dict* into this module and
        its descendants.

        Overrides the method in :class:`nn.Module`. Compared with that method
        this additionally "upgrades" *state_dicts* from old checkpoints.
        """
        # self.prune_modules(model_cfg.modules_filter)
        model_dict_size = self.text_decoder_postnet.output_projection.out_features
        ckpt_dict_size = state_dict["text_decoder_postnet.output_projection.weight"].size(0)
        if model_dict_size != ckpt_dict_size:
            # reset dictionary-related modules, such as embedding table and encoder ctc embed
            logger.warn(f"not equal dictionary between model and checkpoint: {model_dict_size} vs {ckpt_dict_size}")
            logger.info(f"reset model dictionary with size of {model_dict_size}")
            removed_keys = [
                key for key in state_dict.keys() if any(
                    key.startswith(previ) for previ in [
                        "encoder.proj", "text_encoder_prenet", "text_decoder_prenet", "text_decoder_postnet"
                    ]
                )
            ]
            for key in removed_keys:
                state_dict.pop(key, None)
                logger.info(f"removed loaded checkpoint: {key}")
        for m in self._modules.keys():
            m_state_dict = {
                key.replace(f"{m}.", ""): value for key, value in state_dict.items() if key.startswith(f"{m}.")
            }
            if hasattr(self, m):
                self._modules[m].load_state_dict(m_state_dict, False)
        return self

    def prune_modules(self, modules_filter=None):
        """Prune unused modules for specific tasks."""
        if modules_filter is None:
            return
        elif modules_filter == "s2c":
            if hasattr(self, "text_encoder_prenet"): del self.text_encoder_prenet
            if hasattr(self, "speech_decoder_prenet") and getattr(self.args, "sid_pooling_layer", None) != "decoder-las": 
                del self.speech_decoder_prenet
            if hasattr(self, "speech_decoder_postnet"): del self.speech_decoder_postnet
            if hasattr(self, "text_decoder_postnet"): del self.text_decoder_postnet
            if hasattr(self, "speech_encoder_postnet"): del self.speech_encoder_postnet
            if hasattr(self.encoder, "proj"): self.encoder.proj = None
            if hasattr(self, "projection"): del self.projection
            if hasattr(self, "quantizer"): del self.quantizer
            if getattr(self.args, "sid_pooling_layer", "decoder").startswith("encoder") or getattr(self.args, "sid_decoder_speaker", False): 
                if hasattr(self.decoder, "dropout_module"): del self.decoder.dropout_module
                if hasattr(self.decoder, "layers"): del self.decoder.layers
                if hasattr(self.decoder, "layer_norm"): del self.decoder.layer_norm
                if hasattr(self, "text_decoder_prenet"): del self.text_decoder_prenet
        elif modules_filter == "s2s":
            if hasattr(self, "speaker_decoder_postnet"): del self.speaker_decoder_postnet
            if hasattr(self, "text_encoder_prenet"): del self.text_encoder_prenet
            if hasattr(self, "text_decoder_prenet"): del self.text_decoder_prenet
            if hasattr(self, "text_decoder_postnet"): del self.text_decoder_postnet
            if hasattr(self, "speech_encoder_postnet"): del self.speech_encoder_postnet
            if hasattr(self.encoder, "proj"): self.encoder.proj = None
            if hasattr(self, "projection"): del self.projection
            if hasattr(self, "quantizer"): del self.quantizer
        elif modules_filter == "t2s":
            if hasattr(self, "speaker_decoder_postnet"): del self.speaker_decoder_postnet
            if hasattr(self, "speech_encoder_prenet"): del self.speech_encoder_prenet
            if hasattr(self, "text_decoder_prenet"): del self.text_decoder_prenet
            if hasattr(self, "text_decoder_postnet"): del self.text_decoder_postnet
            if hasattr(self, "speech_encoder_postnet"): del self.speech_encoder_postnet
            if hasattr(self.encoder, "proj"): self.encoder.proj = None
            if hasattr(self, "projection"): del self.projection
            if hasattr(self, "quantizer"): del self.quantizer
        elif modules_filter == "s3prl":
            # remain the encoder and the pre/post net
            if hasattr(self.decoder, "dropout_module"): del self.decoder.dropout_module
            if hasattr(self.decoder, "layers"): del self.decoder.layers
            if hasattr(self.decoder, "layer_norm"): del self.decoder.layer_norm
            if hasattr(self, "speaker_decoder_postnet"): del self.speaker_decoder_postnet
            if hasattr(self, "text_decoder_prenet"): del self.text_decoder_prenet
            if hasattr(self, "text_decoder_postnet"): del self.text_decoder_postnet
            if hasattr(self, "speech_decoder_prenet"): del self.speech_decoder_prenet
            if hasattr(self, "speech_decoder_postnet"): del self.speech_decoder_postnet
            if hasattr(self, "speech_encoder_postnet"): del self.speech_encoder_postnet
            if hasattr(self.encoder, "proj"): self.encoder.proj = None
            if hasattr(self, "projection"): del self.projection
            if hasattr(self, "quantizer"): del self.quantizer

    def forward_encoder_torchscript(self, net_input: Dict[str, Tensor]):
        """A TorchScript-compatible version of forward.

        Encoders which use additional arguments may want to override
        this method for TorchScript compatibility.
        """
        if torch.jit.is_scripting():
            return self.forward_encoder(
                source=net_input["source"],
                padding_mask=net_input["padding_mask"]
            )
        else:
            return self.forward_encoder_non_torchscript(net_input)

    @torch.jit.unused
    def forward_encoder_non_torchscript(self, net_input: Dict[str, Tensor]):
        encoder_input = {
            k: v for k, v in net_input.items() if k != "prev_output_tokens" and k != "task_name"
        }
        return self.forward_encoder(**encoder_input)

    def forward_encoder(self, source, padding_mask=None):
        # Encoder Prenet
        encoder_input, encoder_padding_mask = self.speech_encoder_prenet(source, padding_mask=padding_mask, mask=False)

        # Encoder
        encoder_output = self.encoder(encoder_input, encoder_padding_mask)

        return encoder_output

    def forward_text_encoder(self, src_tokens):
        # Text Encoder Prenet
        encoder_input, encoder_padding_mask = self.text_encoder_prenet(src_tokens)

        # Encoder
        encoder_output = self.encoder(encoder_input, encoder_padding_mask)

        return encoder_output

    def forward_decoder(self, tokens, encoder_out, incremental_state):
        # Decoder Prenet
        prev_output_tokens, tgt_mask, incremental_state = self.text_decoder_prenet(tokens, incremental_state)

        # Decoder
        decoder_output, extra = self.decoder(
            prev_output_tokens,
            tgt_mask,
            encoder_out=encoder_out,
            incremental_state=incremental_state,
        )

        # Decoder Postnet
        return self.text_decoder_postnet(decoder_output), extra

    def set_num_updates(self, num_updates):
        """Set the number of parameters updates."""
        super().set_num_updates(num_updates)
        self.num_updates = num_updates

    def generate_class(self, source, prev_output_tokens, **kwargs):
        encoder_out = self.forward_encoder(source, padding_mask=kwargs["padding_mask"])

        prev_output_tokens, tgt_mask, _ = self.text_decoder_prenet(prev_output_tokens, {})
        prev_output_tokens = torch.zeros_like(prev_output_tokens) # s2c use zero vector as [CLS]

        decoder_output, extra = self.decoder(
            prev_output_tokens,
            tgt_mask,
            encoder_out=encoder_out,
        )

        decoder_out, embed = self.speaker_decoder_postnet(decoder_output.mean(1))

        pred_class = decoder_out.argmax(1)
        return pred_class

    def generate_speech(self, source=None, src_tokens=None, spkembs=None, **kwargs):
        assert source is not None or src_tokens is not None

        threshold = kwargs.get("threshold", 0.5)
        minlenratio = kwargs.get("threshold", 0.0)

        if source is None:
            assert src_tokens.size(0) == 1
            encoder_out = self.forward_text_encoder(src_tokens)
            maxlenratio = kwargs.get("threshold", 20.0)
        else:
            assert source.size(0) == 1
            encoder_out = self.forward_encoder(source, padding_mask=kwargs["padding_mask"])
            maxlenratio = kwargs.get("threshold", 10.0)

        if spkembs is not None and self.spk_embed_integration_type != "pre":
            encoder_out["encoder_out"] = [self._integrate_with_spk_embed(
                encoder_out["encoder_out"][0].transpose(0, 1), spkembs
            ).transpose(0, 1)]
            spkembs = None

        maxlen = int(encoder_out["encoder_out"][0].size(0) * maxlenratio / self.reduction_factor)
        minlen = int(encoder_out["encoder_out"][0].size(0) * minlenratio / self.reduction_factor)
        
        idx = 0
        ys = encoder_out["encoder_out"][0].new_zeros(1, 1, self.speech_decoder_postnet.odim)
        outs, probs = [], []

        # forward decoder step-by-step
        if isinstance(self.decoder, FairseqIncrementalDecoder):
            incremental_states = {}
        else:
            incremental_states = None
        attns = []
        while True:
            # update index
            idx += 1
            # calculate output and stop prob at idx-th step
            decoder_in, _ = self.speech_decoder_prenet(ys, spkembs=spkembs)
            z, extra = self.decoder(decoder_in[:,-1:], None, encoder_out, incremental_states, alignment_layer=-1)
            outs += [self.speech_decoder_postnet.feat_out(z[0, -1]).view(self.reduction_factor, self.speech_decoder_postnet.odim)]  # [(r, odim), ...]
            probs += [torch.sigmoid(self.speech_decoder_postnet.prob_out(z[0, -1]))]  # [(r), ...]

            # update next inputs
            ys = torch.cat((ys, outs[-1][-1].view(1, 1, self.speech_decoder_postnet.odim)), dim=1)  # (1, idx + 1, odim)
            attns.append(torch.stack([att_l[0] for att_l in extra['attn'][0]], dim=0))
            # check whether to finish generation
            if int(sum(probs[-1] >= threshold)) > 0 or idx >= maxlen:
                # check mininum length
                if idx < minlen:
                    continue
                outs = (torch.cat(outs, dim=0).unsqueeze(0).transpose(1, 2))  # (L, odim) -> (1, L, odim) -> (1, odim, L)
                if self.speech_decoder_postnet.postnet is not None:
                    outs = outs + self.speech_decoder_postnet.postnet(outs)  # (1, odim, L)
                outs = outs.transpose(2, 1).squeeze(0)  # (L, odim)
                probs = torch.cat(probs, dim=0)
                attn = torch.cat(attns, dim=2)
                break

        if outs.size(0) == maxlen:
            logging.warning("output length reaches maximum length")
        return outs, probs, attn


@register_model_architecture(model_name="artst_transformer", arch_name="artst_transformer")
def base_architecture(args):
    # Transformer
    args.bert_init = getattr(args, "bert_init", False)
    args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 768)
    args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 768 * 4)
    args.encoder_layers = getattr(args, "encoder_layers", 12)
    args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 12)
    args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
    args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim)
    args.decoder_ffn_embed_dim = getattr(
        args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim
    )
    args.decoder_layers = getattr(args, "decoder_layers", 6)
    args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 12)
    args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False)
    args.dropout = getattr(args, "dropout", 0.1)
    args.attention_dropout = getattr(args, "attention_dropout", args.dropout)
    args.activation_dropout = getattr(args, "activation_dropout", args.dropout)
    args.activation_fn = getattr(args, "activation_fn", "gelu")
    args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0.0)
    args.decoder_output_dim = getattr(
        args, "decoder_output_dim", args.decoder_embed_dim
    )
    args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim)
    args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0)
    args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0)
    args.max_text_positions = getattr(args, "max_text_positions", DEFAULT_MAX_TEXT_POSITIONS)
    args.max_speech_positions = getattr(args, "max_speech_positions", DEFAULT_MAX_SPEECH_POSITIONS)

    # Espnet related, including prenet, postnet
    args.eprenet_conv_layers = getattr(args, "eprenet_conv_layers", 0)
    args.eprenet_conv_filts = getattr(args, "eprenet_conv_filts", 0)
    args.eprenet_conv_chans = getattr(args, "eprenet_conv_chans", 0)
    args.use_batch_norm = getattr(args, "use_batch_norm", True)
    args.eprenet_dropout_rate = getattr(args, "eprenet_dropout_rate", 0.0)
    args.enc_use_scaled_pos_enc = getattr(args, "enc_use_scaled_pos_enc", True)
    args.dec_use_scaled_pos_enc = getattr(args, "dec_use_scaled_pos_enc", True)
    args.postnet_layers = getattr(args, "postnet_layers", 5)
    args.postnet_chans = getattr(args, "postnet_chans", 256)
    args.postnet_filts = getattr(args, "postnet_filts", 5)
    args.postnet_dropout_rate = getattr(args, "postnet_dropout_rate", 0.5)
    args.dprenet_dropout_rate = getattr(args, "dprenet_dropout_rate", 0.5)
    args.dprenet_layers = getattr(args, "dprenet_layers", 2)
    args.dprenet_units = getattr(args, "dprenet_units", 256)
    args.initial_encoder_alpha = getattr(args, "initial_encoder_alpha", 1.0)
    args.initial_decoder_alpha = getattr(args, "initial_decoder_alpha", 1.0)
    args.spk_embed_integration_type = getattr(args, "spk_embed_integration_type", "pre")
    args.spk_embed_dim = getattr(args, "spk_embed_dim", 512)
    args.encoder_reduction_factor = getattr(args, "encoder_reduction_factor", 1)
    args.reduction_factor = getattr(args, "reduction_factor", 2)
    args.transformer_enc_positional_dropout_rate = getattr(args, "transformer_enc_positional_dropout_rate", 0.1)
    args.transformer_dec_positional_dropout_rate = getattr(args, "transformer_dec_positional_dropout_rate", 0.1)
    args.layer_norm_eps = getattr(args, "layer_norm_eps", 1e-5)
    args.no_scale_embedding = getattr(args, "no_scale_embedding", True)
    # Convolutional subsampler
    args.encoder_speech_prenet = getattr(args, "encoder_speech_prenet", "conv")
    args.conv_kernel_sizes = getattr(args, "conv_kernel_sizes", "5,5")
    args.conv_channels = getattr(args, "conv_channels", 1024)
    args.quant_noise_pq = getattr(args, "quant_noise_pq", 0)

    args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None)
    args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
    args.no_token_positional_embeddings = getattr(
        args, "no_token_positional_embeddings", False
    )
    args.adaptive_input = getattr(args, "adaptive_input", False)
    args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False)
    args.share_input_output_embed = getattr(args, "share_input_output_embed", False)
    args.share_ctc_embed = getattr(args, "share_ctc_embed", False)
    args.freeze_encoder_updates = getattr(args, "freeze_encoder_updates", 0)
    args.freeze_decoder_updates = getattr(args, "freeze_decoder_updates", 0)
    args.no_freeze_encoder_layer = getattr(args, "no_freeze_encoder_layer", None)

    ## sid
    args.sid_embed_dim = getattr(args, "sid_embed_dim", 128)
    args.sid_pooling_layer = getattr(args, "sid_pooling_layer", "decoder")
    args.softmax_scale = getattr(args, "softmax_scale", 1)
    args.softmax_margin = getattr(args, "softmax_margin", 0)
    args.softmax_easy_margin = getattr(args, "softmax_easy_margin", False)
    args.modules_filter = getattr(args, "modules_filter", None)

    ## Hubert
    args.conv_pos = getattr(args, "conv_pos", 128)
    args.conv_pos_groups = getattr(args, "conv_pos_groups", 16)
    args.target_glu = getattr(args, "target_glu", False)
    args.logit_temp = getattr(args, "logit_temp", 0.1)
    args.final_dim = getattr(args, "final_dim", 256)
    args.untie_final_proj = getattr(args, "untie_final_proj", True)
    args.feature_grad_mult = getattr(args, "feature_grad_mult", 0.1)
    args.use_sent_enc_layer = getattr(args, "use_sent_enc_layer", True)
    # hubert feature extractor
    args.extractor_mode = getattr(args, "extractor_mode", "default")
    args.conv_feature_layers = getattr(args, "conv_feature_layers", "[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2")
    args.conv_bias = getattr(args, "conv_bias", False)
    # mask
    args.hubert_mask_length = getattr(args, "hubert_mask_length", 10)
    args.mask_prob = getattr(args, "mask_prob", 0.0)
    args.mask_selection = getattr(args, "mask_selection", "static")
    args.mask_other = getattr(args, "mask_other", 0)
    args.no_mask_overlap = getattr(args, "no_mask_overlap", False)
    args.mask_min_space = getattr(args, "mask_min_space", 1)
    # channel mask
    args.mask_channel_length = getattr(args, "mask_channel_length", 10)
    args.mask_channel_prob = getattr(args, "mask_channel_prob", 0.0)
    args.mask_channel_selection = getattr(args, "mask_channel_selection", "static")
    args.mask_channel_other = getattr(args, "mask_channel_other", 0)
    args.no_mask_channel_overlap = getattr(args, "no_mask_channel_overlap", False)
    args.mask_channel_min_space = getattr(args, "mask_channel_min_space", 1)
    # loss computation
    args.skip_masked = getattr(args, "skip_masked", False)
    args.skip_nomask = getattr(args, "skip_nomask", False)
    # conv Pos
    args.use_conv_pos = getattr(args, "use_conv_pos", False)
    args.use_sinc_pos = getattr(args, "use_sinc_pos", False)

    # codebook
    args.use_codebook = getattr(args, "use_codebook", False)
    args.latent_vars = getattr(args, "latent_vars", 100)
    args.latent_groups = getattr(args, "latent_groups", 2)
    args.latent_dim = getattr(args, "latent_dim", 0)
    args.latent_temp = getattr(args, "latent_temp", (2, 0.5, 0.999995))
    args.quantizer_depth = getattr(args, "quantizer_depth", 1)
    args.quantizer_factor = getattr(args, "quantizer_factor", 3)
    args.codebook_prob = getattr(args, "codebook_prob", 0.5)

    # Relative pos embed
    args.relative_position_embedding = getattr(args, "relative_position_embedding", False)
    args.num_buckets = getattr(args, "num_buckets", 320)
    args.max_distance = getattr(args, "max_distance", 1280)
    args.encoder_max_relative_position = getattr(args, "encoder_max_relative_position", 160)
    args.decoder_max_relative_position = getattr(args, "decoder_max_relative_position", 160)

@register_model_architecture("artst_transformer", "artst_transformer_base")
def artst_transformer_base(args):
    args.use_conv_pos = getattr(args, "use_conv_pos", True)
    args.use_sinc_pos = getattr(args, "use_sinc_pos", True)
    args.layernorm_embedding = getattr(args, "layernorm_embedding", False)
    args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
    args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False)
    args.layer_norm_first = getattr(args, "layer_norm_first", False)
    args.relative_position_embedding = getattr(args, "relative_position_embedding", True)
    args.dropout = getattr(args, "dropout", 0.1)
    args.activation_dropout = getattr(args, "activation_dropout", 0.0)
    args.attention_dropout = getattr(args, "attention_dropout", 0.1)
    args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0.05)
    args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0.05)
    args.mask_prob = getattr(args, "mask_prob", 0.80)
    base_architecture(args)

@register_model_architecture("artst_transformer", "artst_transformer_large")
def artst_transformer_large(args):
    args.use_conv_pos = getattr(args, "use_conv_pos", True)
    args.use_sinc_pos = getattr(args, "use_sinc_pos", True)
    args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
    args.decoder_normalize_before = getattr(args, "decoder_normalize_before", True)
    args.layer_norm_first = getattr(args, "layer_norm_first", True)
    args.relative_position_embedding = getattr(args, "relative_position_embedding", True)
    args.dropout = getattr(args, "dropout", 0.0)
    args.activation_dropout = getattr(args, "activation_dropout", 0.0)
    args.attention_dropout = getattr(args, "attention_dropout", 0.0)
    args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0.0)
    args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0.0)
    args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024)
    args.encoder_layers = getattr(args, "encoder_layers", 24)
    args.decoder_layers = getattr(args, "decoder_layers", 6)
    args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4096)
    args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16)
    args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16)
    args.feature_grad_mult = getattr(args, "feature_grad_mult", 1.0)
    args.extractor_mode = getattr(args, "extractor_mode", "layer_norm")
    args.final_dim = getattr(args, "final_dim", 768)
    args.mask_prob = getattr(args, "mask_prob", 0.80)
    base_architecture(args)

@register_model_architecture("artst_transformer", "artst_transformer_base_asr")
def artst_transformer_base_asr(args):
    args.use_conv_pos = getattr(args, "use_conv_pos", True)
    args.use_sinc_pos = getattr(args, "use_sinc_pos", True)
    args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
    args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False)
    args.layer_norm_first = getattr(args, "layer_norm_first", False)
    args.relative_position_embedding = getattr(args, "relative_position_embedding", True)
    args.dropout = getattr(args, "dropout", 0.1)
    args.activation_dropout = getattr(args, "activation_dropout", 0.1)
    args.attention_dropout = getattr(args, "attention_dropout", 0.1)
    args.feature_grad_mult = getattr(args, "feature_grad_mult", 0.0)
    args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0.1)
    args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0.1)
    args.mask_prob = getattr(args, "mask_prob", 0.75)
    args.mask_selection = getattr(args, "mask_selection", "static")
    args.mask_channel_length = getattr(args, "mask_channel_length", 64)
    args.mask_channel_prob = getattr(args, "mask_channel_prob", 0.5)
    args.mask_channel_selection = getattr(args, "mask_channel_selection", "static")
    args.max_text_positions = getattr(args, "max_text_positions", 600)
    base_architecture(args)