diff --git "a/modelling_afrolid.py" "b/modelling_afrolid.py" new file mode 100644--- /dev/null +++ "b/modelling_afrolid.py" @@ -0,0 +1,2890 @@ +import math +from typing import Any, Optional +import torch +import torch.onnx.operators +from torch import nn, Tensor +import torch.nn as nn +from typing import Optional, Dict, List, Any, Tuple +import torch.nn as nn +import torch.nn.functional as F +import torch +import sys +import torch.distributed as dist +import uuid +from dataclasses import dataclass, field, asdict +from transformers.modeling_utils import PreTrainedModel +from transformers import AutoConfig, AutoModel, AutoModelForSequenceClassification +from .configuration_afrolid import AfroLidConfig + + + +def quant_noise(module, p, block_size): + """ + Wraps modules and applies quantization noise to the weights for + subsequent quantization with Iterative Product Quantization as + described in "Training with Quantization Noise for Extreme Model Compression" + + Args: + - module: nn.Module + - p: amount of Quantization Noise + - block_size: size of the blocks for subsequent quantization with iPQ + + Remarks: + - Module weights must have the right sizes wrt the block size + - Only Linear, Embedding and Conv2d modules are supported for the moment + - For more detail on how to quantize by blocks with convolutional weights, + see "And the Bit Goes Down: Revisiting the Quantization of Neural Networks" + - We implement the simplest form of noise here as stated in the paper + which consists in randomly dropping blocks + """ + + # if no quantization noise, don't register hook + if p <= 0: + return module + + # supported modules + assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d)) + + # test whether module.weight has the right sizes wrt block_size + is_conv = module.weight.ndim == 4 + + # 2D matrix + if not is_conv: + assert ( + module.weight.size(1) % block_size == 0 + ), "Input features must be a multiple of block sizes" + + # 4D matrix + else: + # 1x1 convolutions + if module.kernel_size == (1, 1): + assert ( + module.in_channels % block_size == 0 + ), "Input channels must be a multiple of block sizes" + # regular convolutions + else: + k = module.kernel_size[0] * module.kernel_size[1] + assert k % block_size == 0, "Kernel size must be a multiple of block size" + + def _forward_pre_hook(mod, input): + # no noise for evaluation + if mod.training: + if not is_conv: + # gather weight and sizes + weight = mod.weight + in_features = weight.size(1) + out_features = weight.size(0) + + # split weight matrix into blocks and randomly drop selected blocks + mask = torch.zeros( + in_features // block_size * out_features, device=weight.device + ) + mask.bernoulli_(p) + mask = mask.repeat_interleave(block_size, -1).view(-1, in_features) + + else: + # gather weight and sizes + weight = mod.weight + in_channels = mod.in_channels + out_channels = mod.out_channels + + # split weight matrix into blocks and randomly drop selected blocks + if mod.kernel_size == (1, 1): + mask = torch.zeros( + int(in_channels // block_size * out_channels), + device=weight.device, + ) + mask.bernoulli_(p) + mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels) + else: + mask = torch.zeros( + weight.size(0), weight.size(1), device=weight.device + ) + mask.bernoulli_(p) + mask = ( + mask.unsqueeze(2) + .unsqueeze(3) + .repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1]) + ) + + # scale weights and apply mask + mask = mask.to( + torch.bool + ) # x.bool() is not currently supported in TorchScript + s = 1 / (1 - p) + mod.weight.data = s * weight.masked_fill(mask, 0) + + module.register_forward_pre_hook(_forward_pre_hook) + return module + +def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True, export=False): + # if torch.jit.is_scripting() or torch.jit.is_tracing(): + # export = True + # if not export and torch.cuda.is_available() and has_fused_layernorm: + # return FusedLayerNorm(normalized_shape, eps, elementwise_affine) + return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine) + +class LayerDropModuleList(nn.ModuleList): + """ + A LayerDrop implementation based on :class:`torch.nn.ModuleList`. + + We refresh the choice of which layers to drop every time we iterate + over the LayerDropModuleList instance. During evaluation we always + iterate over all layers. + + Usage:: + + layers = LayerDropList(p=0.5, modules=[layer1, layer2, layer3]) + for layer in layers: # this might iterate over layers 1 and 3 + x = layer(x) + for layer in layers: # this might iterate over all layers + x = layer(x) + for layer in layers: # this might not iterate over any layers + x = layer(x) + + Args: + p (float): probability of dropping out each layer + modules (iterable, optional): an iterable of modules to add + """ + + def __init__(self, p, modules=None): + super().__init__(modules) + self.p = p + + def __iter__(self): + dropout_probs = torch.empty(len(self)).uniform_() + for i, m in enumerate(super().__iter__()): + if not self.training or (dropout_probs[i] > self.p): + yield m + +from typing import List, Callable +from typing import Dict +import warnings + +def gelu_accurate(x): + if not hasattr(gelu_accurate, "_a"): + gelu_accurate._a = math.sqrt(2 / math.pi) + return ( + 0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3)))) + ) + +def deprecation_warning(message, stacklevel=3): + # don't use DeprecationWarning, since it's ignored by default + warnings.warn(message, stacklevel=stacklevel) + +def gelu(x: torch.Tensor) -> torch.Tensor: + return torch.nn.functional.gelu(x.float()).type_as(x) + +def relu_squared(x: torch.Tensor): + return F.relu(x).pow(2) + +def get_activation_fn(activation: str) -> Callable: + """Returns the activation function corresponding to `activation`""" + + if activation == "relu": + return F.relu + elif activation == "relu_squared": + return relu_squared + elif activation == "gelu": + return gelu + elif activation == "gelu_fast": + deprecation_warning( + "--activation-fn=gelu_fast has been renamed to gelu_accurate" + ) + return gelu_accurate + elif activation == "gelu_accurate": + return gelu_accurate + elif activation == "tanh": + return torch.tanh + elif activation == "linear": + return lambda x: x + elif activation == "swish": + return torch.nn.SiLU + else: + raise RuntimeError("--activation-fn {} not supported".format(activation)) + + +class FairseqDropout(nn.Module): + def __init__(self, p, module_name=None): + super().__init__() + self.p = p + self.module_name = module_name + self.apply_during_inference = False + + def forward(self, x, inplace: bool = False): + if self.p > 0 and (self.training or self.apply_during_inference): + return F.dropout(x, p=self.p, training=True, inplace=inplace) + else: + return x + + +class TransformerEncoderLayerBase(nn.Module): + + """Encoder layer block. + + In the original paper each operation (multi-head attention or FFN) is + postprocessed with: `dropout -> add residual -> layernorm`. In the + tensor2tensor code they suggest that learning is more robust when + preprocessing each layer with layernorm and postprocessing with: + `dropout -> add residual`. We default to the approach in the paper, but the + tensor2tensor approach can be enabled by setting + *cfg.encoder.normalize_before* to ``True``. + + Args: + args (argparse.Namespace): parsed command-line arguments + """ + + def __init__(self, cfg, return_fc=False): + super().__init__() + self.cfg = cfg + self.return_fc = return_fc + self.embed_dim = cfg.encoder.embed_dim + self.quant_noise = cfg.quant_noise.pq + self.quant_noise_block_size = cfg.quant_noise.pq_block_size + self.self_attn = self.build_self_attention(self.embed_dim, cfg) + self.self_attn_layer_norm = LayerNorm(self.embed_dim, export=cfg.export) + self.dropout_module = FairseqDropout( + cfg.dropout, module_name=self.__class__.__name__ + ) + self.activation_fn = get_activation_fn(activation=cfg.activation_fn) + activation_dropout_p = cfg.activation_dropout + if activation_dropout_p == 0: + # for backwards compatibility with models that use cfg.relu_dropout + activation_dropout_p = cfg.relu_dropout or 0 + self.activation_dropout_module = FairseqDropout( + float(activation_dropout_p), module_name=self.__class__.__name__ + ) + self.normalize_before = cfg.encoder.normalize_before + self.fc1 = self.build_fc1( + self.embed_dim, + cfg.encoder.ffn_embed_dim, + self.quant_noise, + self.quant_noise_block_size, + ) + self.fc2 = self.build_fc2( + cfg.encoder.ffn_embed_dim, + self.embed_dim, + self.quant_noise, + self.quant_noise_block_size, + ) + + self.final_layer_norm = LayerNorm(self.embed_dim, export=cfg.export) + + self.num_heads = cfg.encoder.attention_heads + self.load_to_BT = False + self.ever_training = False + # For BT, we need continuous mem + self.in_proj_weight = torch.nn.Parameter( + torch.zeros( + self.self_attn.q_proj.weight.shape[0] * 3, + self.self_attn.q_proj.weight.shape[1], + ) + ) + self.in_proj_bias = torch.nn.Parameter( + torch.zeros(self.self_attn.q_proj.bias.shape[0] * 3) + ) + self.out_proj_weight = torch.nn.Parameter( + torch.zeros(self.self_attn.out_proj.weight.shape) + ) + self.out_proj_bias = torch.nn.Parameter( + torch.zeros(self.self_attn.out_proj.bias.shape) + ) + self.fc1_weight = torch.nn.Parameter(torch.zeros(self.fc1.weight.shape)) + self.fc1_bias = torch.nn.Parameter(torch.zeros(self.fc1.bias.shape)) + self.fc2_weight = torch.nn.Parameter(torch.zeros(self.fc2.weight.shape)) + self.fc2_bias = torch.nn.Parameter(torch.zeros(self.fc2.bias.shape)) + + if ( + self.activation_fn is torch.nn.functional.relu + or isinstance(self.activation_fn, torch.nn.ReLU) + or self.activation_fn == "relu" + ): + self.activation_relu_or_gelu = 1 + elif ( + self.activation_fn is torch.nn.functional.gelu + or isinstance(self.activation_fn, torch.nn.GELU) + or self.activation_fn == "gelu" + ): + self.activation_relu_or_gelu = 2 + else: + self.activation_relu_or_gelu = 0 + # Batch first can not be justified but needs user to make sure + self.can_use_fastpath = None + self.cfg_checkpoint_activations = self.cfg.checkpoint_activations + # torch version check + # make sure BT version is >=1.12.0 + self.BT_version = False + if "fb" in torch.__version__: + self.BT_version = True + else: + if "+" in torch.__version__: + self.torch_version = torch.__version__.split("+")[0] + else: + self.torch_version = torch.__version__ + + self.torch_version = self.torch_version.split(".") + self.int_version = ( + int(self.torch_version[0]) * 1000 + + int(self.torch_version[1]) * 10 + + int(self.torch_version[2]) + ) + if len(self.torch_version) == 3: + if self.int_version >= 1120: + self.BT_version = True + elif len(self.torch_version) == 4: + if self.int_version >= 1130: + self.BT_version = True + + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + self.load_to_BT = True + + old_name = prefix + "self_attn." + q_proj_weight = state_dict[old_name + "q_proj.weight"] + k_proj_weight = state_dict[old_name + "k_proj.weight"] + v_proj_weight = state_dict[old_name + "v_proj.weight"] + q_proj_bias = state_dict[old_name + "q_proj.bias"] + k_proj_bias = state_dict[old_name + "k_proj.bias"] + v_proj_bias = state_dict[old_name + "v_proj.bias"] + + new_name = prefix + state_dict[new_name + "in_proj_weight"] = torch.cat( + (q_proj_weight, k_proj_weight, v_proj_weight), dim=0 + ) + state_dict[new_name + "in_proj_bias"] = torch.cat( + (q_proj_bias, k_proj_bias, v_proj_bias), dim=0 + ) + state_dict[new_name + "out_proj_weight"] = state_dict[ + old_name + "out_proj.weight" + ] + state_dict[new_name + "out_proj_bias"] = state_dict[old_name + "out_proj.bias"] + state_dict[new_name + "fc1_weight"] = state_dict[prefix + "fc1.weight"] + state_dict[new_name + "fc1_bias"] = state_dict[prefix + "fc1.bias"] + state_dict[new_name + "fc2_weight"] = state_dict[prefix + "fc2.weight"] + state_dict[new_name + "fc2_bias"] = state_dict[prefix + "fc2.bias"] + super(TransformerEncoderLayerBase, self)._load_from_state_dict( + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ) + + def build_fc1(self, input_dim, output_dim, q_noise, qn_block_size): + return quant_noise( + nn.Linear(input_dim, output_dim), p=q_noise, block_size=qn_block_size + ) + + def build_fc2(self, input_dim, output_dim, q_noise, qn_block_size): + return quant_noise( + nn.Linear(input_dim, output_dim), p=q_noise, block_size=qn_block_size + ) + + def _get_fc_rank(self, remove_num: int) -> List[int]: + f1_filter_param = [] + for i in range(self.fc1.out_features): + f1_filter_param.append( + torch.sum(torch.abs(self.fc1.weight[i])) + + torch.sum(torch.abs(self.fc2.weight[:, i])) + + torch.abs(self.fc1.bias[i]) + ) + return sorted( + range(len(f1_filter_param)), key=lambda k: f1_filter_param[k], reverse=False + )[0:remove_num] + + def _prune_fc_layer(self, remove_index: List[int]): + new_fc1_weight = [] + new_fc1_bias = [] + for i in range(self.fc1.out_features): + if i not in remove_index: + new_fc1_weight.append(self.fc1.weight[i]) + new_fc1_bias.append(self.fc1.bias[i]) + + new_fc1_weight = torch.stack(new_fc1_weight).detach() + new_fc1_weight.requires_grad = True + + new_fc1_bias = torch.stack(new_fc1_bias).detach() + new_fc1_bias.requires_grad = True + + self.fc1 = quant_noise( + nn.Linear(self.fc1.in_features, self.fc1.out_features - len(remove_index)), + p=self.quant_noise, + block_size=self.quant_noise_block_size, + ) + self.fc1.weight = torch.nn.Parameter(new_fc1_weight) + self.fc1.bias = torch.nn.Parameter(new_fc1_bias) + + new_fc2_weight = [] + new_fc2_bias = [] + for i in range(self.fc2.in_features): + if i not in remove_index: + new_fc2_weight.append(self.fc2.weight[:, i]) + new_fc2_bias = self.fc2.bias.detach() + + new_fc2_weight = torch.stack(new_fc2_weight, dim=-1).detach() + new_fc2_weight.requires_grad = True + + new_fc2_bias = self.fc2.bias.detach() + new_fc2_bias.requires_grad = True + + self.fc2 = quant_noise( + nn.Linear(self.fc2.in_features - len(remove_index), self.fc2.out_features), + p=self.quant_noise, + block_size=self.quant_noise_block_size, + ) + self.fc2.weight = torch.nn.Parameter(new_fc2_weight) + self.fc2.bias = torch.nn.Parameter(new_fc2_bias) + + def build_self_attention(self, embed_dim, cfg): + return MultiheadAttention( + embed_dim, + cfg.encoder.attention_heads, + dropout=cfg.attention_dropout, + self_attention=True, + q_noise=self.quant_noise, + qn_block_size=self.quant_noise_block_size, + xformers_att_config=cfg.encoder.xformers_att_config, + ) + + def residual_connection(self, x, residual): + return residual + x + + def upgrade_state_dict_named(self, state_dict, name): + """ + Rename layer norm states from `...layer_norms.0.weight` to + `...self_attn_layer_norm.weight` and `...layer_norms.1.weight` to + `...final_layer_norm.weight` + """ + layer_norm_map = {"0": "self_attn_layer_norm", "1": "final_layer_norm"} + for old, new in layer_norm_map.items(): + for m in ("weight", "bias"): + k = "{}.layer_norms.{}.{}".format(name, old, m) + if k in state_dict: + state_dict["{}.{}.{}".format(name, new, m)] = state_dict[k] + del state_dict[k] + + def forward( + self, + x, + encoder_padding_mask: Optional[Tensor], + attn_mask: Optional[Tensor] = None, + ): + """ + Args: + x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)` + encoder_padding_mask (ByteTensor): binary ByteTensor of shape + `(batch, seq_len)` where padding elements are indicated by ``1``. + attn_mask (ByteTensor): binary tensor of shape `(tgt_len, src_len)`, + where `tgt_len` is the length of output and `src_len` is the + length of input, though here both are equal to `seq_len`. + `attn_mask[tgt_i, src_j] = 1` means that when calculating the + embedding for `tgt_i`, we exclude (mask out) `src_j`. This is + useful for strided self-attention. + + Returns: + encoded output of shape `(seq_len, batch, embed_dim)` + """ + # anything in original attn_mask = 1, becomes -1e8 + # anything in original attn_mask = 0, becomes 0 + # Note that we cannot use -inf here, because at some edge cases, + # the attention weight (before softmax) for some padded element in query + # will become -inf, which results in NaN in model parameters + + if self.training: + self.ever_training = True + + if ( + self.BT_version + and x.dim() == 3 + and self.load_to_BT + and not self.return_fc + and self.can_use_fastpath + and not self.training + and not self.ever_training + and not self.cfg_checkpoint_activations + ): + # assume is Batch first and nested tensor + output = torch._transformer_encoder_layer_fwd( + x, + self.embed_dim, + self.num_heads, + self.in_proj_weight, + self.in_proj_bias, + self.out_proj_weight, + self.out_proj_bias, + self.activation_relu_or_gelu == 2, + False, # norm_first, currently not supported + self.self_attn_layer_norm.eps, + self.self_attn_layer_norm.weight, + self.self_attn_layer_norm.bias, + self.final_layer_norm.weight, + self.final_layer_norm.bias, + self.fc1_weight, + self.fc1_bias, + self.fc2_weight, + self.fc2_bias, + encoder_padding_mask if encoder_padding_mask is not None else attn_mask, + ) + return output + + else: + if attn_mask is not None: + attn_mask = attn_mask.masked_fill( + attn_mask.to(torch.bool), -1e8 if x.dtype == torch.float32 else -1e4 + ) + + residual = x + if self.normalize_before: + x = self.self_attn_layer_norm(x) + x, _ = self.self_attn( + query=x, + key=x, + value=x, + key_padding_mask=encoder_padding_mask, + need_weights=False, + attn_mask=attn_mask, + ) + x = self.dropout_module(x) + x = self.residual_connection(x, residual) + if not self.normalize_before: + x = self.self_attn_layer_norm(x) + + residual = x + if self.normalize_before: + x = self.final_layer_norm(x) + x = self.activation_fn(self.fc1(x)) + x = self.activation_dropout_module(x) + x = self.fc2(x) + + fc_result = x + + x = self.dropout_module(x) + x = self.residual_connection(x, residual) + if not self.normalize_before: + x = self.final_layer_norm(x) + + if self.return_fc and not torch.jit.is_scripting(): + return x, fc_result + return x + +def safe_getattr(obj, k, default=None): + """Returns obj[k] if it exists and is not None, otherwise returns default.""" + from omegaconf import OmegaConf + + if OmegaConf.is_config(obj): + return obj[k] if k in obj and obj[k] is not None else default + + return getattr(obj, k, default) + +class TransformerDecoderLayerBase(nn.Module): + """Decoder layer block. + + In the original paper each operation (multi-head attention, encoder + attention or FFN) is postprocessed with: `dropout -> add residual -> + layernorm`. In the tensor2tensor code they suggest that learning is more + robust when preprocessing each layer with layernorm and postprocessing with: + `dropout -> add residual`. We default to the approach in the paper, but the + tensor2tensor approach can be enabled by setting + *cfg.decoder.normalize_before* to ``True``. + + Args: + args (argparse.Namespace): parsed command-line arguments + no_encoder_attn (bool, optional): whether to attend to encoder outputs + (default: False). + """ + + def __init__( + self, cfg, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False + ): #embed_dim, num_heads, ff_dim, dropout + super().__init__() + self.embed_dim = cfg.decoder.embed_dim + self.dropout_module = FairseqDropout( + cfg.dropout, module_name=self.__class__.__name__ + ) + self.quant_noise = cfg.quant_noise.pq + self.quant_noise_block_size = cfg.quant_noise.pq_block_size + + self.cross_self_attention = cfg.cross_self_attention + + self.self_attn = self.build_self_attention( + self.embed_dim, + cfg, + add_bias_kv=add_bias_kv, + add_zero_attn=add_zero_attn, + ) + self.attn_ln = ( + LayerNorm(self.embed_dim) + if safe_getattr(cfg, "scale_attn", False) + else None + ) + self.nh = self.self_attn.num_heads + self.head_dim = self.self_attn.head_dim + scale_heads = safe_getattr(cfg, "scale_heads", False) + self.c_attn = ( + nn.Parameter(torch.ones((self.nh,)), requires_grad=True) + if scale_heads + else None + ) + + self.activation_fn = get_activation_fn(activation=cfg.activation_fn) + activation_dropout_p = cfg.activation_dropout + if activation_dropout_p == 0: + # for backwards compatibility with models that use cfg.relu_dropout + activation_dropout_p = cfg.relu_dropout or 0 + self.activation_dropout_module = FairseqDropout( + float(activation_dropout_p), module_name=self.__class__.__name__ + ) + self.normalize_before = cfg.decoder.normalize_before + + self.self_attn_layer_norm = LayerNorm(self.embed_dim, export=cfg.export) + + if no_encoder_attn: + self.encoder_attn = None + self.encoder_attn_layer_norm = None + else: + self.encoder_attn = self.build_encoder_attention(self.embed_dim, cfg) + self.encoder_attn_layer_norm = LayerNorm(self.embed_dim, export=cfg.export) + + self.ffn_layernorm = ( + LayerNorm(cfg.decoder.ffn_embed_dim) + if safe_getattr(cfg, "scale_fc", False) + else None + ) + self.w_resid = ( + nn.Parameter( + torch.ones( + self.embed_dim, + ), + requires_grad=True, + ) + if safe_getattr(cfg, "scale_resids", False) + else None + ) + + self.fc1 = self.build_fc1( + self.embed_dim, + cfg.decoder.ffn_embed_dim, + self.quant_noise, + self.quant_noise_block_size, + ) + self.fc2 = self.build_fc2( + cfg.decoder.ffn_embed_dim, + self.embed_dim, + self.quant_noise, + self.quant_noise_block_size, + ) + + self.final_layer_norm = LayerNorm(self.embed_dim, export=cfg.export) + self.need_attn = True + + self.onnx_trace = False + + def build_fc1(self, input_dim, output_dim, q_noise, qn_block_size): + return quant_noise(nn.Linear(input_dim, output_dim), q_noise, qn_block_size) + + def build_fc2(self, input_dim, output_dim, q_noise, qn_block_size): + return quant_noise(nn.Linear(input_dim, output_dim), q_noise, qn_block_size) + + def build_self_attention( + self, embed_dim, cfg, add_bias_kv=False, add_zero_attn=False + ): + return MultiheadAttention( + embed_dim, + cfg.decoder.attention_heads, + dropout=cfg.attention_dropout, + add_bias_kv=add_bias_kv, + add_zero_attn=add_zero_attn, + self_attention=not cfg.cross_self_attention, + q_noise=self.quant_noise, + qn_block_size=self.quant_noise_block_size, + xformers_att_config=cfg.decoder.xformers_att_config, + ) + + def build_encoder_attention(self, embed_dim, cfg): + return MultiheadAttention( + embed_dim, + cfg.decoder.attention_heads, + kdim=cfg.encoder.embed_dim, + vdim=cfg.encoder.embed_dim, + dropout=cfg.attention_dropout, + encoder_decoder_attention=True, + q_noise=self.quant_noise, + qn_block_size=self.quant_noise_block_size, + xformers_att_config=cfg.encoder.xformers_att_config, + ) + + def prepare_for_onnx_export_(self): + self.onnx_trace = True + + def residual_connection(self, x, residual): + return residual + x + + def forward( + self, + x, + encoder_out: Optional[torch.Tensor] = None, + encoder_padding_mask: Optional[torch.Tensor] = None, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, + prev_self_attn_state: Optional[List[torch.Tensor]] = None, + prev_attn_state: Optional[List[torch.Tensor]] = None, + self_attn_mask: Optional[torch.Tensor] = None, + self_attn_padding_mask: Optional[torch.Tensor] = None, + need_attn: bool = False, + need_head_weights: bool = False, + ): + """ + Args: + x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)` + encoder_padding_mask (ByteTensor, optional): binary + ByteTensor of shape `(batch, src_len)` where padding + elements are indicated by ``1``. + need_attn (bool, optional): return attention weights + need_head_weights (bool, optional): return attention weights + for each head (default: return average over heads). + + Returns: + encoded output of shape `(seq_len, batch, embed_dim)` + """ + if need_head_weights: + need_attn = True + + residual = x + if self.normalize_before: + x = self.self_attn_layer_norm(x) + if prev_self_attn_state is not None: + prev_key, prev_value = prev_self_attn_state[:2] + saved_state: Dict[str, Optional[Tensor]] = { + "prev_key": prev_key, + "prev_value": prev_value, + } + if len(prev_self_attn_state) >= 3: + saved_state["prev_key_padding_mask"] = prev_self_attn_state[2] + assert incremental_state is not None + self.self_attn._set_input_buffer(incremental_state, saved_state) + _self_attn_input_buffer = self.self_attn._get_input_buffer(incremental_state) + if self.cross_self_attention and not ( + incremental_state is not None + and _self_attn_input_buffer is not None + and "prev_key" in _self_attn_input_buffer + ): + if self_attn_mask is not None: + assert encoder_out is not None + self_attn_mask = torch.cat( + (x.new_zeros(x.size(0), encoder_out.size(0)), self_attn_mask), dim=1 + ) + if self_attn_padding_mask is not None: + if encoder_padding_mask is None: + assert encoder_out is not None + encoder_padding_mask = self_attn_padding_mask.new_zeros( + encoder_out.size(1), encoder_out.size(0) + ) + self_attn_padding_mask = torch.cat( + (encoder_padding_mask, self_attn_padding_mask), dim=1 + ) + assert encoder_out is not None + y = torch.cat((encoder_out, x), dim=0) + else: + y = x + + x, attn = self.self_attn( + query=x, + key=y, + value=y, + key_padding_mask=self_attn_padding_mask, + incremental_state=incremental_state, + need_weights=False, + attn_mask=self_attn_mask, + ) + if self.c_attn is not None: + tgt_len, bsz = x.size(0), x.size(1) + x = x.view(tgt_len, bsz, self.nh, self.head_dim) + x = torch.einsum("tbhd,h->tbhd", x, self.c_attn) + x = x.reshape(tgt_len, bsz, self.embed_dim) + if self.attn_ln is not None: + x = self.attn_ln(x) + x = self.dropout_module(x) + x = self.residual_connection(x, residual) + if not self.normalize_before: + x = self.self_attn_layer_norm(x) + + if self.encoder_attn is not None and encoder_out is not None: + residual = x + if self.normalize_before: + x = self.encoder_attn_layer_norm(x) + if prev_attn_state is not None: + prev_key, prev_value = prev_attn_state[:2] + saved_state: Dict[str, Optional[Tensor]] = { + "prev_key": prev_key, + "prev_value": prev_value, + } + if len(prev_attn_state) >= 3: + saved_state["prev_key_padding_mask"] = prev_attn_state[2] + assert incremental_state is not None + self.encoder_attn._set_input_buffer(incremental_state, saved_state) + + x, attn = self.encoder_attn( + query=x, + key=encoder_out, + value=encoder_out, + key_padding_mask=encoder_padding_mask, + incremental_state=incremental_state, + static_kv=True, + need_weights=need_attn or (not self.training and self.need_attn), + need_head_weights=need_head_weights, + ) + x = self.dropout_module(x) + x = self.residual_connection(x, residual) + if not self.normalize_before: + x = self.encoder_attn_layer_norm(x) + + residual = x + if self.normalize_before: + x = self.final_layer_norm(x) + + x = self.activation_fn(self.fc1(x)) + x = self.activation_dropout_module(x) + if self.ffn_layernorm is not None: + x = self.ffn_layernorm(x) + x = self.fc2(x) + x = self.dropout_module(x) + if self.w_resid is not None: + residual = torch.mul(self.w_resid, residual) + x = self.residual_connection(x, residual) + if not self.normalize_before: + x = self.final_layer_norm(x) + if self.onnx_trace and incremental_state is not None: + saved_state = self.self_attn._get_input_buffer(incremental_state) + assert saved_state is not None + if self_attn_padding_mask is not None: + self_attn_state = [ + saved_state["prev_key"], + saved_state["prev_value"], + saved_state["prev_key_padding_mask"], + ] + else: + self_attn_state = [saved_state["prev_key"], saved_state["prev_value"]] + return x, attn, self_attn_state + return x, attn, None + + def make_generation_fast_(self, need_attn: bool = False, **kwargs): + self.need_attn = need_attn + + +import torch +import torch.nn as nn +import math +from typing import Optional, Dict, List, Any +from torch import Tensor + + +def make_positions(tensor, padding_idx: int, onnx_trace: bool = False): + """Replace non-padding symbols with their position numbers. + + Position numbers begin at padding_idx+1. Padding symbols are ignored. + """ + # The series of casts and type-conversions here are carefully + # balanced to both work with ONNX export and XLA. In particular XLA + # prefers ints, cumsum defaults to output longs, and ONNX doesn't know + # how to handle the dtype kwarg in cumsum. + mask = tensor.ne(padding_idx).int() + return (torch.cumsum(mask, dim=1).type_as(mask) * mask).long() + padding_idx + +class SinusoidalPositionalEmbedding(nn.Module): + """This module produces sinusoidal positional embeddings of any length. + + Padding symbols are ignored. + """ + + def __init__(self, embedding_dim, padding_idx, init_size=1024): + super().__init__() + self.embedding_dim = embedding_dim + self.padding_idx = padding_idx if padding_idx is not None else 0 + self.weights = SinusoidalPositionalEmbedding.get_embedding( + init_size, embedding_dim, padding_idx + ) + self.onnx_trace = False + self.register_buffer("_float_tensor", torch.FloatTensor(1)) + self.max_positions = int(1e5) + + def prepare_for_onnx_export_(self): + self.onnx_trace = True + + @staticmethod + def get_embedding( + num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None + ): + """Build sinusoidal embeddings. + + This matches the implementation in tensor2tensor, but differs slightly + from the description in Section 3.5 of "Attention Is All You Need". + """ + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb) + emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze( + 1 + ) * emb.unsqueeze(0) + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view( + num_embeddings, -1 + ) + if embedding_dim % 2 == 1: + # zero pad + emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1) + if padding_idx is not None: + emb[padding_idx, :] = 0 + return emb + + def forward( + self, + input, + incremental_state: Optional[Any] = None, + timestep: Optional[Tensor] = None, + positions: Optional[Any] = None, + ): + """Input is expected to be of size [bsz x seqlen].""" + bspair = torch.onnx.operators.shape_as_tensor(input) + bsz, seq_len = bspair[0], bspair[1] + max_pos = self.padding_idx + 1 + seq_len + if self.weights is None or max_pos > self.weights.size(0): + # recompute/expand embeddings if needed + self.weights = SinusoidalPositionalEmbedding.get_embedding( + max_pos, self.embedding_dim, self.padding_idx + ) + self.weights = self.weights.to(self._float_tensor) + + if incremental_state is not None: + # positions is the same for every token when decoding a single step + pos = timestep.view(-1)[0] + 1 if timestep is not None else seq_len + if self.onnx_trace: + return ( + self.weights.index_select(index=self.padding_idx + pos, dim=0) + .unsqueeze(1) + .repeat(bsz, 1, 1) + ) + return self.weights[self.padding_idx + pos, :].expand(bsz, 1, -1) + + positions = make_positions( + input, self.padding_idx, onnx_trace=self.onnx_trace + ) + if self.onnx_trace: + flat_embeddings = self.weights.detach().index_select(0, positions.view(-1)) + embedding_shape = torch.cat( + (bsz.view(1), seq_len.view(1), torch.tensor([-1], dtype=torch.long)) + ) + embeddings = torch.onnx.operators.reshape_from_tensor_shape( + flat_embeddings, embedding_shape + ) + return embeddings + return ( + self.weights.index_select(0, positions.view(-1)) + .view(bsz, seq_len, -1) + .detach() + ) + + +class TransformerEncoderBase(nn.Module): + def __init__(self, cfg, dictionary, embed_tokens, return_fc=False): + super().__init__() + self.cfg = cfg + self.dictionary = dictionary + self.return_fc = return_fc + self.register_buffer('version', torch.Tensor([3])) + + self.dropout_module = FairseqDropout(cfg.dropout) + self.encoder_layerdrop = cfg.encoder.layerdrop + + embed_dim = embed_tokens.embedding_dim + self.padding_idx = embed_tokens.padding_idx + self.max_source_positions = cfg.max_source_positions + + self.embed_tokens = embed_tokens + self.embed_scale = 1.0 if cfg.no_scale_embedding else math.sqrt(embed_dim) + + self.embed_positions = ( + SinusoidalPositionalEmbedding( + embed_dim, self.padding_idx, cfg.max_source_positions + self.padding_idx + 1 + ) if not cfg.no_token_positional_embeddings else None + ) + + # self.layernorm_embedding = ( + # nn.LayerNorm(embed_dim) if cfg.layernorm_embedding else None + # ) + + if cfg.layernorm_embedding: + self.layernorm_embedding = LayerNorm(embed_dim, export=cfg.export) + else: + self.layernorm_embedding = None + + if not cfg.adaptive_input and cfg.quant_noise.pq > 0: + self.quant_noise = quant_noise( + nn.Linear(embed_dim, embed_dim, bias=False), + cfg.quant_noise.pq, + cfg.quant_noise.pq_block_size, + ) + else: + self.quant_noise = None + + if self.encoder_layerdrop > 0.0: + self.layers = LayerDropModuleList(p=self.encoder_layerdrop) + else: + self.layers = nn.ModuleList([]) + + self.layers.extend( + [self.build_encoder_layer(cfg) for i in range(cfg.encoder.layers)] + ) + + self.num_layers = len(self.layers) + + if cfg.encoder.normalize_before: + self.layer_norm = LayerNorm(embed_dim, export=cfg.export) + else: + self.layer_norm = None + + def build_encoder_layer(self, cfg): + layer = TransformerEncoderLayerBase( + cfg, return_fc=self.return_fc + ) + checkpoint = cfg.checkpoint_activations + # if checkpoint: + # offload_to_cpu = cfg.offload_activations + # layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu) + # if we are checkpointing, enforce that FSDP always wraps the + # checkpointed layer, regardless of layer size + min_params_to_wrap = cfg.min_params_to_wrap if not checkpoint else 0 + # layer = fsdp_wrap(layer, min_num_params=min_params_to_wrap) + return layer + + def forward_embedding( + self, src_tokens, token_embedding: Optional[torch.Tensor] = None): + # embed tokens and positions + if token_embedding is None: + token_embedding = self.embed_tokens(src_tokens) + x = embed = self.embed_scale * token_embedding + if self.embed_positions is not None: + x = embed + self.embed_positions(src_tokens) + if self.layernorm_embedding is not None: + x = self.layernorm_embedding(x) + x = self.dropout_module(x) + if self.quant_noise is not None: + x = self.quant_noise(x) + return x, embed + + def max_positions(self): + """Maximum input length supported by the encoder.""" + if self.embed_positions is None: + return self.max_source_positions + return min(self.max_source_positions, self.embed_positions.max_positions) + + def forward(self, src_tokens, src_lengths: Optional[torch.Tensor] = None, token_embeddings: Optional[torch.Tensor] = None, return_all_hiddens: bool = False): + encoder_padding_mask = src_tokens.eq(self.padding_idx) + # encoder_padding_mask = src_tokens.device.type == "xla" or encoder_padding_mask.any() + + has_pads = src_tokens.device.type == "xla" or encoder_padding_mask.any() + + x, encoder_embedding = self.forward_embedding(src_tokens) + + if has_pads: + x = x * (1 - encoder_padding_mask.unsqueeze(-1).type_as(x)) + + x = x.transpose(0, 1) # B x T x C -> T x B x C + + encoder_states = [] if return_all_hiddens else None + fc_results = [] + + if return_all_hiddens: + encoder_states.append(x) + + encoder_padding_mask = encoder_padding_mask if has_pads else None + for layer in self.layers: + x = layer(x, encoder_padding_mask = encoder_padding_mask) + + if isinstance(x, tuple) and len(x) ==2: + x, fc_result = x + else: + fc_result = None + + if return_all_hiddens: + assert encoder_states is not None + encoder_states.append(x) + fc_results.append(fc_result) + + if self.layer_norm is not None: + x = self.layer_norm(x) + + src_lengths = ( + src_tokens.ne(self.padding_idx) + .sum(dim=1, dtype=torch.int32) + .reshape(-1, 1) + .contiguous() + ) + + return { + "encoder_out": [x], # T x B x C + "encoder_padding_mask": [encoder_padding_mask], # B x T + "encoder_embedding": [encoder_embedding], # B x T x C + "encoder_states": encoder_states, # List[T x B x C] + "fc_results": fc_results, # List[T x B x C] + "src_tokens": [], + "src_lengths": [src_lengths], + } + +import torch.nn as nn +import torch +import sys +import torch.distributed as dist +# from fairseq import utils +# from fairseq.distributed import utils as distributed_utils +# from fairseq.modules.layer_norm import LayerNorm + +_MODEL_PARALLEL_GROUP = None +# Data parallel group that the current rank belongs to. +_DATA_PARALLEL_GROUP = None +_USE_XLA = False + +def use_xla(): + global _USE_XLA + return _USE_XLA + +def get_world_size(group): + if use_xla(): + assert group[0] == "tpu" + my_group = _find_my_group(group[1]) + return len(my_group) + elif torch.distributed.is_initialized(): + return dist.get_world_size(group=group) + else: + return 1 + +def get_global_world_size(): + if use_xla(): + return xm.xrt_world_size() + elif torch.distributed.is_initialized(): + return torch.distributed.get_world_size() + else: + return 1 +def get_global_rank(): + if use_xla(): + return xm.get_ordinal() + elif torch.distributed.is_initialized(): + return torch.distributed.get_rank() + else: + return 0 + +def new_groups(grouped_ranks: List[List[int]]): + if use_xla(): + return ("tpu", grouped_ranks) + else: + groups = [dist.new_group(g) for g in grouped_ranks] + my_group_idx = _find_my_group_index(grouped_ranks) + return groups[my_group_idx] + +def get_global_group(): + if use_xla(): + return new_groups([list(range(get_global_world_size()))]) + elif torch.distributed.is_initialized(): + if not hasattr(get_global_group, "_global_group"): + # ideally we could use torch.distributed.group.WORLD, but it seems + # to cause random NCCL hangs in some cases + get_global_group._global_group = dist.new_group() + return get_global_group._global_group + else: + return None + +def get_global_group(): + if use_xla(): + return new_groups([list(range(get_global_world_size()))]) + elif torch.distributed.is_initialized(): + if not hasattr(get_global_group, "_global_group"): + # ideally we could use torch.distributed.group.WORLD, but it seems + # to cause random NCCL hangs in some cases + get_global_group._global_group = dist.new_group() + return get_global_group._global_group + else: + return None + +def _find_my_group_index(grouped_ranks): + my_rank = get_global_rank() + for i, group in enumerate(grouped_ranks): + if my_rank in group: + return i + raise RuntimeError + + +def _find_my_group(grouped_ranks): + index = _find_my_group_index(grouped_ranks) + return grouped_ranks[index] + +def get_global_group(): + if use_xla(): + return new_groups([list(range(get_global_world_size()))]) + elif torch.distributed.is_initialized(): + if not hasattr(get_global_group, "_global_group"): + # ideally we could use torch.distributed.group.WORLD, but it seems + # to cause random NCCL hangs in some cases + get_global_group._global_group = dist.new_group() + return get_global_group._global_group + else: + return None + +def get_world_size(group): + if use_xla(): + assert group[0] == "tpu" + my_group = _find_my_group(group[1]) + return len(my_group) + elif torch.distributed.is_initialized(): + return dist.get_world_size(group=group) + else: + return 1 + +def get_rank(group): + if use_xla(): + assert group[0] == "tpu" + my_group = _find_my_group(group[1]) + return my_group.index(get_global_rank()) + else: + return dist.get_rank(group=group) + +def mpu_get_data_parallel_group(): + """Get the data parallel group the caller rank belongs to.""" + assert _DATA_PARALLEL_GROUP is not None, \ + 'data parallel group is not initialized' + return _DATA_PARALLEL_GROUP + +def get_data_parallel_group(): + """Get the data parallel group the caller rank belongs to.""" + global _USE_MEGATRON + if _USE_MEGATRON: + return mpu_get_data_parallel_group() + else: + return get_global_group() + +def get_data_parallel_rank(): + """Return my rank for the data parallel group.""" + return get_rank(get_data_parallel_group()) + +def get_data_parallel_world_size(): + """Return world size for the data parallel group.""" + return get_world_size(get_data_parallel_group()) + +class BaseSublayer(nn.Module): + def __init__(self, args): + super().__init__() + self.activation_fn = get_activation_fn( + activation=getattr(args, "activation_fn", "relu") or "relu" + ) + self.norm = LayerNorm(args.decoder_embed_dim, export=False) + self.ff1 = torch.nn.Linear(args.decoder_embed_dim, args.decoder_ffn_embed_dim) + self.ff2 = torch.nn.Linear(args.decoder_ffn_embed_dim, args.decoder_embed_dim) + self.ff2.weight.data.zero_() + + def forward(self, xs): + return xs + self.ff2(self.activation_fn(self.ff1(self.norm(xs)))) + +class BaseLayer(nn.Module): + def __init__(self, args): + super().__init__() + self.num_workers = get_data_parallel_world_size() + expert_centroids = torch.empty(self.num_workers, args.decoder_embed_dim) + torch.nn.init.orthogonal_(expert_centroids, gain=0.1) + self.register_parameter( + "expert_centroids", torch.nn.Parameter(expert_centroids) + ) + self.expert_network = nn.Sequential( + *([BaseSublayer(args) for _ in range(args.base_sublayers)]) + ) + self.expert_id = get_data_parallel_rank() + self.shuffle = args.base_shuffle + self.cpp = self.load_assignment() + + # Add a special attribute to the expert parameters, so we know not to sync their gradients + for param in self.expert_network.parameters(): + param.expert = True + + def forward(self, input_features, *args, **kwargs): + features = input_features.reshape(-1, input_features.size(-1)) + is_training = input_features.requires_grad + + if self.shuffle and is_training: + # Send each token to a random worker, to break correlations within the batch + shuffle_sort = torch.randperm(features.size(0), device=features.device) + features = All2All.apply(features[shuffle_sort]) + + with torch.no_grad(): + # Compute similarity of each token to each expert, for routing + token_expert_affinities = features.matmul( + self.expert_centroids.transpose(0, 1) + ) + + # Compute which token goes to which expert + sort_by_expert, input_splits, output_splits = ( + self.balanced_assignment(token_expert_affinities) + if is_training + else self.greedy_assignment(token_expert_affinities) + ) + # Swap these tokens for the right ones for our expert + routed_features = All2All.apply( + features[sort_by_expert], output_splits, input_splits + ) + + if routed_features.size(0) > 0: + # Mix in the expert network based on how appropriate it is for these tokens + alpha = torch.sigmoid( + routed_features.mv(self.expert_centroids[self.expert_id]) + ).unsqueeze(1) + routed_features = ( + alpha * self.expert_network(routed_features) + + (1 - alpha) * routed_features + ) + # Return to original worker and ordering + result = All2All.apply(routed_features, input_splits, output_splits)[ + self.inverse_sort(sort_by_expert) + ] + + if self.shuffle and is_training: + # Undo shuffling + result = All2All.apply(result)[self.inverse_sort(shuffle_sort)] + + # Return additional Nones for compatibility with TransformerDecoderLayer + return result.view(input_features.size()), None, None + + def inverse_sort(self, order): + # Creates an index that undoes a sort: xs==xs[order][inverse_sort(order)] + return torch.empty_like(order).scatter_( + 0, order, torch.arange(0, order.size(0), device=order.device) + ) + + def balanced_assignment(self, scores): + ok = scores.isfinite() + if not ok.all(): + # NaNs here can break the assignment algorithm + scores[~ok] = scores[ok].min() + return self.cpp.balanced_assignment(scores), None, None + + # Assigns each token to the top k experts + def greedy_assignment(self, scores, k=1): + token_to_workers = torch.topk(scores, dim=1, k=k, largest=True).indices.view(-1) + token_to_workers, sort_ordering = torch.sort(token_to_workers) + worker2token = sort_ordering // k + + # Find how many tokens we're sending to each other worker (being careful for sending 0 tokens to some workers) + output_splits = torch.zeros( + (self.num_workers,), dtype=torch.long, device=scores.device + ) + workers, counts = torch.unique_consecutive(token_to_workers, return_counts=True) + output_splits[workers] = counts + # Tell other workers how many tokens to expect from us + input_splits = All2All.apply(output_splits) + return worker2token, input_splits.tolist(), output_splits.tolist() + + def load_assignment(self): + try: + from fairseq import libbase + + return libbase + + except ImportError as e: + sys.stderr.write( + "ERROR: missing libbase. run `python setup.py build_ext --inplace`\n" + ) + raise e + + + +class TransformerDecoderBase(nn.Module): + """ + Transformer decoder implemented using PyTorch's nn.Module. + + Args: + vocab_size (int): Size of the vocabulary. + embed_dim (int): Dimension of the embeddings. + num_layers (int): Number of Transformer decoder layers. + num_heads (int): Number of attention heads. + ff_dim (int): Dimension of feed-forward layers. + dropout (float): Dropout probability. + max_target_positions (int): Maximum target sequence length. + padding_idx (int): Index for the padding token. + share_input_output_embed (bool): Whether to share input/output embeddings. + """ + + def __init__( + self, + cfg, + dictionary, + embed_tokens, + no_encoder_attn=False, + output_projection=None, + ): + super().__init__() + self.register_buffer("version", torch.Tensor([3])) + self._future_mask = torch.empty(0) + ################ + self.dropout_module = FairseqDropout( + cfg.dropout, module_name="TransformerDecoder") + self.decoder_layerdrop = cfg.decoder.layerdrop + self.share_input_output_embed = cfg.share_decoder_input_output_embed + + input_embed_dim = embed_tokens.embedding_dim + embed_dim = cfg.decoder.embed_dim + self.embed_dim = embed_dim + self.output_embed_dim = cfg.decoder.output_dim + + self.padding_idx = embed_tokens.padding_idx + self.max_target_positions = cfg.max_target_positions + + self.embed_tokens = embed_tokens + + self.embed_scale = 1.0 if cfg.no_scale_embedding else math.sqrt( + embed_dim) + + if cfg.quant_noise.pq > 0: + self.quant_noise = quant_noise( + nn.Linear(embed_dim, embed_dim, bias=False), + cfg.quant_noise.pq, + cfg.quant_noise.pq_block_size, + ) + else: + self.quant_noise = None + + self.project_in_dim = ( + nn.Linear(input_embed_dim, embed_dim, bias=False) + if embed_dim != input_embed_dim + else None + ) + self.embed_positions = ( + SinusoidalPositionalEmbedding( + embed_dim, self.padding_idx, cfg.max_target_positions + self.padding_idx + 1 + ) + if not cfg.no_token_positional_embeddings + else None + ) + if cfg.layernorm_embedding: + self.layernorm_embedding = LayerNorm(embed_dim, export=cfg.export) + else: + self.layernorm_embedding = None + + self.cross_self_attention = cfg.cross_self_attention + + if self.decoder_layerdrop > 0.0: + self.layers = LayerDropModuleList(p=self.decoder_layerdrop) + else: + self.layers = nn.ModuleList([]) + self.layers.extend( + [ + self.build_decoder_layer(cfg, no_encoder_attn) + for _ in range(cfg.decoder.layers) + ] + ) + self.num_layers = len(self.layers) + + if cfg.decoder.normalize_before and not cfg.no_decoder_final_norm: + self.layer_norm = LayerNorm(embed_dim, export=cfg.export) + else: + self.layer_norm = None + + self.project_out_dim = ( + nn.Linear(embed_dim, self.output_embed_dim, bias=False) + if embed_dim != self.output_embed_dim and not cfg.tie_adaptive_weights + else None + ) + + self.adaptive_softmax = None + self.output_projection = output_projection + if self.output_projection is None: + self.build_output_projection(cfg, dictionary, embed_tokens) + ################ + + + def build_output_projection(self, cfg, dictionary, embed_tokens): + if self.share_input_output_embed: + self.output_projection = nn.Linear( + self.embed_tokens.weight.shape[1], + self.embed_tokens.weight.shape[0], + bias=False, + ) + self.output_projection.weight = self.embed_tokens.weight + else: + self.output_projection = nn.Linear( + self.output_embed_dim, len(dictionary), bias=False + ) + nn.init.normal_( + self.output_projection.weight, mean=0, std=self.output_embed_dim**-0.5 + ) + num_base_layers = cfg.base_layers + for i in range(num_base_layers): + self.layers.insert( + ((i + 1) * cfg.decoder.layers) // (num_base_layers + 1), + BaseLayer(cfg), + ) + + + def build_decoder_layer(self, cfg, no_encoder_attn=False): + layer = TransformerDecoderLayerBase(cfg, no_encoder_attn) + checkpoint = cfg.checkpoint_activations + if checkpoint: + offload_to_cpu = cfg.offload_activations + # layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu) + # if we are checkpointing, enforce that FSDP always wraps the + # checkpointed layer, regardless of layer size + min_params_to_wrap = cfg.min_params_to_wrap if not checkpoint else 0 + # layer = fsdp_wrap(layer, min_num_params=min_params_to_wrap) + return layer + + def forward( + self, + prev_output_tokens: Tensor, + encoder_out: Optional[Tensor] = None, + src_padding_mask: Optional[Tensor] = None, + src_lengths: Optional[Any] = None, + return_all_hiddens: bool = False, + features_only: bool = False, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, + full_context_alignment: bool = False, + alignment_layer: Optional[int] = None, + alignment_heads: Optional[int] = None, + ): + """ + Args: + prev_output_tokens (Tensor): Previous output tokens of shape (batch, tgt_len). + encoder_out (Tensor, optional): Encoder outputs (batch, src_len, embed_dim). + src_padding_mask (Tensor, optional): Padding mask for the encoder inputs. + + Returns: + Tensor: Decoder output of shape (batch, tgt_len, vocab_size). + """ + bs, slen = prev_output_tokens.size() + if alignment_layer is None: + alignment_layer = self.num_layers - 1 + + enc: Optional[Tensor] = None + padding_mask: Optional[Tensor] = None + + if encoder_out is not None and len(encoder_out["encoder_out"]) > 0: + enc = encoder_out["encoder_out"][0] + if encoder_out is not None and len(encoder_out["encoder_padding_mask"]) > 0: + padding_mask = encoder_out["encoder_padding_mask"][0] + + # embed positions + positions = None + if self.embed_positions is not None: + positions = self.embed_positions( + prev_output_tokens, incremental_state=incremental_state + ) + + if incremental_state is not None: + prev_output_tokens = prev_output_tokens[:, -1:] + if positions is not None: + positions = positions[:, -1:] + + + # Prevent torchscript exporting issue for dynamic quant embedding + prev_output_tokens = prev_output_tokens.contiguous() + # embed tokens and positions + x = self.embed_scale * self.embed_tokens(prev_output_tokens) + + if self.quant_noise is not None: + x = self.quant_noise(x) + + if self.project_in_dim is not None: + x = self.project_in_dim(x) + + if positions is not None: + x += positions + + if self.layernorm_embedding is not None: + x = self.layernorm_embedding(x) + + x = self.dropout_module(x) + + # B x T x C -> T x B x C + x = x.transpose(0, 1) + + self_attn_padding_mask: Optional[Tensor] = None + if self.cross_self_attention or prev_output_tokens.eq(self.padding_idx).any(): + self_attn_padding_mask = prev_output_tokens.eq(self.padding_idx) + + + + # Embed tokens and positions + # positions = torch.arange(prev_output_tokens.size(1), device=prev_output_tokens.device).unsqueeze(0) + # x = self.embed_tokens(prev_output_tokens) + self.embed_positions(positions) + # x = self.dropout(x) + # decoder layers + attn: Optional[Tensor] = None + inner_states: List[Optional[Tensor]] = [x] + + for idx, layer in enumerate(self.layers): + if incremental_state is None and not full_context_alignment: + self_attn_mask = self.buffered_future_mask(x) + else: + self_attn_mask = None + + x, layer_attn, _ = layer( + x, + enc, + padding_mask, + incremental_state, + self_attn_mask=self_attn_mask, + self_attn_padding_mask=self_attn_padding_mask, + need_attn=bool((idx == alignment_layer)), + need_head_weights=bool((idx == alignment_layer)), + ) + inner_states.append(x) + if layer_attn is not None and idx == alignment_layer: + attn = layer_attn.float().to(x) + + if attn is not None: + if alignment_heads is not None: + attn = attn[:alignment_heads] + + # average probabilities over heads + attn = attn.mean(dim=0) + + if self.layer_norm is not None: + x = self.layer_norm(x) + + # T x B x C -> B x T x C + x = x.transpose(0, 1) + + if self.project_out_dim is not None: + x = self.project_out_dim(x) + + if not features_only: + x = self.output_layer(x) + + return x, {"attn": [attn], "inner_states": inner_states} + + def output_layer(self, features): + """Project features to the vocabulary size.""" + if self.adaptive_softmax is None: + # project back to size of vocabulary + return self.output_projection(features) + else: + return features + + def max_positions(self): + """Maximum output length supported by the decoder.""" + if self.embed_positions is None: + return self.max_target_positions + return min(self.max_target_positions, self.embed_positions.max_positions) + + def fill_with_neg_inf(self, t): + """FP16-compatible function that fills a tensor with -inf.""" + return t.float().fill_(float("-inf")).type_as(t) + + def buffered_future_mask(self, tensor): + dim = tensor.size(0) + # self._future_mask.device != tensor.device is not working in TorchScript. This is a workaround. + if ( + self._future_mask.size(0) == 0 + or (not self._future_mask.device == tensor.device) + or self._future_mask.size(0) < dim + ): + self._future_mask = torch.triu( + self.fill_with_neg_inf(torch.zeros([dim, dim])), 1 + ) + self._future_mask = self._future_mask.to(tensor) + return self._future_mask[:dim, :dim] + + +class FairseqIncrementalState(object): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.init_incremental_state() + + def init_incremental_state(self): + self._incremental_state_id = str(uuid.uuid4()) + + def _get_full_incremental_state_key(self, key: str) -> str: + return "{}.{}".format(self._incremental_state_id, key) + + def get_incremental_state( + self, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], + key: str, + ) -> Optional[Dict[str, Optional[Tensor]]]: + """Helper for getting incremental state for an nn.Module.""" + full_key = self._get_full_incremental_state_key(key) + if incremental_state is None or full_key not in incremental_state: + return None + return incremental_state[full_key] + + def set_incremental_state( + self, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], + key: str, + value: Dict[str, Optional[Tensor]], + ) -> Optional[Dict[str, Dict[str, Optional[Tensor]]]]: + """Helper for setting incremental state for an nn.Module.""" + if incremental_state is not None: + full_key = self._get_full_incremental_state_key(key) + incremental_state[full_key] = value + return incremental_state + + +def with_incremental_state(cls): + cls.__bases__ = (FairseqIncrementalState,) + tuple( + b for b in cls.__bases__ if b != FairseqIncrementalState + ) + return cls + +def eval_str_dict(x, type=dict): + if x is None: + return None + if isinstance(x, str): + x = eval(x) + return x + +def softmax(x, dim: int, onnx_trace: bool = False): + if onnx_trace: + return F.softmax(x.float(), dim=dim) + else: + return F.softmax(x, dim=dim, dtype=torch.float32) + + + +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math +from typing import Dict, List, Optional, Tuple + +import torch +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn import Parameter + +try: + from xformers.components.attention import build_attention + from xformers.components.attention.utils import maybe_merge_masks + + _xformers_available = True +except ImportError: + _xformers_available = False + + +# TODO: move this into xformers? +# TODO: uint8 input type should just output a bool +def _mask_for_xformers(mask: Tensor, to_dtype: Optional[torch.dtype] = None): + """ + call to pytorch multihead accepts three mask types: + - ByteTensor where non-zero means to mask + - FloatTensor which is an additive mask + - BoolTensor where True means to mask + xFormers currently accepts boolean and additive maks. For boolean masks + the values have opposite meaning. For a BoolTensor True mean to keep the value. + """ + float_types = [torch.float, torch.float16] + # If an input mask is a float it is an additive mask. Otherwise it is either uint8 or bool. + additive = mask.dtype in float_types + # If to_dype is not specified, keep same dtype as mask. + to_dtype = mask.dtype if to_dtype is None else to_dtype + to_additive = to_dtype in float_types + + if additive: + if to_additive: + return mask.to(to_dtype) + mask = mask < 0 + + if to_additive: + # return additive mask + new_mask = torch.zeros_like(mask, dtype=to_dtype) + new_mask = new_mask.masked_fill_(mask, -float("inf")) + return new_mask + + # In xFormers True is value to keep rather than value to mask + mask = ~mask.to(torch.bool) + mask = mask.to(to_dtype) + return mask + +def softmax(x, dim: int, onnx_trace: bool = False): + if onnx_trace: + return F.softmax(x.float(), dim=dim) + else: + return F.softmax(x, dim=dim, dtype=torch.float32) + +def eval_str_dict(x, type=dict): + if x is None: + return None + if isinstance(x, str): + x = eval(x) + return x + + +@with_incremental_state +class MultiheadAttention(nn.Module): + """Multi-headed attention. + + See "Attention Is All You Need" for more details. + """ + + def __init__( + self, + embed_dim, + num_heads, + kdim=None, + vdim=None, + dropout=0.0, + bias=True, + add_bias_kv=False, + add_zero_attn=False, + self_attention=False, + encoder_decoder_attention=False, + q_noise=0.0, + qn_block_size=8, + # TODO: pass in config rather than string. + # config defined in xformers.components.attention.AttentionConfig + xformers_att_config: Optional[str] = None, + xformers_blocksparse_layout: Optional[ + torch.Tensor + ] = None, # This should be part of the config + xformers_blocksparse_blocksize: Optional[ + int + ] = 16, # This should be part of the config + ): + super().__init__() + + xformers_att_config = eval_str_dict(xformers_att_config) + self.use_xformers = xformers_att_config is not None + if self.use_xformers and not _xformers_available: + raise ImportError("\n\n Please install xFormers.") + self.embed_dim = embed_dim + self.kdim = kdim if kdim is not None else embed_dim + self.vdim = vdim if vdim is not None else embed_dim + self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim + + self.num_heads = num_heads + self.dropout_module = FairseqDropout( + dropout, module_name=self.__class__.__name__ + ) + + self.head_dim = embed_dim // num_heads + assert ( + self.head_dim * num_heads == self.embed_dim + ), "embed_dim must be divisible by num_heads" + self.scaling = self.head_dim**-0.5 + + self.self_attention = self_attention + self.encoder_decoder_attention = encoder_decoder_attention + + assert not self.self_attention or self.qkv_same_dim, ( + "Self-attention requires query, key and " "value to be of the same size" + ) + + self.k_proj = quant_noise( + nn.Linear(self.kdim, embed_dim, bias=bias), q_noise, qn_block_size + ) + self.v_proj = quant_noise( + nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size + ) + self.q_proj = quant_noise( + nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size + ) + + self.out_proj = quant_noise( + nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size + ) + + if add_bias_kv: + self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim)) + self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim)) + else: + self.bias_k = self.bias_v = None + + self.add_zero_attn = add_zero_attn + self.beam_size = 1 + self.reset_parameters() + + if self.use_xformers: + xformers_att_config["dropout"] = xformers_att_config.get("dropout", dropout) + xformers_att_config["num_heads"] = xformers_att_config.get( + "num_heads", num_heads + ) + + if xformers_blocksparse_layout is not None: + # Could be part of a single config passed only once + xformers_att_config["block_size"] = xformers_blocksparse_blocksize + xformers_att_config["layout"] = xformers_blocksparse_layout + xformers_att_config["name"] = "blocksparse" + + self.attention = build_attention(xformers_att_config) + + self.onnx_trace = False + self.skip_embed_dim_check = False + + def prepare_for_onnx_export_(self): + self.onnx_trace = True + + def reset_parameters(self): + if self.qkv_same_dim: + # Empirically observed the convergence to be much better with + # the scaled initialization + nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2)) + nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2)) + nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2)) + else: + nn.init.xavier_uniform_(self.k_proj.weight) + nn.init.xavier_uniform_(self.v_proj.weight) + nn.init.xavier_uniform_(self.q_proj.weight) + + nn.init.xavier_uniform_(self.out_proj.weight) + if self.out_proj.bias is not None: + nn.init.constant_(self.out_proj.bias, 0.0) + if self.bias_k is not None: + nn.init.xavier_normal_(self.bias_k) + if self.bias_v is not None: + nn.init.xavier_normal_(self.bias_v) + + def _get_reserve_head_index(self, num_heads_to_keep: int): + k_proj_heads_norm = [] + q_proj_heads_norm = [] + v_proj_heads_norm = [] + + for i in range(self.num_heads): + start_idx = i * self.head_dim + end_idx = (i + 1) * self.head_dim + k_proj_heads_norm.append( + torch.sum( + torch.abs( + self.k_proj.weight[ + start_idx:end_idx, + ] + ) + ).tolist() + + torch.sum(torch.abs(self.k_proj.bias[start_idx:end_idx])).tolist() + ) + q_proj_heads_norm.append( + torch.sum( + torch.abs( + self.q_proj.weight[ + start_idx:end_idx, + ] + ) + ).tolist() + + torch.sum(torch.abs(self.q_proj.bias[start_idx:end_idx])).tolist() + ) + v_proj_heads_norm.append( + torch.sum( + torch.abs( + self.v_proj.weight[ + start_idx:end_idx, + ] + ) + ).tolist() + + torch.sum(torch.abs(self.v_proj.bias[start_idx:end_idx])).tolist() + ) + + heads_norm = [] + for i in range(self.num_heads): + heads_norm.append( + k_proj_heads_norm[i] + q_proj_heads_norm[i] + v_proj_heads_norm[i] + ) + + sorted_head_index = sorted( + range(self.num_heads), key=lambda k: heads_norm[k], reverse=True + ) + reserve_head_index = [] + for i in range(num_heads_to_keep): + start = sorted_head_index[i] * self.head_dim + end = (sorted_head_index[i] + 1) * self.head_dim + reserve_head_index.append((start, end)) + return reserve_head_index + + def _adaptive_prune_heads(self, reserve_head_index: List[Tuple[int, int]]): + new_q_weight = [] + new_q_bias = [] + new_k_weight = [] + new_k_bias = [] + new_v_weight = [] + new_v_bias = [] + new_out_proj_weight = [] + + for ele in reserve_head_index: + start_idx, end_idx = ele + new_q_weight.append( + self.q_proj.weight[ + start_idx:end_idx, + ] + ) + new_q_bias.append(self.q_proj.bias[start_idx:end_idx]) + + new_k_weight.append( + self.k_proj.weight[ + start_idx:end_idx, + ] + ) + + new_k_bias.append(self.k_proj.bias[start_idx:end_idx]) + + new_v_weight.append( + self.v_proj.weight[ + start_idx:end_idx, + ] + ) + new_v_bias.append(self.v_proj.bias[start_idx:end_idx]) + + new_out_proj_weight.append(self.out_proj.weight[:, start_idx:end_idx]) + + new_q_weight = torch.cat(new_q_weight).detach() + new_k_weight = torch.cat(new_k_weight).detach() + new_v_weight = torch.cat(new_v_weight).detach() + new_out_proj_weight = torch.cat(new_out_proj_weight, dim=-1).detach() + new_q_weight.requires_grad = True + new_k_weight.requires_grad = True + new_v_weight.requires_grad = True + new_out_proj_weight.requires_grad = True + + new_q_bias = torch.cat(new_q_bias).detach() + new_q_bias.requires_grad = True + + new_k_bias = torch.cat(new_k_bias).detach() + new_k_bias.requires_grad = True + + new_v_bias = torch.cat(new_v_bias).detach() + new_v_bias.requires_grad = True + + self.q_proj.weight = torch.nn.Parameter(new_q_weight) + self.q_proj.bias = torch.nn.Parameter(new_q_bias) + + self.k_proj.weight = torch.nn.Parameter(new_k_weight) + self.k_proj.bias = torch.nn.Parameter(new_k_bias) + + self.v_proj.weight = torch.nn.Parameter(new_v_weight) + self.v_proj.bias = torch.nn.Parameter(new_v_bias) + + self.out_proj.weight = torch.nn.Parameter(new_out_proj_weight) + + self.num_heads = len(reserve_head_index) + self.embed_dim = self.head_dim * self.num_heads + self.q_proj.out_features = self.embed_dim + self.k_proj.out_features = self.embed_dim + self.v_proj.out_features = self.embed_dim + + def _set_skip_embed_dim_check(self): + self.skip_embed_dim_check = True + + def _pad_masks( + self, + key_padding_mask: Optional[Tensor], + attn_mask: Optional[Tensor], + ) -> Tuple[Optional[Tensor], Optional[Tensor]]: + if attn_mask is not None: + shape = attn_mask.size()[:-1] + torch.Size([1]) + attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(shape)], dim=-1) + if key_padding_mask is not None: + shape = key_padding_mask.size()[:-1] + torch.Size([1]) + key_padding_mask = torch.cat( + [ + key_padding_mask, + key_padding_mask.new_zeros(shape), + ], + dim=-1, + ) + return key_padding_mask, attn_mask + + def _add_bias( + self, + k: Tensor, + v: Tensor, + key_padding_mask: Optional[Tensor], + attn_mask: Optional[Tensor], + bsz: int, + ) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]: + assert self.bias_k is not None + assert self.bias_v is not None + k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)]) + v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)]) + key_padding_mask, attn_mask = self._pad_masks( + key_padding_mask=key_padding_mask, attn_mask=attn_mask + ) + return k, v, key_padding_mask, attn_mask + + def _append_zero_attn( + self, + k: Tensor, + v: Tensor, + key_padding_mask: Optional[Tensor], + attn_mask: Optional[Tensor], + ) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]: + zero_attn_shape = k.size()[:-2] + torch.Size([1]) + k.size()[-1:] + k = torch.cat( + [k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=-2 + ) + v = torch.cat( + [v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=-2 + ) + key_padding_mask, attn_mask = self._pad_masks( + key_padding_mask=key_padding_mask, attn_mask=attn_mask + ) + return k, v, key_padding_mask, attn_mask + + def _xformers_attn_forward( + self, + query, + key: Optional[Tensor], + value: Optional[Tensor], + key_padding_mask: Optional[Tensor] = None, + need_weights: bool = True, + attn_mask: Optional[Tensor] = None, + ) -> Tuple[Tensor, Optional[Tensor]]: + + tgt_len, bsz, embed_dim = query.size() + + if key_padding_mask is not None: + assert key_padding_mask.size(0) == bsz + assert key_padding_mask.size(1) == tgt_len + + if self.self_attention: + key = query + value = query + elif self.encoder_decoder_attention: + value = key + + q = self.q_proj(query) + k = self.k_proj(key) + v = self.v_proj(value) + + if self.bias_k is not None: + assert self.bias_v is not None + k, v, attn_mask, key_padding_mask = self._add_bias( + k, v, attn_mask, key_padding_mask, bsz + ) + + def fold_heads(x): + return ( + x.contiguous() + .view(-1, bsz * self.num_heads, self.head_dim) + .transpose(0, 1) + ) + + def split_heads(x): + return ( + x.contiguous() + .view(-1, bsz, self.num_heads, self.head_dim) + .transpose(0, 1) + .transpose(1, 2) + ) + + massage = split_heads if self.attention.requires_head_dimension else fold_heads + q = massage(q) + if k is not None: + k = massage(k) + if v is not None: + v = massage(v) + + if self.add_zero_attn: + k, v, key_padding_mask, attn_mask = self._append_zero_attn( + k=k, v=v, key_padding_mask=key_padding_mask, attn_mask=attn_mask + ) + + kwargs = {} + + if attn_mask is not None and self.attention.supports_attention_mask: + attn_mask = _mask_for_xformers(attn_mask, to_dtype=q.dtype) + kwargs["att_mask"] = attn_mask + + if key_padding_mask is not None: + key_padding_mask = _mask_for_xformers(key_padding_mask, to_dtype=torch.bool) + if not self.attention.requires_separate_masks: + attn_mask = maybe_merge_masks( + attn_mask, + key_padding_mask, + batch_size=bsz, + src_len=k.size(-2), + tgt_len=q.size(-2), + num_heads=self.num_heads, + ) + key_padding_mask = None + kwargs["att_mask"] = attn_mask + if self.attention.supports_key_padding_mask: + kwargs["key_padding_mask"] = key_padding_mask + + y = self.attention(q, k, v, **kwargs) + + y = ( + y.view(bsz, self.num_heads, tgt_len, self.head_dim) + .transpose(1, 2) + .flatten(start_dim=2, end_dim=3) + .transpose(0, 1) + ) + assert list(y.size()) == [tgt_len, bsz, embed_dim] + + # Dropout not needed because already applied in attention. + # It is applied to the attention weights before matmul with v. + y = self.out_proj(y) + + # TODO: support returning attention weights if needed. + return y, None + + def forward( + self, + query, + key: Optional[Tensor], + value: Optional[Tensor], + key_padding_mask: Optional[Tensor] = None, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, + need_weights: bool = True, + static_kv: bool = False, + attn_mask: Optional[Tensor] = None, + before_softmax: bool = False, + need_head_weights: bool = False, + ) -> Tuple[Tensor, Optional[Tensor]]: + """Input shape: Time x Batch x Channel + + Args: + key_padding_mask (ByteTensor, optional): mask to exclude + keys that are pads, of shape `(batch, src_len)`, where + padding elements are indicated by 1s. + need_weights (bool, optional): return the attention weights, + averaged over heads (default: False). + attn_mask (ByteTensor, optional): typically used to + implement causal attention, where the mask prevents the + attention from looking forward in time (default: None). + before_softmax (bool, optional): return the raw attention + weights and values before the attention softmax. + need_head_weights (bool, optional): return the attention + weights for each head. Implies *need_weights*. Default: + return the average attention weights over all heads. + """ + if need_head_weights: + need_weights = True + + is_tpu = query.device.type == "xla" + + tgt_len, bsz, embed_dim = query.size() + src_len = tgt_len + if not self.skip_embed_dim_check: + assert ( + embed_dim == self.embed_dim + ), f"query dim {embed_dim} != {self.embed_dim}" + assert list(query.size()) == [tgt_len, bsz, embed_dim] + if key is not None: + src_len, key_bsz, _ = key.size() + if not torch.jit.is_scripting(): + assert value is not None + assert src_len, key_bsz == value.shape[:2] + + if ( + not self.onnx_trace + and not is_tpu # don't use PyTorch version on TPUs + and incremental_state is None + and not static_kv + # A workaround for quantization to work. Otherwise JIT compilation + # treats bias in linear module as method. + and not torch.jit.is_scripting() + # The Multihead attention implemented in pytorch forces strong dimension check + # for input embedding dimention and K,Q,V projection dimension. + # Since pruning will break the dimension check and it is not easy to modify the pytorch API, + # it is preferred to bypass the pytorch MHA when we need to skip embed_dim_check + and not self.skip_embed_dim_check + ): + assert key is not None and value is not None + + if self.use_xformers: + return self._xformers_attn_forward( + query, key, value, key_padding_mask, need_weights, attn_mask + ) + + else: + return F.multi_head_attention_forward( + query, + key, + value, + self.embed_dim, + self.num_heads, + torch.empty([0]), + torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)), + self.bias_k, + self.bias_v, + self.add_zero_attn, + self.dropout_module.p, + self.out_proj.weight, + self.out_proj.bias, + self.training or self.dropout_module.apply_during_inference, + key_padding_mask, + need_weights, + attn_mask, + use_separate_proj_weight=True, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + ) + + if incremental_state is not None: + saved_state = self._get_input_buffer(incremental_state) + if saved_state is not None and "prev_key" in saved_state: + # previous time steps are cached - no need to recompute + # key and value if they are static + if static_kv: + assert self.encoder_decoder_attention and not self.self_attention + key = value = None + else: + saved_state = None + + if self.self_attention: + q = self.q_proj(query) + k = self.k_proj(query) + v = self.v_proj(query) + elif self.encoder_decoder_attention: + # encoder-decoder attention + q = self.q_proj(query) + if key is None: + assert value is None + k = v = None + else: + if self.beam_size > 1 and bsz == key.size(1): + # key is [T, bsz*beam_size, C], reduce to [T, bsz, C] + key = key.view(key.size(0), -1, self.beam_size, key.size(2))[ + :, :, 0, : + ] + if key_padding_mask is not None: + key_padding_mask = key_padding_mask.view( + -1, self.beam_size, key_padding_mask.size(1) + )[:, 0, :] + k = self.k_proj(key) + v = self.v_proj(key) + + else: + assert key is not None and value is not None + q = self.q_proj(query) + k = self.k_proj(key) + v = self.v_proj(value) + q *= self.scaling + + if self.bias_k is not None: + assert self.bias_v is not None + k, v, attn_mask, key_padding_mask = self._add_bias( + k, v, attn_mask, key_padding_mask, bsz + ) + + q = ( + q.contiguous() + .view(tgt_len, bsz * self.num_heads, self.head_dim) + .transpose(0, 1) + ) + kv_bsz = bsz # need default value for scripting + if k is not None: + kv_bsz = k.size(1) + k = ( + k.contiguous() + .view(-1, kv_bsz * self.num_heads, self.head_dim) + .transpose(0, 1) + ) + if v is not None: + v = ( + v.contiguous() + .view(-1, kv_bsz * self.num_heads, self.head_dim) + .transpose(0, 1) + ) + + if saved_state is not None: + # saved states are stored with shape (bsz, num_heads, seq_len, head_dim) + if "prev_key" in saved_state: + _prev_key = saved_state["prev_key"] + assert _prev_key is not None + kv_bsz = _prev_key.size(0) + prev_key = _prev_key.view(kv_bsz * self.num_heads, -1, self.head_dim) + if static_kv: + k = prev_key + else: + assert k is not None + k = torch.cat([prev_key, k], dim=1) + src_len = k.size(1) + if "prev_value" in saved_state: + _prev_value = saved_state["prev_value"] + assert _prev_value is not None + assert kv_bsz == _prev_value.size(0) + prev_value = _prev_value.view( + kv_bsz * self.num_heads, -1, self.head_dim + ) + if static_kv: + v = prev_value + else: + assert v is not None + v = torch.cat([prev_value, v], dim=1) + prev_key_padding_mask: Optional[Tensor] = None + if "prev_key_padding_mask" in saved_state: + prev_key_padding_mask = saved_state["prev_key_padding_mask"] + assert k is not None and v is not None + key_padding_mask = MultiheadAttention._append_prev_key_padding_mask( + key_padding_mask=key_padding_mask, + prev_key_padding_mask=prev_key_padding_mask, + batch_size=kv_bsz, + src_len=k.size(1), + static_kv=static_kv, + ) + + saved_state["prev_key"] = k.view(kv_bsz, self.num_heads, -1, self.head_dim) + saved_state["prev_value"] = v.view( + kv_bsz, self.num_heads, -1, self.head_dim + ) + saved_state["prev_key_padding_mask"] = key_padding_mask + # In this branch incremental_state is never None + assert incremental_state is not None + incremental_state = self._set_input_buffer(incremental_state, saved_state) + assert k is not None + assert k.size(1) == src_len + + # This is part of a workaround to get around fork/join parallelism + # not supporting Optional types. + if key_padding_mask is not None and key_padding_mask.dim() == 0: + key_padding_mask = None + + if key_padding_mask is not None: + assert key_padding_mask.size(0) == kv_bsz + assert key_padding_mask.size(1) == src_len + + if self.add_zero_attn: + assert v is not None + src_len += 1 + k, v, key_padding_mask, attn_mask = self._append_zero_attn( + k=k, v=v, key_padding_mask=key_padding_mask, attn_mask=attn_mask + ) + + if self.encoder_decoder_attention and bsz != kv_bsz: + attn_weights = torch.einsum( + "bxhtd,bhsd->bxhts", + q.view((kv_bsz, -1, self.num_heads) + q.size()[1:]), + k.view((kv_bsz, self.num_heads) + k.size()[1:]), + ) + attn_weights = attn_weights.reshape((-1,) + attn_weights.size()[-2:]) + else: + attn_weights = torch.bmm(q, k.transpose(1, 2)) + attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz) + + assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len] + + if attn_mask is not None: + attn_mask = attn_mask.unsqueeze(0) + if self.onnx_trace: + attn_mask = attn_mask.repeat(attn_weights.size(0), 1, 1) + attn_weights += attn_mask + + if key_padding_mask is not None: + # don't attend to padding symbols + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + if not is_tpu: + attn_weights = attn_weights.view( + kv_bsz, -1, self.num_heads, tgt_len, src_len + ) + attn_weights = attn_weights.masked_fill( + key_padding_mask.unsqueeze(1) + .unsqueeze(2) + .unsqueeze(3) + .to(torch.bool), + float("-inf"), + ) + else: + attn_weights = attn_weights.transpose(0, 2) + attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf")) + attn_weights = attn_weights.transpose(0, 2) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if before_softmax: + return attn_weights, v + + attn_weights_float = softmax( + attn_weights, dim=-1, onnx_trace=self.onnx_trace + ) + attn_weights = attn_weights_float.type_as(attn_weights) + attn_probs = self.dropout_module(attn_weights) + + assert v is not None + if self.encoder_decoder_attention and bsz != kv_bsz: + attn = torch.einsum( + "bxhts,bhsd->bxhtd", + attn_probs.view( + ( + kv_bsz, + -1, + self.num_heads, + ) + + attn_probs.size()[1:] + ), + v.view( + ( + kv_bsz, + self.num_heads, + ) + + v.size()[1:] + ), + ) + attn = attn.reshape((-1,) + attn.size()[-2:]) + else: + attn = torch.bmm(attn_probs, v) + assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim] + if self.onnx_trace and attn.size(1) == 1: + # when ONNX tracing a single decoder step (sequence length == 1) + # the transpose is a no-op copy before view, thus unnecessary + attn = attn.contiguous().view(tgt_len, bsz, self.embed_dim) + else: + attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, self.embed_dim) + attn = self.out_proj(attn) + attn_weights: Optional[Tensor] = None + if need_weights: + attn_weights = attn_weights_float.view( + bsz, self.num_heads, tgt_len, src_len + ).transpose(1, 0) + if not need_head_weights: + # average attention weights over heads + attn_weights = attn_weights.mean(dim=0) + + return attn, attn_weights + + @staticmethod + def _append_prev_key_padding_mask( + key_padding_mask: Optional[Tensor], + prev_key_padding_mask: Optional[Tensor], + batch_size: int, + src_len: int, + static_kv: bool, + ) -> Optional[Tensor]: + # saved key padding masks have shape (bsz, seq_len) + if prev_key_padding_mask is not None and static_kv: + new_key_padding_mask = prev_key_padding_mask + elif prev_key_padding_mask is not None and key_padding_mask is not None: + new_key_padding_mask = torch.cat( + [prev_key_padding_mask.float(), key_padding_mask.float()], dim=1 + ) + # During incremental decoding, as the padding token enters and + # leaves the frame, there will be a time when prev or current + # is None + elif prev_key_padding_mask is not None: + if src_len > prev_key_padding_mask.size(1): + filler = torch.zeros( + (batch_size, src_len - prev_key_padding_mask.size(1)), + device=prev_key_padding_mask.device, + ) + new_key_padding_mask = torch.cat( + [prev_key_padding_mask.float(), filler.float()], dim=1 + ) + else: + new_key_padding_mask = prev_key_padding_mask.float() + elif key_padding_mask is not None: + if src_len > key_padding_mask.size(1): + filler = torch.zeros( + (batch_size, src_len - key_padding_mask.size(1)), + device=key_padding_mask.device, + ) + new_key_padding_mask = torch.cat( + [filler.float(), key_padding_mask.float()], dim=1 + ) + else: + new_key_padding_mask = key_padding_mask.float() + else: + new_key_padding_mask = prev_key_padding_mask + return new_key_padding_mask + + @torch.jit.export + def reorder_incremental_state( + self, + incremental_state: Dict[str, Dict[str, Optional[Tensor]]], + new_order: Tensor, + ): + """Reorder buffered internal state (for incremental generation).""" + input_buffer = self._get_input_buffer(incremental_state) + if input_buffer is not None: + for k in input_buffer.keys(): + input_buffer_k = input_buffer[k] + if input_buffer_k is not None: + if self.encoder_decoder_attention: + if input_buffer_k.size(0) * self.beam_size == new_order.size(0): + return incremental_state + elif self.beam_size > 1: + input_buffer[k] = input_buffer_k.index_select( + 0, + new_order.reshape(-1, self.beam_size)[:, 0] + // self.beam_size, + ) + else: + input_buffer[k] = input_buffer_k.index_select(0, new_order) + else: + input_buffer[k] = input_buffer_k.index_select(0, new_order) + incremental_state = self._set_input_buffer(incremental_state, input_buffer) + return incremental_state + + def set_beam_size(self, beam_size): + """Used for effiecient beamable enc-dec attention""" + self.beam_size = beam_size + + def _get_input_buffer( + self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] + ) -> Dict[str, Optional[Tensor]]: + result = self.get_incremental_state(incremental_state, "attn_state") + if result is not None: + return result + else: + empty_result: Dict[str, Optional[Tensor]] = {} + return empty_result + + def _set_input_buffer( + self, + incremental_state: Dict[str, Dict[str, Optional[Tensor]]], + buffer: Dict[str, Optional[Tensor]], + ): + return self.set_incremental_state(incremental_state, "attn_state", buffer) + + def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int): + return attn_weights + + def upgrade_state_dict_named(self, state_dict, name): + prefix = name + "." if name != "" else "" + items_to_add = {} + keys_to_remove = [] + for k in state_dict.keys(): + if k.endswith(prefix + "in_proj_weight"): + # in_proj_weight used to be q + k + v with same dimensions + dim = int(state_dict[k].shape[0] / 3) + items_to_add[prefix + "q_proj.weight"] = state_dict[k][:dim] + items_to_add[prefix + "k_proj.weight"] = state_dict[k][dim : 2 * dim] + items_to_add[prefix + "v_proj.weight"] = state_dict[k][2 * dim :] + + keys_to_remove.append(k) + + k_bias = prefix + "in_proj_bias" + if k_bias in state_dict.keys(): + dim = int(state_dict[k].shape[0] / 3) + items_to_add[prefix + "q_proj.bias"] = state_dict[k_bias][:dim] + items_to_add[prefix + "k_proj.bias"] = state_dict[k_bias][ + dim : 2 * dim + ] + items_to_add[prefix + "v_proj.bias"] = state_dict[k_bias][2 * dim :] + + keys_to_remove.append(prefix + "in_proj_bias") + + for k in keys_to_remove: + del state_dict[k] + + for key, value in items_to_add.items(): + state_dict[key] = value + +@dataclass +class QuantNoiseConfig: + _name: str = "transformer" + pq: float = 0.0 + pq_block_size: int = 8 + scalar: float = 0.0 + + def to_dict(self): + return asdict(self) + + @classmethod + def from_dict(cls, data): + return cls(**data) + +@dataclass +class EncDecBaseConfig: + _name: str = "transformer" + embed_path: Optional[str] = None + embed_dim: int = 768 + ffn_embed_dim: int = 3072 + layers: int = 12 + attention_heads: int = 12 + normalize_before: bool = False + learned_pos: bool = False + layerdrop: float = 0.0 + layers_to_keep: Optional[list[int]] = None + xformers_att_config: Optional[dict] = None + quant_noise: QuantNoiseConfig = field(default_factory=QuantNoiseConfig) + padding_idx= 1 + vocab_size = 64001 + + +@dataclass +class DecoderConfig(EncDecBaseConfig): + input_dim: int = 768 + output_dim: int = 768 + vocab_size = 528 + +@dataclass +class TransformerConfig: + _name: str = "transformer" + activation_fn: str = "relu" + dropout: float = 0.1 + attention_dropout: float = 0.1 + activation_dropout: float = 0.0 + adaptive_input: bool = False + encoder: EncDecBaseConfig = field(default_factory=EncDecBaseConfig) + max_source_positions: int = 1024 + decoder: DecoderConfig = field(default_factory=DecoderConfig) + max_target_positions: int = 1024 + share_decoder_input_output_embed: bool = True + share_all_embeddings: bool = False + no_token_positional_embeddings: bool = False + adaptive_softmax_cutoff: Optional[list[int]] = None + adaptive_softmax_dropout: float = 0.0 + adaptive_softmax_factor: int = 4 + layernorm_embedding: bool = False + tie_adaptive_weights: bool = False + tie_adaptive_proj: bool = False + no_scale_embedding: bool = False + checkpoint_activations: bool = False + offload_activations: bool = False + no_cross_attention: bool = False + cross_self_attention: bool = False + quant_noise: QuantNoiseConfig = field(default_factory=QuantNoiseConfig) + min_params_to_wrap: int = 100_000_000 + char_inputs: bool = False + relu_dropout: float = 0.0 + base_layers: int = 0 + base_sublayers: int = 1 + base_shuffle: int = 1 + export: bool = False + no_decoder_final_norm: bool = False + +# Example of instantiating the config +main_config = TransformerConfig() + + +class TokenEmbedding(nn.Module): + def __init__(self, vocab_size, embed_dim, padding_idx): + super().__init__() + self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx) + self.vocab_size = vocab_size + self.embedding_dim = embed_dim + self.padding_idx = padding_idx + + def forward(self, input_tokens): + return self.embedding(input_tokens) + +# Example Usage +def initialize_embed_tokens(cfg, model='encoder'): + """ + Initialize the embed_tokens layer. + + Args: + cfg: Configuration object + dictionary: Vocabulary dictionary with token-to-index mapping + + Returns: + embed_tokens: Token embedding layer + """ + vocab_size = cfg.encoder.vocab_size if model == 'encoder' else cfg.decoder.vocab_size # Assuming this attribute is added in the config + embed_dim = cfg.encoder.embed_dim # Assuming this attribute is added in the config + padding_idx = cfg.encoder.padding_idx #dictionary.pad() # Fetch the padding index from the dictionary + return TokenEmbedding(vocab_size, embed_dim, padding_idx) + + +class EncoderDecoderModel(nn.Module): + """Standalone Encoder-Decoder model for Fairseq with necessary functionalities.""" + + def __init__(self, cfg): + super().__init__() + self.cfg = cfg + self.encoder = TransformerEncoderBase(cfg, enc_dictionary, encoder_embedding.embedding) + self.decoder = TransformerDecoderBase(cfg, dec_dictionary, decoder_embedding.embedding) + self.supports_align_args = True + self._is_generation_fast = False + + def forward(self, src_tokens, src_lengths, prev_output_tokens, **kwargs): + """ + Perform a forward pass. + + Args: + src_tokens (LongTensor): Source tokens `(batch, src_len)` + src_lengths (LongTensor): Source lengths `(batch)` + prev_output_tokens (LongTensor): Previous decoder outputs `(batch, tgt_len)` + + Returns: + Tuple: decoder output and additional info + """ + encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, + **kwargs) + decoder_out = self.decoder( + prev_output_tokens, encoder_out=encoder_out, **kwargs + ) + return decoder_out + + def forward_decoder(self, prev_output_tokens, **kwargs): + return self.decoder(prev_output_tokens, **kwargs) + + + + def output_layer(self, features, **kwargs): + """Project features to the default output size (typically vocabulary size).""" + return self.decoder.output_layer(features, **kwargs) + + def max_positions(self): + """Maximum length supported by the model.""" + return (self.encoder.max_positions(), self.decoder.max_positions()) + + def max_decoder_positions(self): + """Maximum length supported by the decoder.""" + return self.decoder.max_positions() + + + + + + +encoder_embedding = initialize_embed_tokens(main_config) +decoder_embedding = initialize_embed_tokens(main_config, 'decoder') +enc_dictionary = [9]* main_config.encoder.vocab_size +dec_dictionary = [9] * main_config.decoder.vocab_size + + +class AfroLidForSequenceClassification(PreTrainedModel): + config_class = AfroLidConfig + base_model_prefix = "transformer" + + def __init__(self, config): + super().__init__(config) + self.cfg = main_config + self.encoder = TransformerEncoderBase(self.cfg, enc_dictionary, encoder_embedding.embedding) + self.decoder = TransformerDecoderBase(self.cfg, dec_dictionary, decoder_embedding.embedding) + self.supports_align_args = True + self._is_generation_fast = False + + def forward(self, src_tokens, src_lengths, prev_output_tokens, **kwargs): + """ + Perform a forward pass. + + Args: + src_tokens (LongTensor): Source tokens `(batch, src_len)` + src_lengths (LongTensor): Source lengths `(batch)` + prev_output_tokens (LongTensor): Previous decoder outputs `(batch, tgt_len)` + + Returns: + Tuple: decoder output and additional info + """ + encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs) + decoder_out = self.decoder( + prev_output_tokens, encoder_out=encoder_out, **kwargs + ) + return decoder_out + + def forward_decoder(self, prev_output_tokens, **kwargs): + return self.decoder(prev_output_tokens, **kwargs) + + + + def output_layer(self, features, **kwargs): + """Project features to the default output size (typically vocabulary size).""" + return self.decoder.output_layer(features, **kwargs) + + def max_positions(self): + """Maximum length supported by the model.""" + return (self.encoder.max_positions(), self.decoder.max_positions()) + + def max_decoder_positions(self): + """Maximum length supported by the decoder.""" + return self.decoder.max_positions() + + +config = AfroLidConfig() +afrolid_model = AfroLidForSequenceClassification(config) +AutoConfig.register("afrolid", AfroLidConfig) +AutoModel.register(AfroLidConfig, AfroLidForSequenceClassification) +AutoModelForSequenceClassification.register( + AfroLidConfig, AfroLidForSequenceClassification) \ No newline at end of file