import math from typing import Any, Optional import torch import torch.onnx.operators from torch import nn, Tensor import torch.nn as nn from typing import Optional, Dict, List, Any, Tuple import torch.nn as nn import torch.nn.functional as F import torch import sys import torch.distributed as dist import uuid from dataclasses import dataclass, field, asdict from transformers.modeling_utils import PreTrainedModel from transformers import AutoConfig, AutoModel, AutoModelForSequenceClassification from .configuration_afrolid import AfroLidConfig def quant_noise(module, p, block_size): """ Wraps modules and applies quantization noise to the weights for subsequent quantization with Iterative Product Quantization as described in "Training with Quantization Noise for Extreme Model Compression" Args: - module: nn.Module - p: amount of Quantization Noise - block_size: size of the blocks for subsequent quantization with iPQ Remarks: - Module weights must have the right sizes wrt the block size - Only Linear, Embedding and Conv2d modules are supported for the moment - For more detail on how to quantize by blocks with convolutional weights, see "And the Bit Goes Down: Revisiting the Quantization of Neural Networks" - We implement the simplest form of noise here as stated in the paper which consists in randomly dropping blocks """ # if no quantization noise, don't register hook if p <= 0: return module # supported modules assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d)) # test whether module.weight has the right sizes wrt block_size is_conv = module.weight.ndim == 4 # 2D matrix if not is_conv: assert ( module.weight.size(1) % block_size == 0 ), "Input features must be a multiple of block sizes" # 4D matrix else: # 1x1 convolutions if module.kernel_size == (1, 1): assert ( module.in_channels % block_size == 0 ), "Input channels must be a multiple of block sizes" # regular convolutions else: k = module.kernel_size[0] * module.kernel_size[1] assert k % block_size == 0, "Kernel size must be a multiple of block size" def _forward_pre_hook(mod, input): # no noise for evaluation if mod.training: if not is_conv: # gather weight and sizes weight = mod.weight in_features = weight.size(1) out_features = weight.size(0) # split weight matrix into blocks and randomly drop selected blocks mask = torch.zeros( in_features // block_size * out_features, device=weight.device ) mask.bernoulli_(p) mask = mask.repeat_interleave(block_size, -1).view(-1, in_features) else: # gather weight and sizes weight = mod.weight in_channels = mod.in_channels out_channels = mod.out_channels # split weight matrix into blocks and randomly drop selected blocks if mod.kernel_size == (1, 1): mask = torch.zeros( int(in_channels // block_size * out_channels), device=weight.device, ) mask.bernoulli_(p) mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels) else: mask = torch.zeros( weight.size(0), weight.size(1), device=weight.device ) mask.bernoulli_(p) mask = ( mask.unsqueeze(2) .unsqueeze(3) .repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1]) ) # scale weights and apply mask mask = mask.to( torch.bool ) # x.bool() is not currently supported in TorchScript s = 1 / (1 - p) mod.weight.data = s * weight.masked_fill(mask, 0) module.register_forward_pre_hook(_forward_pre_hook) return module def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True, export=False): # if torch.jit.is_scripting() or torch.jit.is_tracing(): # export = True # if not export and torch.cuda.is_available() and has_fused_layernorm: # return FusedLayerNorm(normalized_shape, eps, elementwise_affine) return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine) class LayerDropModuleList(nn.ModuleList): """ A LayerDrop implementation based on :class:`torch.nn.ModuleList`. We refresh the choice of which layers to drop every time we iterate over the LayerDropModuleList instance. During evaluation we always iterate over all layers. Usage:: layers = LayerDropList(p=0.5, modules=[layer1, layer2, layer3]) for layer in layers: # this might iterate over layers 1 and 3 x = layer(x) for layer in layers: # this might iterate over all layers x = layer(x) for layer in layers: # this might not iterate over any layers x = layer(x) Args: p (float): probability of dropping out each layer modules (iterable, optional): an iterable of modules to add """ def __init__(self, p, modules=None): super().__init__(modules) self.p = p def __iter__(self): dropout_probs = torch.empty(len(self)).uniform_() for i, m in enumerate(super().__iter__()): if not self.training or (dropout_probs[i] > self.p): yield m from typing import List, Callable from typing import Dict import warnings def gelu_accurate(x): if not hasattr(gelu_accurate, "_a"): gelu_accurate._a = math.sqrt(2 / math.pi) return ( 0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3)))) ) def deprecation_warning(message, stacklevel=3): # don't use DeprecationWarning, since it's ignored by default warnings.warn(message, stacklevel=stacklevel) def gelu(x: torch.Tensor) -> torch.Tensor: return torch.nn.functional.gelu(x.float()).type_as(x) def relu_squared(x: torch.Tensor): return F.relu(x).pow(2) def get_activation_fn(activation: str) -> Callable: """Returns the activation function corresponding to `activation`""" if activation == "relu": return F.relu elif activation == "relu_squared": return relu_squared elif activation == "gelu": return gelu elif activation == "gelu_fast": deprecation_warning( "--activation-fn=gelu_fast has been renamed to gelu_accurate" ) return gelu_accurate elif activation == "gelu_accurate": return gelu_accurate elif activation == "tanh": return torch.tanh elif activation == "linear": return lambda x: x elif activation == "swish": return torch.nn.SiLU else: raise RuntimeError("--activation-fn {} not supported".format(activation)) class FairseqDropout(nn.Module): def __init__(self, p, module_name=None): super().__init__() self.p = p self.module_name = module_name self.apply_during_inference = False def forward(self, x, inplace: bool = False): if self.p > 0 and (self.training or self.apply_during_inference): return F.dropout(x, p=self.p, training=True, inplace=inplace) else: return x class TransformerEncoderLayerBase(nn.Module): """Encoder layer block. In the original paper each operation (multi-head attention or FFN) is postprocessed with: `dropout -> add residual -> layernorm`. In the tensor2tensor code they suggest that learning is more robust when preprocessing each layer with layernorm and postprocessing with: `dropout -> add residual`. We default to the approach in the paper, but the tensor2tensor approach can be enabled by setting *cfg.encoder.normalize_before* to ``True``. Args: args (argparse.Namespace): parsed command-line arguments """ def __init__(self, cfg, return_fc=False): super().__init__() self.cfg = cfg self.return_fc = return_fc self.embed_dim = cfg.encoder.embed_dim self.quant_noise = cfg.quant_noise.pq self.quant_noise_block_size = cfg.quant_noise.pq_block_size self.self_attn = self.build_self_attention(self.embed_dim, cfg) self.self_attn_layer_norm = LayerNorm(self.embed_dim, export=cfg.export) self.dropout_module = FairseqDropout( cfg.dropout, module_name=self.__class__.__name__ ) self.activation_fn = get_activation_fn(activation=cfg.activation_fn) activation_dropout_p = cfg.activation_dropout if activation_dropout_p == 0: # for backwards compatibility with models that use cfg.relu_dropout activation_dropout_p = cfg.relu_dropout or 0 self.activation_dropout_module = FairseqDropout( float(activation_dropout_p), module_name=self.__class__.__name__ ) self.normalize_before = cfg.encoder.normalize_before self.fc1 = self.build_fc1( self.embed_dim, cfg.encoder.ffn_embed_dim, self.quant_noise, self.quant_noise_block_size, ) self.fc2 = self.build_fc2( cfg.encoder.ffn_embed_dim, self.embed_dim, self.quant_noise, self.quant_noise_block_size, ) self.final_layer_norm = LayerNorm(self.embed_dim, export=cfg.export) self.num_heads = cfg.encoder.attention_heads self.load_to_BT = False self.ever_training = False # For BT, we need continuous mem self.in_proj_weight = torch.nn.Parameter( torch.zeros( self.self_attn.q_proj.weight.shape[0] * 3, self.self_attn.q_proj.weight.shape[1], ) ) self.in_proj_bias = torch.nn.Parameter( torch.zeros(self.self_attn.q_proj.bias.shape[0] * 3) ) self.out_proj_weight = torch.nn.Parameter( torch.zeros(self.self_attn.out_proj.weight.shape) ) self.out_proj_bias = torch.nn.Parameter( torch.zeros(self.self_attn.out_proj.bias.shape) ) self.fc1_weight = torch.nn.Parameter(torch.zeros(self.fc1.weight.shape)) self.fc1_bias = torch.nn.Parameter(torch.zeros(self.fc1.bias.shape)) self.fc2_weight = torch.nn.Parameter(torch.zeros(self.fc2.weight.shape)) self.fc2_bias = torch.nn.Parameter(torch.zeros(self.fc2.bias.shape)) if ( self.activation_fn is torch.nn.functional.relu or isinstance(self.activation_fn, torch.nn.ReLU) or self.activation_fn == "relu" ): self.activation_relu_or_gelu = 1 elif ( self.activation_fn is torch.nn.functional.gelu or isinstance(self.activation_fn, torch.nn.GELU) or self.activation_fn == "gelu" ): self.activation_relu_or_gelu = 2 else: self.activation_relu_or_gelu = 0 # Batch first can not be justified but needs user to make sure self.can_use_fastpath = None self.cfg_checkpoint_activations = self.cfg.checkpoint_activations # torch version check # make sure BT version is >=1.12.0 self.BT_version = False if "fb" in torch.__version__: self.BT_version = True else: if "+" in torch.__version__: self.torch_version = torch.__version__.split("+")[0] else: self.torch_version = torch.__version__ self.torch_version = self.torch_version.split(".") self.int_version = ( int(self.torch_version[0]) * 1000 + int(self.torch_version[1]) * 10 + int(self.torch_version[2]) ) if len(self.torch_version) == 3: if self.int_version >= 1120: self.BT_version = True elif len(self.torch_version) == 4: if self.int_version >= 1130: self.BT_version = True def _load_from_state_dict( self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs, ): self.load_to_BT = True old_name = prefix + "self_attn." q_proj_weight = state_dict[old_name + "q_proj.weight"] k_proj_weight = state_dict[old_name + "k_proj.weight"] v_proj_weight = state_dict[old_name + "v_proj.weight"] q_proj_bias = state_dict[old_name + "q_proj.bias"] k_proj_bias = state_dict[old_name + "k_proj.bias"] v_proj_bias = state_dict[old_name + "v_proj.bias"] new_name = prefix state_dict[new_name + "in_proj_weight"] = torch.cat( (q_proj_weight, k_proj_weight, v_proj_weight), dim=0 ) state_dict[new_name + "in_proj_bias"] = torch.cat( (q_proj_bias, k_proj_bias, v_proj_bias), dim=0 ) state_dict[new_name + "out_proj_weight"] = state_dict[ old_name + "out_proj.weight" ] state_dict[new_name + "out_proj_bias"] = state_dict[old_name + "out_proj.bias"] state_dict[new_name + "fc1_weight"] = state_dict[prefix + "fc1.weight"] state_dict[new_name + "fc1_bias"] = state_dict[prefix + "fc1.bias"] state_dict[new_name + "fc2_weight"] = state_dict[prefix + "fc2.weight"] state_dict[new_name + "fc2_bias"] = state_dict[prefix + "fc2.bias"] super(TransformerEncoderLayerBase, self)._load_from_state_dict( state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs, ) def build_fc1(self, input_dim, output_dim, q_noise, qn_block_size): return quant_noise( nn.Linear(input_dim, output_dim), p=q_noise, block_size=qn_block_size ) def build_fc2(self, input_dim, output_dim, q_noise, qn_block_size): return quant_noise( nn.Linear(input_dim, output_dim), p=q_noise, block_size=qn_block_size ) def _get_fc_rank(self, remove_num: int) -> List[int]: f1_filter_param = [] for i in range(self.fc1.out_features): f1_filter_param.append( torch.sum(torch.abs(self.fc1.weight[i])) + torch.sum(torch.abs(self.fc2.weight[:, i])) + torch.abs(self.fc1.bias[i]) ) return sorted( range(len(f1_filter_param)), key=lambda k: f1_filter_param[k], reverse=False )[0:remove_num] def _prune_fc_layer(self, remove_index: List[int]): new_fc1_weight = [] new_fc1_bias = [] for i in range(self.fc1.out_features): if i not in remove_index: new_fc1_weight.append(self.fc1.weight[i]) new_fc1_bias.append(self.fc1.bias[i]) new_fc1_weight = torch.stack(new_fc1_weight).detach() new_fc1_weight.requires_grad = True new_fc1_bias = torch.stack(new_fc1_bias).detach() new_fc1_bias.requires_grad = True self.fc1 = quant_noise( nn.Linear(self.fc1.in_features, self.fc1.out_features - len(remove_index)), p=self.quant_noise, block_size=self.quant_noise_block_size, ) self.fc1.weight = torch.nn.Parameter(new_fc1_weight) self.fc1.bias = torch.nn.Parameter(new_fc1_bias) new_fc2_weight = [] new_fc2_bias = [] for i in range(self.fc2.in_features): if i not in remove_index: new_fc2_weight.append(self.fc2.weight[:, i]) new_fc2_bias = self.fc2.bias.detach() new_fc2_weight = torch.stack(new_fc2_weight, dim=-1).detach() new_fc2_weight.requires_grad = True new_fc2_bias = self.fc2.bias.detach() new_fc2_bias.requires_grad = True self.fc2 = quant_noise( nn.Linear(self.fc2.in_features - len(remove_index), self.fc2.out_features), p=self.quant_noise, block_size=self.quant_noise_block_size, ) self.fc2.weight = torch.nn.Parameter(new_fc2_weight) self.fc2.bias = torch.nn.Parameter(new_fc2_bias) def build_self_attention(self, embed_dim, cfg): return MultiheadAttention( embed_dim, cfg.encoder.attention_heads, dropout=cfg.attention_dropout, self_attention=True, q_noise=self.quant_noise, qn_block_size=self.quant_noise_block_size, xformers_att_config=cfg.encoder.xformers_att_config, ) def residual_connection(self, x, residual): return residual + x def upgrade_state_dict_named(self, state_dict, name): """ Rename layer norm states from `...layer_norms.0.weight` to `...self_attn_layer_norm.weight` and `...layer_norms.1.weight` to `...final_layer_norm.weight` """ layer_norm_map = {"0": "self_attn_layer_norm", "1": "final_layer_norm"} for old, new in layer_norm_map.items(): for m in ("weight", "bias"): k = "{}.layer_norms.{}.{}".format(name, old, m) if k in state_dict: state_dict["{}.{}.{}".format(name, new, m)] = state_dict[k] del state_dict[k] def forward( self, x, encoder_padding_mask: Optional[Tensor], attn_mask: Optional[Tensor] = None, ): """ Args: x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)` encoder_padding_mask (ByteTensor): binary ByteTensor of shape `(batch, seq_len)` where padding elements are indicated by ``1``. attn_mask (ByteTensor): binary tensor of shape `(tgt_len, src_len)`, where `tgt_len` is the length of output and `src_len` is the length of input, though here both are equal to `seq_len`. `attn_mask[tgt_i, src_j] = 1` means that when calculating the embedding for `tgt_i`, we exclude (mask out) `src_j`. This is useful for strided self-attention. Returns: encoded output of shape `(seq_len, batch, embed_dim)` """ # anything in original attn_mask = 1, becomes -1e8 # anything in original attn_mask = 0, becomes 0 # Note that we cannot use -inf here, because at some edge cases, # the attention weight (before softmax) for some padded element in query # will become -inf, which results in NaN in model parameters if self.training: self.ever_training = True if ( self.BT_version and x.dim() == 3 and self.load_to_BT and not self.return_fc and self.can_use_fastpath and not self.training and not self.ever_training and not self.cfg_checkpoint_activations ): # assume is Batch first and nested tensor output = torch._transformer_encoder_layer_fwd( x, self.embed_dim, self.num_heads, self.in_proj_weight, self.in_proj_bias, self.out_proj_weight, self.out_proj_bias, self.activation_relu_or_gelu == 2, False, # norm_first, currently not supported self.self_attn_layer_norm.eps, self.self_attn_layer_norm.weight, self.self_attn_layer_norm.bias, self.final_layer_norm.weight, self.final_layer_norm.bias, self.fc1_weight, self.fc1_bias, self.fc2_weight, self.fc2_bias, encoder_padding_mask if encoder_padding_mask is not None else attn_mask, ) return output else: if attn_mask is not None: attn_mask = attn_mask.masked_fill( attn_mask.to(torch.bool), -1e8 if x.dtype == torch.float32 else -1e4 ) residual = x if self.normalize_before: x = self.self_attn_layer_norm(x) x, _ = self.self_attn( query=x, key=x, value=x, key_padding_mask=encoder_padding_mask, need_weights=False, attn_mask=attn_mask, ) x = self.dropout_module(x) x = self.residual_connection(x, residual) if not self.normalize_before: x = self.self_attn_layer_norm(x) residual = x if self.normalize_before: x = self.final_layer_norm(x) x = self.activation_fn(self.fc1(x)) x = self.activation_dropout_module(x) x = self.fc2(x) fc_result = x x = self.dropout_module(x) x = self.residual_connection(x, residual) if not self.normalize_before: x = self.final_layer_norm(x) if self.return_fc and not torch.jit.is_scripting(): return x, fc_result return x def safe_getattr(obj, k, default=None): """Returns obj[k] if it exists and is not None, otherwise returns default.""" from omegaconf import OmegaConf if OmegaConf.is_config(obj): return obj[k] if k in obj and obj[k] is not None else default return getattr(obj, k, default) class TransformerDecoderLayerBase(nn.Module): """Decoder layer block. In the original paper each operation (multi-head attention, encoder attention or FFN) is postprocessed with: `dropout -> add residual -> layernorm`. In the tensor2tensor code they suggest that learning is more robust when preprocessing each layer with layernorm and postprocessing with: `dropout -> add residual`. We default to the approach in the paper, but the tensor2tensor approach can be enabled by setting *cfg.decoder.normalize_before* to ``True``. Args: args (argparse.Namespace): parsed command-line arguments no_encoder_attn (bool, optional): whether to attend to encoder outputs (default: False). """ def __init__( self, cfg, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False ): #embed_dim, num_heads, ff_dim, dropout super().__init__() self.embed_dim = cfg.decoder.embed_dim self.dropout_module = FairseqDropout( cfg.dropout, module_name=self.__class__.__name__ ) self.quant_noise = cfg.quant_noise.pq self.quant_noise_block_size = cfg.quant_noise.pq_block_size self.cross_self_attention = cfg.cross_self_attention self.self_attn = self.build_self_attention( self.embed_dim, cfg, add_bias_kv=add_bias_kv, add_zero_attn=add_zero_attn, ) self.attn_ln = ( LayerNorm(self.embed_dim) if safe_getattr(cfg, "scale_attn", False) else None ) self.nh = self.self_attn.num_heads self.head_dim = self.self_attn.head_dim scale_heads = safe_getattr(cfg, "scale_heads", False) self.c_attn = ( nn.Parameter(torch.ones((self.nh,)), requires_grad=True) if scale_heads else None ) self.activation_fn = get_activation_fn(activation=cfg.activation_fn) activation_dropout_p = cfg.activation_dropout if activation_dropout_p == 0: # for backwards compatibility with models that use cfg.relu_dropout activation_dropout_p = cfg.relu_dropout or 0 self.activation_dropout_module = FairseqDropout( float(activation_dropout_p), module_name=self.__class__.__name__ ) self.normalize_before = cfg.decoder.normalize_before self.self_attn_layer_norm = LayerNorm(self.embed_dim, export=cfg.export) if no_encoder_attn: self.encoder_attn = None self.encoder_attn_layer_norm = None else: self.encoder_attn = self.build_encoder_attention(self.embed_dim, cfg) self.encoder_attn_layer_norm = LayerNorm(self.embed_dim, export=cfg.export) self.ffn_layernorm = ( LayerNorm(cfg.decoder.ffn_embed_dim) if safe_getattr(cfg, "scale_fc", False) else None ) self.w_resid = ( nn.Parameter( torch.ones( self.embed_dim, ), requires_grad=True, ) if safe_getattr(cfg, "scale_resids", False) else None ) self.fc1 = self.build_fc1( self.embed_dim, cfg.decoder.ffn_embed_dim, self.quant_noise, self.quant_noise_block_size, ) self.fc2 = self.build_fc2( cfg.decoder.ffn_embed_dim, self.embed_dim, self.quant_noise, self.quant_noise_block_size, ) self.final_layer_norm = LayerNorm(self.embed_dim, export=cfg.export) self.need_attn = True self.onnx_trace = False def build_fc1(self, input_dim, output_dim, q_noise, qn_block_size): return quant_noise(nn.Linear(input_dim, output_dim), q_noise, qn_block_size) def build_fc2(self, input_dim, output_dim, q_noise, qn_block_size): return quant_noise(nn.Linear(input_dim, output_dim), q_noise, qn_block_size) def build_self_attention( self, embed_dim, cfg, add_bias_kv=False, add_zero_attn=False ): return MultiheadAttention( embed_dim, cfg.decoder.attention_heads, dropout=cfg.attention_dropout, add_bias_kv=add_bias_kv, add_zero_attn=add_zero_attn, self_attention=not cfg.cross_self_attention, q_noise=self.quant_noise, qn_block_size=self.quant_noise_block_size, xformers_att_config=cfg.decoder.xformers_att_config, ) def build_encoder_attention(self, embed_dim, cfg): return MultiheadAttention( embed_dim, cfg.decoder.attention_heads, kdim=cfg.encoder.embed_dim, vdim=cfg.encoder.embed_dim, dropout=cfg.attention_dropout, encoder_decoder_attention=True, q_noise=self.quant_noise, qn_block_size=self.quant_noise_block_size, xformers_att_config=cfg.encoder.xformers_att_config, ) def prepare_for_onnx_export_(self): self.onnx_trace = True def residual_connection(self, x, residual): return residual + x def forward( self, x, encoder_out: Optional[torch.Tensor] = None, encoder_padding_mask: Optional[torch.Tensor] = None, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, prev_self_attn_state: Optional[List[torch.Tensor]] = None, prev_attn_state: Optional[List[torch.Tensor]] = None, self_attn_mask: Optional[torch.Tensor] = None, self_attn_padding_mask: Optional[torch.Tensor] = None, need_attn: bool = False, need_head_weights: bool = False, ): """ Args: x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)` encoder_padding_mask (ByteTensor, optional): binary ByteTensor of shape `(batch, src_len)` where padding elements are indicated by ``1``. need_attn (bool, optional): return attention weights need_head_weights (bool, optional): return attention weights for each head (default: return average over heads). Returns: encoded output of shape `(seq_len, batch, embed_dim)` """ if need_head_weights: need_attn = True residual = x if self.normalize_before: x = self.self_attn_layer_norm(x) if prev_self_attn_state is not None: prev_key, prev_value = prev_self_attn_state[:2] saved_state: Dict[str, Optional[Tensor]] = { "prev_key": prev_key, "prev_value": prev_value, } if len(prev_self_attn_state) >= 3: saved_state["prev_key_padding_mask"] = prev_self_attn_state[2] assert incremental_state is not None self.self_attn._set_input_buffer(incremental_state, saved_state) _self_attn_input_buffer = self.self_attn._get_input_buffer(incremental_state) if self.cross_self_attention and not ( incremental_state is not None and _self_attn_input_buffer is not None and "prev_key" in _self_attn_input_buffer ): if self_attn_mask is not None: assert encoder_out is not None self_attn_mask = torch.cat( (x.new_zeros(x.size(0), encoder_out.size(0)), self_attn_mask), dim=1 ) if self_attn_padding_mask is not None: if encoder_padding_mask is None: assert encoder_out is not None encoder_padding_mask = self_attn_padding_mask.new_zeros( encoder_out.size(1), encoder_out.size(0) ) self_attn_padding_mask = torch.cat( (encoder_padding_mask, self_attn_padding_mask), dim=1 ) assert encoder_out is not None y = torch.cat((encoder_out, x), dim=0) else: y = x x, attn = self.self_attn( query=x, key=y, value=y, key_padding_mask=self_attn_padding_mask, incremental_state=incremental_state, need_weights=False, attn_mask=self_attn_mask, ) if self.c_attn is not None: tgt_len, bsz = x.size(0), x.size(1) x = x.view(tgt_len, bsz, self.nh, self.head_dim) x = torch.einsum("tbhd,h->tbhd", x, self.c_attn) x = x.reshape(tgt_len, bsz, self.embed_dim) if self.attn_ln is not None: x = self.attn_ln(x) x = self.dropout_module(x) x = self.residual_connection(x, residual) if not self.normalize_before: x = self.self_attn_layer_norm(x) if self.encoder_attn is not None and encoder_out is not None: residual = x if self.normalize_before: x = self.encoder_attn_layer_norm(x) if prev_attn_state is not None: prev_key, prev_value = prev_attn_state[:2] saved_state: Dict[str, Optional[Tensor]] = { "prev_key": prev_key, "prev_value": prev_value, } if len(prev_attn_state) >= 3: saved_state["prev_key_padding_mask"] = prev_attn_state[2] assert incremental_state is not None self.encoder_attn._set_input_buffer(incremental_state, saved_state) x, attn = self.encoder_attn( query=x, key=encoder_out, value=encoder_out, key_padding_mask=encoder_padding_mask, incremental_state=incremental_state, static_kv=True, need_weights=need_attn or (not self.training and self.need_attn), need_head_weights=need_head_weights, ) x = self.dropout_module(x) x = self.residual_connection(x, residual) if not self.normalize_before: x = self.encoder_attn_layer_norm(x) residual = x if self.normalize_before: x = self.final_layer_norm(x) x = self.activation_fn(self.fc1(x)) x = self.activation_dropout_module(x) if self.ffn_layernorm is not None: x = self.ffn_layernorm(x) x = self.fc2(x) x = self.dropout_module(x) if self.w_resid is not None: residual = torch.mul(self.w_resid, residual) x = self.residual_connection(x, residual) if not self.normalize_before: x = self.final_layer_norm(x) if self.onnx_trace and incremental_state is not None: saved_state = self.self_attn._get_input_buffer(incremental_state) assert saved_state is not None if self_attn_padding_mask is not None: self_attn_state = [ saved_state["prev_key"], saved_state["prev_value"], saved_state["prev_key_padding_mask"], ] else: self_attn_state = [saved_state["prev_key"], saved_state["prev_value"]] return x, attn, self_attn_state return x, attn, None def make_generation_fast_(self, need_attn: bool = False, **kwargs): self.need_attn = need_attn import torch import torch.nn as nn import math from typing import Optional, Dict, List, Any from torch import Tensor def make_positions(tensor, padding_idx: int, onnx_trace: bool = False): """Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols are ignored. """ # The series of casts and type-conversions here are carefully # balanced to both work with ONNX export and XLA. In particular XLA # prefers ints, cumsum defaults to output longs, and ONNX doesn't know # how to handle the dtype kwarg in cumsum. mask = tensor.ne(padding_idx).int() return (torch.cumsum(mask, dim=1).type_as(mask) * mask).long() + padding_idx class SinusoidalPositionalEmbedding(nn.Module): """This module produces sinusoidal positional embeddings of any length. Padding symbols are ignored. """ def __init__(self, embedding_dim, padding_idx, init_size=1024): super().__init__() self.embedding_dim = embedding_dim self.padding_idx = padding_idx if padding_idx is not None else 0 self.weights = SinusoidalPositionalEmbedding.get_embedding( init_size, embedding_dim, padding_idx ) self.onnx_trace = False self.register_buffer("_float_tensor", torch.FloatTensor(1)) self.max_positions = int(1e5) def prepare_for_onnx_export_(self): self.onnx_trace = True @staticmethod def get_embedding( num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None ): """Build sinusoidal embeddings. This matches the implementation in tensor2tensor, but differs slightly from the description in Section 3.5 of "Attention Is All You Need". """ half_dim = embedding_dim // 2 emb = math.log(10000) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb) emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze( 1 ) * emb.unsqueeze(0) emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view( num_embeddings, -1 ) if embedding_dim % 2 == 1: # zero pad emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1) if padding_idx is not None: emb[padding_idx, :] = 0 return emb def forward( self, input, incremental_state: Optional[Any] = None, timestep: Optional[Tensor] = None, positions: Optional[Any] = None, ): """Input is expected to be of size [bsz x seqlen].""" bspair = torch.onnx.operators.shape_as_tensor(input) bsz, seq_len = bspair[0], bspair[1] max_pos = self.padding_idx + 1 + seq_len if self.weights is None or max_pos > self.weights.size(0): # recompute/expand embeddings if needed self.weights = SinusoidalPositionalEmbedding.get_embedding( max_pos, self.embedding_dim, self.padding_idx ) self.weights = self.weights.to(self._float_tensor) if incremental_state is not None: # positions is the same for every token when decoding a single step pos = timestep.view(-1)[0] + 1 if timestep is not None else seq_len if self.onnx_trace: return ( self.weights.index_select(index=self.padding_idx + pos, dim=0) .unsqueeze(1) .repeat(bsz, 1, 1) ) return self.weights[self.padding_idx + pos, :].expand(bsz, 1, -1) positions = make_positions( input, self.padding_idx, onnx_trace=self.onnx_trace ) if self.onnx_trace: flat_embeddings = self.weights.detach().index_select(0, positions.view(-1)) embedding_shape = torch.cat( (bsz.view(1), seq_len.view(1), torch.tensor([-1], dtype=torch.long)) ) embeddings = torch.onnx.operators.reshape_from_tensor_shape( flat_embeddings, embedding_shape ) return embeddings return ( self.weights.index_select(0, positions.view(-1)) .view(bsz, seq_len, -1) .detach() ) class TransformerEncoderBase(nn.Module): def __init__(self, cfg, dictionary, embed_tokens, return_fc=False): super().__init__() self.cfg = cfg self.dictionary = dictionary self.return_fc = return_fc self.register_buffer('version', torch.Tensor([3])) self.dropout_module = FairseqDropout(cfg.dropout) self.encoder_layerdrop = cfg.encoder.layerdrop embed_dim = embed_tokens.embedding_dim self.padding_idx = embed_tokens.padding_idx self.max_source_positions = cfg.max_source_positions self.embed_tokens = embed_tokens self.embed_scale = 1.0 if cfg.no_scale_embedding else math.sqrt(embed_dim) self.embed_positions = ( SinusoidalPositionalEmbedding( embed_dim, self.padding_idx, cfg.max_source_positions + self.padding_idx + 1 ) if not cfg.no_token_positional_embeddings else None ) # self.layernorm_embedding = ( # nn.LayerNorm(embed_dim) if cfg.layernorm_embedding else None # ) if cfg.layernorm_embedding: self.layernorm_embedding = LayerNorm(embed_dim, export=cfg.export) else: self.layernorm_embedding = None if not cfg.adaptive_input and cfg.quant_noise.pq > 0: self.quant_noise = quant_noise( nn.Linear(embed_dim, embed_dim, bias=False), cfg.quant_noise.pq, cfg.quant_noise.pq_block_size, ) else: self.quant_noise = None if self.encoder_layerdrop > 0.0: self.layers = LayerDropModuleList(p=self.encoder_layerdrop) else: self.layers = nn.ModuleList([]) self.layers.extend( [self.build_encoder_layer(cfg) for i in range(cfg.encoder.layers)] ) self.num_layers = len(self.layers) if cfg.encoder.normalize_before: self.layer_norm = LayerNorm(embed_dim, export=cfg.export) else: self.layer_norm = None def build_encoder_layer(self, cfg): layer = TransformerEncoderLayerBase( cfg, return_fc=self.return_fc ) checkpoint = cfg.checkpoint_activations # if checkpoint: # offload_to_cpu = cfg.offload_activations # layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu) # if we are checkpointing, enforce that FSDP always wraps the # checkpointed layer, regardless of layer size min_params_to_wrap = cfg.min_params_to_wrap if not checkpoint else 0 # layer = fsdp_wrap(layer, min_num_params=min_params_to_wrap) return layer def forward_embedding( self, src_tokens, token_embedding: Optional[torch.Tensor] = None): # embed tokens and positions if token_embedding is None: token_embedding = self.embed_tokens(src_tokens) x = embed = self.embed_scale * token_embedding if self.embed_positions is not None: x = embed + self.embed_positions(src_tokens) if self.layernorm_embedding is not None: x = self.layernorm_embedding(x) x = self.dropout_module(x) if self.quant_noise is not None: x = self.quant_noise(x) return x, embed def max_positions(self): """Maximum input length supported by the encoder.""" if self.embed_positions is None: return self.max_source_positions return min(self.max_source_positions, self.embed_positions.max_positions) def forward(self, src_tokens, src_lengths: Optional[torch.Tensor] = None, token_embeddings: Optional[torch.Tensor] = None, return_all_hiddens: bool = False): encoder_padding_mask = src_tokens.eq(self.padding_idx) # encoder_padding_mask = src_tokens.device.type == "xla" or encoder_padding_mask.any() has_pads = src_tokens.device.type == "xla" or encoder_padding_mask.any() x, encoder_embedding = self.forward_embedding(src_tokens) if has_pads: x = x * (1 - encoder_padding_mask.unsqueeze(-1).type_as(x)) x = x.transpose(0, 1) # B x T x C -> T x B x C encoder_states = [] if return_all_hiddens else None fc_results = [] if return_all_hiddens: encoder_states.append(x) encoder_padding_mask = encoder_padding_mask if has_pads else None for layer in self.layers: x = layer(x, encoder_padding_mask = encoder_padding_mask) if isinstance(x, tuple) and len(x) ==2: x, fc_result = x else: fc_result = None if return_all_hiddens: assert encoder_states is not None encoder_states.append(x) fc_results.append(fc_result) if self.layer_norm is not None: x = self.layer_norm(x) src_lengths = ( src_tokens.ne(self.padding_idx) .sum(dim=1, dtype=torch.int32) .reshape(-1, 1) .contiguous() ) return { "encoder_out": [x], # T x B x C "encoder_padding_mask": [encoder_padding_mask], # B x T "encoder_embedding": [encoder_embedding], # B x T x C "encoder_states": encoder_states, # List[T x B x C] "fc_results": fc_results, # List[T x B x C] "src_tokens": [], "src_lengths": [src_lengths], } import torch.nn as nn import torch import sys import torch.distributed as dist # from fairseq import utils # from fairseq.distributed import utils as distributed_utils # from fairseq.modules.layer_norm import LayerNorm _MODEL_PARALLEL_GROUP = None # Data parallel group that the current rank belongs to. _DATA_PARALLEL_GROUP = None _USE_XLA = False def use_xla(): global _USE_XLA return _USE_XLA def get_world_size(group): if use_xla(): assert group[0] == "tpu" my_group = _find_my_group(group[1]) return len(my_group) elif torch.distributed.is_initialized(): return dist.get_world_size(group=group) else: return 1 def get_global_world_size(): if use_xla(): return xm.xrt_world_size() elif torch.distributed.is_initialized(): return torch.distributed.get_world_size() else: return 1 def get_global_rank(): if use_xla(): return xm.get_ordinal() elif torch.distributed.is_initialized(): return torch.distributed.get_rank() else: return 0 def new_groups(grouped_ranks: List[List[int]]): if use_xla(): return ("tpu", grouped_ranks) else: groups = [dist.new_group(g) for g in grouped_ranks] my_group_idx = _find_my_group_index(grouped_ranks) return groups[my_group_idx] def get_global_group(): if use_xla(): return new_groups([list(range(get_global_world_size()))]) elif torch.distributed.is_initialized(): if not hasattr(get_global_group, "_global_group"): # ideally we could use torch.distributed.group.WORLD, but it seems # to cause random NCCL hangs in some cases get_global_group._global_group = dist.new_group() return get_global_group._global_group else: return None def get_global_group(): if use_xla(): return new_groups([list(range(get_global_world_size()))]) elif torch.distributed.is_initialized(): if not hasattr(get_global_group, "_global_group"): # ideally we could use torch.distributed.group.WORLD, but it seems # to cause random NCCL hangs in some cases get_global_group._global_group = dist.new_group() return get_global_group._global_group else: return None def _find_my_group_index(grouped_ranks): my_rank = get_global_rank() for i, group in enumerate(grouped_ranks): if my_rank in group: return i raise RuntimeError def _find_my_group(grouped_ranks): index = _find_my_group_index(grouped_ranks) return grouped_ranks[index] def get_global_group(): if use_xla(): return new_groups([list(range(get_global_world_size()))]) elif torch.distributed.is_initialized(): if not hasattr(get_global_group, "_global_group"): # ideally we could use torch.distributed.group.WORLD, but it seems # to cause random NCCL hangs in some cases get_global_group._global_group = dist.new_group() return get_global_group._global_group else: return None def get_world_size(group): if use_xla(): assert group[0] == "tpu" my_group = _find_my_group(group[1]) return len(my_group) elif torch.distributed.is_initialized(): return dist.get_world_size(group=group) else: return 1 def get_rank(group): if use_xla(): assert group[0] == "tpu" my_group = _find_my_group(group[1]) return my_group.index(get_global_rank()) else: return dist.get_rank(group=group) def mpu_get_data_parallel_group(): """Get the data parallel group the caller rank belongs to.""" assert _DATA_PARALLEL_GROUP is not None, \ 'data parallel group is not initialized' return _DATA_PARALLEL_GROUP def get_data_parallel_group(): """Get the data parallel group the caller rank belongs to.""" global _USE_MEGATRON if _USE_MEGATRON: return mpu_get_data_parallel_group() else: return get_global_group() def get_data_parallel_rank(): """Return my rank for the data parallel group.""" return get_rank(get_data_parallel_group()) def get_data_parallel_world_size(): """Return world size for the data parallel group.""" return get_world_size(get_data_parallel_group()) class BaseSublayer(nn.Module): def __init__(self, args): super().__init__() self.activation_fn = get_activation_fn( activation=getattr(args, "activation_fn", "relu") or "relu" ) self.norm = LayerNorm(args.decoder_embed_dim, export=False) self.ff1 = torch.nn.Linear(args.decoder_embed_dim, args.decoder_ffn_embed_dim) self.ff2 = torch.nn.Linear(args.decoder_ffn_embed_dim, args.decoder_embed_dim) self.ff2.weight.data.zero_() def forward(self, xs): return xs + self.ff2(self.activation_fn(self.ff1(self.norm(xs)))) class BaseLayer(nn.Module): def __init__(self, args): super().__init__() self.num_workers = get_data_parallel_world_size() expert_centroids = torch.empty(self.num_workers, args.decoder_embed_dim) torch.nn.init.orthogonal_(expert_centroids, gain=0.1) self.register_parameter( "expert_centroids", torch.nn.Parameter(expert_centroids) ) self.expert_network = nn.Sequential( *([BaseSublayer(args) for _ in range(args.base_sublayers)]) ) self.expert_id = get_data_parallel_rank() self.shuffle = args.base_shuffle self.cpp = self.load_assignment() # Add a special attribute to the expert parameters, so we know not to sync their gradients for param in self.expert_network.parameters(): param.expert = True def forward(self, input_features, *args, **kwargs): features = input_features.reshape(-1, input_features.size(-1)) is_training = input_features.requires_grad if self.shuffle and is_training: # Send each token to a random worker, to break correlations within the batch shuffle_sort = torch.randperm(features.size(0), device=features.device) features = All2All.apply(features[shuffle_sort]) with torch.no_grad(): # Compute similarity of each token to each expert, for routing token_expert_affinities = features.matmul( self.expert_centroids.transpose(0, 1) ) # Compute which token goes to which expert sort_by_expert, input_splits, output_splits = ( self.balanced_assignment(token_expert_affinities) if is_training else self.greedy_assignment(token_expert_affinities) ) # Swap these tokens for the right ones for our expert routed_features = All2All.apply( features[sort_by_expert], output_splits, input_splits ) if routed_features.size(0) > 0: # Mix in the expert network based on how appropriate it is for these tokens alpha = torch.sigmoid( routed_features.mv(self.expert_centroids[self.expert_id]) ).unsqueeze(1) routed_features = ( alpha * self.expert_network(routed_features) + (1 - alpha) * routed_features ) # Return to original worker and ordering result = All2All.apply(routed_features, input_splits, output_splits)[ self.inverse_sort(sort_by_expert) ] if self.shuffle and is_training: # Undo shuffling result = All2All.apply(result)[self.inverse_sort(shuffle_sort)] # Return additional Nones for compatibility with TransformerDecoderLayer return result.view(input_features.size()), None, None def inverse_sort(self, order): # Creates an index that undoes a sort: xs==xs[order][inverse_sort(order)] return torch.empty_like(order).scatter_( 0, order, torch.arange(0, order.size(0), device=order.device) ) def balanced_assignment(self, scores): ok = scores.isfinite() if not ok.all(): # NaNs here can break the assignment algorithm scores[~ok] = scores[ok].min() return self.cpp.balanced_assignment(scores), None, None # Assigns each token to the top k experts def greedy_assignment(self, scores, k=1): token_to_workers = torch.topk(scores, dim=1, k=k, largest=True).indices.view(-1) token_to_workers, sort_ordering = torch.sort(token_to_workers) worker2token = sort_ordering // k # Find how many tokens we're sending to each other worker (being careful for sending 0 tokens to some workers) output_splits = torch.zeros( (self.num_workers,), dtype=torch.long, device=scores.device ) workers, counts = torch.unique_consecutive(token_to_workers, return_counts=True) output_splits[workers] = counts # Tell other workers how many tokens to expect from us input_splits = All2All.apply(output_splits) return worker2token, input_splits.tolist(), output_splits.tolist() def load_assignment(self): try: from fairseq import libbase return libbase except ImportError as e: sys.stderr.write( "ERROR: missing libbase. run `python setup.py build_ext --inplace`\n" ) raise e class TransformerDecoderBase(nn.Module): """ Transformer decoder implemented using PyTorch's nn.Module. Args: vocab_size (int): Size of the vocabulary. embed_dim (int): Dimension of the embeddings. num_layers (int): Number of Transformer decoder layers. num_heads (int): Number of attention heads. ff_dim (int): Dimension of feed-forward layers. dropout (float): Dropout probability. max_target_positions (int): Maximum target sequence length. padding_idx (int): Index for the padding token. share_input_output_embed (bool): Whether to share input/output embeddings. """ def __init__( self, cfg, dictionary, embed_tokens, no_encoder_attn=False, output_projection=None, ): super().__init__() self.register_buffer("version", torch.Tensor([3])) self._future_mask = torch.empty(0) ################ self.dropout_module = FairseqDropout( cfg.dropout, module_name="TransformerDecoder") self.decoder_layerdrop = cfg.decoder.layerdrop self.share_input_output_embed = cfg.share_decoder_input_output_embed input_embed_dim = embed_tokens.embedding_dim embed_dim = cfg.decoder.embed_dim self.embed_dim = embed_dim self.output_embed_dim = cfg.decoder.output_dim self.padding_idx = embed_tokens.padding_idx self.max_target_positions = cfg.max_target_positions self.embed_tokens = embed_tokens self.embed_scale = 1.0 if cfg.no_scale_embedding else math.sqrt( embed_dim) if cfg.quant_noise.pq > 0: self.quant_noise = quant_noise( nn.Linear(embed_dim, embed_dim, bias=False), cfg.quant_noise.pq, cfg.quant_noise.pq_block_size, ) else: self.quant_noise = None self.project_in_dim = ( nn.Linear(input_embed_dim, embed_dim, bias=False) if embed_dim != input_embed_dim else None ) self.embed_positions = ( SinusoidalPositionalEmbedding( embed_dim, self.padding_idx, cfg.max_target_positions + self.padding_idx + 1 ) if not cfg.no_token_positional_embeddings else None ) if cfg.layernorm_embedding: self.layernorm_embedding = LayerNorm(embed_dim, export=cfg.export) else: self.layernorm_embedding = None self.cross_self_attention = cfg.cross_self_attention if self.decoder_layerdrop > 0.0: self.layers = LayerDropModuleList(p=self.decoder_layerdrop) else: self.layers = nn.ModuleList([]) self.layers.extend( [ self.build_decoder_layer(cfg, no_encoder_attn) for _ in range(cfg.decoder.layers) ] ) self.num_layers = len(self.layers) if cfg.decoder.normalize_before and not cfg.no_decoder_final_norm: self.layer_norm = LayerNorm(embed_dim, export=cfg.export) else: self.layer_norm = None self.project_out_dim = ( nn.Linear(embed_dim, self.output_embed_dim, bias=False) if embed_dim != self.output_embed_dim and not cfg.tie_adaptive_weights else None ) self.adaptive_softmax = None self.output_projection = output_projection if self.output_projection is None: self.build_output_projection(cfg, dictionary, embed_tokens) ################ def build_output_projection(self, cfg, dictionary, embed_tokens): if self.share_input_output_embed: self.output_projection = nn.Linear( self.embed_tokens.weight.shape[1], self.embed_tokens.weight.shape[0], bias=False, ) self.output_projection.weight = self.embed_tokens.weight else: self.output_projection = nn.Linear( self.output_embed_dim, len(dictionary), bias=False ) nn.init.normal_( self.output_projection.weight, mean=0, std=self.output_embed_dim**-0.5 ) num_base_layers = cfg.base_layers for i in range(num_base_layers): self.layers.insert( ((i + 1) * cfg.decoder.layers) // (num_base_layers + 1), BaseLayer(cfg), ) def build_decoder_layer(self, cfg, no_encoder_attn=False): layer = TransformerDecoderLayerBase(cfg, no_encoder_attn) checkpoint = cfg.checkpoint_activations if checkpoint: offload_to_cpu = cfg.offload_activations # layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu) # if we are checkpointing, enforce that FSDP always wraps the # checkpointed layer, regardless of layer size min_params_to_wrap = cfg.min_params_to_wrap if not checkpoint else 0 # layer = fsdp_wrap(layer, min_num_params=min_params_to_wrap) return layer def forward( self, prev_output_tokens: Tensor, encoder_out: Optional[Tensor] = None, src_padding_mask: Optional[Tensor] = None, src_lengths: Optional[Any] = None, return_all_hiddens: bool = False, features_only: bool = False, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, full_context_alignment: bool = False, alignment_layer: Optional[int] = None, alignment_heads: Optional[int] = None, ): """ Args: prev_output_tokens (Tensor): Previous output tokens of shape (batch, tgt_len). encoder_out (Tensor, optional): Encoder outputs (batch, src_len, embed_dim). src_padding_mask (Tensor, optional): Padding mask for the encoder inputs. Returns: Tensor: Decoder output of shape (batch, tgt_len, vocab_size). """ bs, slen = prev_output_tokens.size() if alignment_layer is None: alignment_layer = self.num_layers - 1 enc: Optional[Tensor] = None padding_mask: Optional[Tensor] = None if encoder_out is not None and len(encoder_out["encoder_out"]) > 0: enc = encoder_out["encoder_out"][0] if encoder_out is not None and len(encoder_out["encoder_padding_mask"]) > 0: padding_mask = encoder_out["encoder_padding_mask"][0] # embed positions positions = None if self.embed_positions is not None: positions = self.embed_positions( prev_output_tokens, incremental_state=incremental_state ) if incremental_state is not None: prev_output_tokens = prev_output_tokens[:, -1:] if positions is not None: positions = positions[:, -1:] # Prevent torchscript exporting issue for dynamic quant embedding prev_output_tokens = prev_output_tokens.contiguous() # embed tokens and positions x = self.embed_scale * self.embed_tokens(prev_output_tokens) if self.quant_noise is not None: x = self.quant_noise(x) if self.project_in_dim is not None: x = self.project_in_dim(x) if positions is not None: x += positions if self.layernorm_embedding is not None: x = self.layernorm_embedding(x) x = self.dropout_module(x) # B x T x C -> T x B x C x = x.transpose(0, 1) self_attn_padding_mask: Optional[Tensor] = None if self.cross_self_attention or prev_output_tokens.eq(self.padding_idx).any(): self_attn_padding_mask = prev_output_tokens.eq(self.padding_idx) # Embed tokens and positions # positions = torch.arange(prev_output_tokens.size(1), device=prev_output_tokens.device).unsqueeze(0) # x = self.embed_tokens(prev_output_tokens) + self.embed_positions(positions) # x = self.dropout(x) # decoder layers attn: Optional[Tensor] = None inner_states: List[Optional[Tensor]] = [x] for idx, layer in enumerate(self.layers): if incremental_state is None and not full_context_alignment: self_attn_mask = self.buffered_future_mask(x) else: self_attn_mask = None x, layer_attn, _ = layer( x, enc, padding_mask, incremental_state, self_attn_mask=self_attn_mask, self_attn_padding_mask=self_attn_padding_mask, need_attn=bool((idx == alignment_layer)), need_head_weights=bool((idx == alignment_layer)), ) inner_states.append(x) if layer_attn is not None and idx == alignment_layer: attn = layer_attn.float().to(x) if attn is not None: if alignment_heads is not None: attn = attn[:alignment_heads] # average probabilities over heads attn = attn.mean(dim=0) if self.layer_norm is not None: x = self.layer_norm(x) # T x B x C -> B x T x C x = x.transpose(0, 1) if self.project_out_dim is not None: x = self.project_out_dim(x) if not features_only: x = self.output_layer(x) return x, {"attn": [attn], "inner_states": inner_states} def output_layer(self, features): """Project features to the vocabulary size.""" if self.adaptive_softmax is None: # project back to size of vocabulary return self.output_projection(features) else: return features def max_positions(self): """Maximum output length supported by the decoder.""" if self.embed_positions is None: return self.max_target_positions return min(self.max_target_positions, self.embed_positions.max_positions) def fill_with_neg_inf(self, t): """FP16-compatible function that fills a tensor with -inf.""" return t.float().fill_(float("-inf")).type_as(t) def buffered_future_mask(self, tensor): dim = tensor.size(0) # self._future_mask.device != tensor.device is not working in TorchScript. This is a workaround. if ( self._future_mask.size(0) == 0 or (not self._future_mask.device == tensor.device) or self._future_mask.size(0) < dim ): self._future_mask = torch.triu( self.fill_with_neg_inf(torch.zeros([dim, dim])), 1 ) self._future_mask = self._future_mask.to(tensor) return self._future_mask[:dim, :dim] class FairseqIncrementalState(object): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.init_incremental_state() def init_incremental_state(self): self._incremental_state_id = str(uuid.uuid4()) def _get_full_incremental_state_key(self, key: str) -> str: return "{}.{}".format(self._incremental_state_id, key) def get_incremental_state( self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], key: str, ) -> Optional[Dict[str, Optional[Tensor]]]: """Helper for getting incremental state for an nn.Module.""" full_key = self._get_full_incremental_state_key(key) if incremental_state is None or full_key not in incremental_state: return None return incremental_state[full_key] def set_incremental_state( self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], key: str, value: Dict[str, Optional[Tensor]], ) -> Optional[Dict[str, Dict[str, Optional[Tensor]]]]: """Helper for setting incremental state for an nn.Module.""" if incremental_state is not None: full_key = self._get_full_incremental_state_key(key) incremental_state[full_key] = value return incremental_state def with_incremental_state(cls): cls.__bases__ = (FairseqIncrementalState,) + tuple( b for b in cls.__bases__ if b != FairseqIncrementalState ) return cls def eval_str_dict(x, type=dict): if x is None: return None if isinstance(x, str): x = eval(x) return x def softmax(x, dim: int, onnx_trace: bool = False): if onnx_trace: return F.softmax(x.float(), dim=dim) else: return F.softmax(x, dim=dim, dtype=torch.float32) # Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import math from typing import Dict, List, Optional, Tuple import torch import torch.nn.functional as F from torch import Tensor, nn from torch.nn import Parameter try: from xformers.components.attention import build_attention from xformers.components.attention.utils import maybe_merge_masks _xformers_available = True except ImportError: _xformers_available = False # TODO: move this into xformers? # TODO: uint8 input type should just output a bool def _mask_for_xformers(mask: Tensor, to_dtype: Optional[torch.dtype] = None): """ call to pytorch multihead accepts three mask types: - ByteTensor where non-zero means to mask - FloatTensor which is an additive mask - BoolTensor where True means to mask xFormers currently accepts boolean and additive maks. For boolean masks the values have opposite meaning. For a BoolTensor True mean to keep the value. """ float_types = [torch.float, torch.float16] # If an input mask is a float it is an additive mask. Otherwise it is either uint8 or bool. additive = mask.dtype in float_types # If to_dype is not specified, keep same dtype as mask. to_dtype = mask.dtype if to_dtype is None else to_dtype to_additive = to_dtype in float_types if additive: if to_additive: return mask.to(to_dtype) mask = mask < 0 if to_additive: # return additive mask new_mask = torch.zeros_like(mask, dtype=to_dtype) new_mask = new_mask.masked_fill_(mask, -float("inf")) return new_mask # In xFormers True is value to keep rather than value to mask mask = ~mask.to(torch.bool) mask = mask.to(to_dtype) return mask def softmax(x, dim: int, onnx_trace: bool = False): if onnx_trace: return F.softmax(x.float(), dim=dim) else: return F.softmax(x, dim=dim, dtype=torch.float32) def eval_str_dict(x, type=dict): if x is None: return None if isinstance(x, str): x = eval(x) return x @with_incremental_state class MultiheadAttention(nn.Module): """Multi-headed attention. See "Attention Is All You Need" for more details. """ def __init__( self, embed_dim, num_heads, kdim=None, vdim=None, dropout=0.0, bias=True, add_bias_kv=False, add_zero_attn=False, self_attention=False, encoder_decoder_attention=False, q_noise=0.0, qn_block_size=8, # TODO: pass in config rather than string. # config defined in xformers.components.attention.AttentionConfig xformers_att_config: Optional[str] = None, xformers_blocksparse_layout: Optional[ torch.Tensor ] = None, # This should be part of the config xformers_blocksparse_blocksize: Optional[ int ] = 16, # This should be part of the config ): super().__init__() xformers_att_config = eval_str_dict(xformers_att_config) self.use_xformers = xformers_att_config is not None if self.use_xformers and not _xformers_available: raise ImportError("\n\n Please install xFormers.") self.embed_dim = embed_dim self.kdim = kdim if kdim is not None else embed_dim self.vdim = vdim if vdim is not None else embed_dim self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim self.num_heads = num_heads self.dropout_module = FairseqDropout( dropout, module_name=self.__class__.__name__ ) self.head_dim = embed_dim // num_heads assert ( self.head_dim * num_heads == self.embed_dim ), "embed_dim must be divisible by num_heads" self.scaling = self.head_dim**-0.5 self.self_attention = self_attention self.encoder_decoder_attention = encoder_decoder_attention assert not self.self_attention or self.qkv_same_dim, ( "Self-attention requires query, key and " "value to be of the same size" ) self.k_proj = quant_noise( nn.Linear(self.kdim, embed_dim, bias=bias), q_noise, qn_block_size ) self.v_proj = quant_noise( nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size ) self.q_proj = quant_noise( nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size ) self.out_proj = quant_noise( nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size ) if add_bias_kv: self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim)) self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim)) else: self.bias_k = self.bias_v = None self.add_zero_attn = add_zero_attn self.beam_size = 1 self.reset_parameters() if self.use_xformers: xformers_att_config["dropout"] = xformers_att_config.get("dropout", dropout) xformers_att_config["num_heads"] = xformers_att_config.get( "num_heads", num_heads ) if xformers_blocksparse_layout is not None: # Could be part of a single config passed only once xformers_att_config["block_size"] = xformers_blocksparse_blocksize xformers_att_config["layout"] = xformers_blocksparse_layout xformers_att_config["name"] = "blocksparse" self.attention = build_attention(xformers_att_config) self.onnx_trace = False self.skip_embed_dim_check = False def prepare_for_onnx_export_(self): self.onnx_trace = True def reset_parameters(self): if self.qkv_same_dim: # Empirically observed the convergence to be much better with # the scaled initialization nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2)) nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2)) nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2)) else: nn.init.xavier_uniform_(self.k_proj.weight) nn.init.xavier_uniform_(self.v_proj.weight) nn.init.xavier_uniform_(self.q_proj.weight) nn.init.xavier_uniform_(self.out_proj.weight) if self.out_proj.bias is not None: nn.init.constant_(self.out_proj.bias, 0.0) if self.bias_k is not None: nn.init.xavier_normal_(self.bias_k) if self.bias_v is not None: nn.init.xavier_normal_(self.bias_v) def _get_reserve_head_index(self, num_heads_to_keep: int): k_proj_heads_norm = [] q_proj_heads_norm = [] v_proj_heads_norm = [] for i in range(self.num_heads): start_idx = i * self.head_dim end_idx = (i + 1) * self.head_dim k_proj_heads_norm.append( torch.sum( torch.abs( self.k_proj.weight[ start_idx:end_idx, ] ) ).tolist() + torch.sum(torch.abs(self.k_proj.bias[start_idx:end_idx])).tolist() ) q_proj_heads_norm.append( torch.sum( torch.abs( self.q_proj.weight[ start_idx:end_idx, ] ) ).tolist() + torch.sum(torch.abs(self.q_proj.bias[start_idx:end_idx])).tolist() ) v_proj_heads_norm.append( torch.sum( torch.abs( self.v_proj.weight[ start_idx:end_idx, ] ) ).tolist() + torch.sum(torch.abs(self.v_proj.bias[start_idx:end_idx])).tolist() ) heads_norm = [] for i in range(self.num_heads): heads_norm.append( k_proj_heads_norm[i] + q_proj_heads_norm[i] + v_proj_heads_norm[i] ) sorted_head_index = sorted( range(self.num_heads), key=lambda k: heads_norm[k], reverse=True ) reserve_head_index = [] for i in range(num_heads_to_keep): start = sorted_head_index[i] * self.head_dim end = (sorted_head_index[i] + 1) * self.head_dim reserve_head_index.append((start, end)) return reserve_head_index def _adaptive_prune_heads(self, reserve_head_index: List[Tuple[int, int]]): new_q_weight = [] new_q_bias = [] new_k_weight = [] new_k_bias = [] new_v_weight = [] new_v_bias = [] new_out_proj_weight = [] for ele in reserve_head_index: start_idx, end_idx = ele new_q_weight.append( self.q_proj.weight[ start_idx:end_idx, ] ) new_q_bias.append(self.q_proj.bias[start_idx:end_idx]) new_k_weight.append( self.k_proj.weight[ start_idx:end_idx, ] ) new_k_bias.append(self.k_proj.bias[start_idx:end_idx]) new_v_weight.append( self.v_proj.weight[ start_idx:end_idx, ] ) new_v_bias.append(self.v_proj.bias[start_idx:end_idx]) new_out_proj_weight.append(self.out_proj.weight[:, start_idx:end_idx]) new_q_weight = torch.cat(new_q_weight).detach() new_k_weight = torch.cat(new_k_weight).detach() new_v_weight = torch.cat(new_v_weight).detach() new_out_proj_weight = torch.cat(new_out_proj_weight, dim=-1).detach() new_q_weight.requires_grad = True new_k_weight.requires_grad = True new_v_weight.requires_grad = True new_out_proj_weight.requires_grad = True new_q_bias = torch.cat(new_q_bias).detach() new_q_bias.requires_grad = True new_k_bias = torch.cat(new_k_bias).detach() new_k_bias.requires_grad = True new_v_bias = torch.cat(new_v_bias).detach() new_v_bias.requires_grad = True self.q_proj.weight = torch.nn.Parameter(new_q_weight) self.q_proj.bias = torch.nn.Parameter(new_q_bias) self.k_proj.weight = torch.nn.Parameter(new_k_weight) self.k_proj.bias = torch.nn.Parameter(new_k_bias) self.v_proj.weight = torch.nn.Parameter(new_v_weight) self.v_proj.bias = torch.nn.Parameter(new_v_bias) self.out_proj.weight = torch.nn.Parameter(new_out_proj_weight) self.num_heads = len(reserve_head_index) self.embed_dim = self.head_dim * self.num_heads self.q_proj.out_features = self.embed_dim self.k_proj.out_features = self.embed_dim self.v_proj.out_features = self.embed_dim def _set_skip_embed_dim_check(self): self.skip_embed_dim_check = True def _pad_masks( self, key_padding_mask: Optional[Tensor], attn_mask: Optional[Tensor], ) -> Tuple[Optional[Tensor], Optional[Tensor]]: if attn_mask is not None: shape = attn_mask.size()[:-1] + torch.Size([1]) attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(shape)], dim=-1) if key_padding_mask is not None: shape = key_padding_mask.size()[:-1] + torch.Size([1]) key_padding_mask = torch.cat( [ key_padding_mask, key_padding_mask.new_zeros(shape), ], dim=-1, ) return key_padding_mask, attn_mask def _add_bias( self, k: Tensor, v: Tensor, key_padding_mask: Optional[Tensor], attn_mask: Optional[Tensor], bsz: int, ) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]: assert self.bias_k is not None assert self.bias_v is not None k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)]) v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)]) key_padding_mask, attn_mask = self._pad_masks( key_padding_mask=key_padding_mask, attn_mask=attn_mask ) return k, v, key_padding_mask, attn_mask def _append_zero_attn( self, k: Tensor, v: Tensor, key_padding_mask: Optional[Tensor], attn_mask: Optional[Tensor], ) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]: zero_attn_shape = k.size()[:-2] + torch.Size([1]) + k.size()[-1:] k = torch.cat( [k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=-2 ) v = torch.cat( [v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=-2 ) key_padding_mask, attn_mask = self._pad_masks( key_padding_mask=key_padding_mask, attn_mask=attn_mask ) return k, v, key_padding_mask, attn_mask def _xformers_attn_forward( self, query, key: Optional[Tensor], value: Optional[Tensor], key_padding_mask: Optional[Tensor] = None, need_weights: bool = True, attn_mask: Optional[Tensor] = None, ) -> Tuple[Tensor, Optional[Tensor]]: tgt_len, bsz, embed_dim = query.size() if key_padding_mask is not None: assert key_padding_mask.size(0) == bsz assert key_padding_mask.size(1) == tgt_len if self.self_attention: key = query value = query elif self.encoder_decoder_attention: value = key q = self.q_proj(query) k = self.k_proj(key) v = self.v_proj(value) if self.bias_k is not None: assert self.bias_v is not None k, v, attn_mask, key_padding_mask = self._add_bias( k, v, attn_mask, key_padding_mask, bsz ) def fold_heads(x): return ( x.contiguous() .view(-1, bsz * self.num_heads, self.head_dim) .transpose(0, 1) ) def split_heads(x): return ( x.contiguous() .view(-1, bsz, self.num_heads, self.head_dim) .transpose(0, 1) .transpose(1, 2) ) massage = split_heads if self.attention.requires_head_dimension else fold_heads q = massage(q) if k is not None: k = massage(k) if v is not None: v = massage(v) if self.add_zero_attn: k, v, key_padding_mask, attn_mask = self._append_zero_attn( k=k, v=v, key_padding_mask=key_padding_mask, attn_mask=attn_mask ) kwargs = {} if attn_mask is not None and self.attention.supports_attention_mask: attn_mask = _mask_for_xformers(attn_mask, to_dtype=q.dtype) kwargs["att_mask"] = attn_mask if key_padding_mask is not None: key_padding_mask = _mask_for_xformers(key_padding_mask, to_dtype=torch.bool) if not self.attention.requires_separate_masks: attn_mask = maybe_merge_masks( attn_mask, key_padding_mask, batch_size=bsz, src_len=k.size(-2), tgt_len=q.size(-2), num_heads=self.num_heads, ) key_padding_mask = None kwargs["att_mask"] = attn_mask if self.attention.supports_key_padding_mask: kwargs["key_padding_mask"] = key_padding_mask y = self.attention(q, k, v, **kwargs) y = ( y.view(bsz, self.num_heads, tgt_len, self.head_dim) .transpose(1, 2) .flatten(start_dim=2, end_dim=3) .transpose(0, 1) ) assert list(y.size()) == [tgt_len, bsz, embed_dim] # Dropout not needed because already applied in attention. # It is applied to the attention weights before matmul with v. y = self.out_proj(y) # TODO: support returning attention weights if needed. return y, None def forward( self, query, key: Optional[Tensor], value: Optional[Tensor], key_padding_mask: Optional[Tensor] = None, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, need_weights: bool = True, static_kv: bool = False, attn_mask: Optional[Tensor] = None, before_softmax: bool = False, need_head_weights: bool = False, ) -> Tuple[Tensor, Optional[Tensor]]: """Input shape: Time x Batch x Channel Args: key_padding_mask (ByteTensor, optional): mask to exclude keys that are pads, of shape `(batch, src_len)`, where padding elements are indicated by 1s. need_weights (bool, optional): return the attention weights, averaged over heads (default: False). attn_mask (ByteTensor, optional): typically used to implement causal attention, where the mask prevents the attention from looking forward in time (default: None). before_softmax (bool, optional): return the raw attention weights and values before the attention softmax. need_head_weights (bool, optional): return the attention weights for each head. Implies *need_weights*. Default: return the average attention weights over all heads. """ if need_head_weights: need_weights = True is_tpu = query.device.type == "xla" tgt_len, bsz, embed_dim = query.size() src_len = tgt_len if not self.skip_embed_dim_check: assert ( embed_dim == self.embed_dim ), f"query dim {embed_dim} != {self.embed_dim}" assert list(query.size()) == [tgt_len, bsz, embed_dim] if key is not None: src_len, key_bsz, _ = key.size() if not torch.jit.is_scripting(): assert value is not None assert src_len, key_bsz == value.shape[:2] if ( not self.onnx_trace and not is_tpu # don't use PyTorch version on TPUs and incremental_state is None and not static_kv # A workaround for quantization to work. Otherwise JIT compilation # treats bias in linear module as method. and not torch.jit.is_scripting() # The Multihead attention implemented in pytorch forces strong dimension check # for input embedding dimention and K,Q,V projection dimension. # Since pruning will break the dimension check and it is not easy to modify the pytorch API, # it is preferred to bypass the pytorch MHA when we need to skip embed_dim_check and not self.skip_embed_dim_check ): assert key is not None and value is not None if self.use_xformers: return self._xformers_attn_forward( query, key, value, key_padding_mask, need_weights, attn_mask ) else: return F.multi_head_attention_forward( query, key, value, self.embed_dim, self.num_heads, torch.empty([0]), torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)), self.bias_k, self.bias_v, self.add_zero_attn, self.dropout_module.p, self.out_proj.weight, self.out_proj.bias, self.training or self.dropout_module.apply_during_inference, key_padding_mask, need_weights, attn_mask, use_separate_proj_weight=True, q_proj_weight=self.q_proj.weight, k_proj_weight=self.k_proj.weight, v_proj_weight=self.v_proj.weight, ) if incremental_state is not None: saved_state = self._get_input_buffer(incremental_state) if saved_state is not None and "prev_key" in saved_state: # previous time steps are cached - no need to recompute # key and value if they are static if static_kv: assert self.encoder_decoder_attention and not self.self_attention key = value = None else: saved_state = None if self.self_attention: q = self.q_proj(query) k = self.k_proj(query) v = self.v_proj(query) elif self.encoder_decoder_attention: # encoder-decoder attention q = self.q_proj(query) if key is None: assert value is None k = v = None else: if self.beam_size > 1 and bsz == key.size(1): # key is [T, bsz*beam_size, C], reduce to [T, bsz, C] key = key.view(key.size(0), -1, self.beam_size, key.size(2))[ :, :, 0, : ] if key_padding_mask is not None: key_padding_mask = key_padding_mask.view( -1, self.beam_size, key_padding_mask.size(1) )[:, 0, :] k = self.k_proj(key) v = self.v_proj(key) else: assert key is not None and value is not None q = self.q_proj(query) k = self.k_proj(key) v = self.v_proj(value) q *= self.scaling if self.bias_k is not None: assert self.bias_v is not None k, v, attn_mask, key_padding_mask = self._add_bias( k, v, attn_mask, key_padding_mask, bsz ) q = ( q.contiguous() .view(tgt_len, bsz * self.num_heads, self.head_dim) .transpose(0, 1) ) kv_bsz = bsz # need default value for scripting if k is not None: kv_bsz = k.size(1) k = ( k.contiguous() .view(-1, kv_bsz * self.num_heads, self.head_dim) .transpose(0, 1) ) if v is not None: v = ( v.contiguous() .view(-1, kv_bsz * self.num_heads, self.head_dim) .transpose(0, 1) ) if saved_state is not None: # saved states are stored with shape (bsz, num_heads, seq_len, head_dim) if "prev_key" in saved_state: _prev_key = saved_state["prev_key"] assert _prev_key is not None kv_bsz = _prev_key.size(0) prev_key = _prev_key.view(kv_bsz * self.num_heads, -1, self.head_dim) if static_kv: k = prev_key else: assert k is not None k = torch.cat([prev_key, k], dim=1) src_len = k.size(1) if "prev_value" in saved_state: _prev_value = saved_state["prev_value"] assert _prev_value is not None assert kv_bsz == _prev_value.size(0) prev_value = _prev_value.view( kv_bsz * self.num_heads, -1, self.head_dim ) if static_kv: v = prev_value else: assert v is not None v = torch.cat([prev_value, v], dim=1) prev_key_padding_mask: Optional[Tensor] = None if "prev_key_padding_mask" in saved_state: prev_key_padding_mask = saved_state["prev_key_padding_mask"] assert k is not None and v is not None key_padding_mask = MultiheadAttention._append_prev_key_padding_mask( key_padding_mask=key_padding_mask, prev_key_padding_mask=prev_key_padding_mask, batch_size=kv_bsz, src_len=k.size(1), static_kv=static_kv, ) saved_state["prev_key"] = k.view(kv_bsz, self.num_heads, -1, self.head_dim) saved_state["prev_value"] = v.view( kv_bsz, self.num_heads, -1, self.head_dim ) saved_state["prev_key_padding_mask"] = key_padding_mask # In this branch incremental_state is never None assert incremental_state is not None incremental_state = self._set_input_buffer(incremental_state, saved_state) assert k is not None assert k.size(1) == src_len # This is part of a workaround to get around fork/join parallelism # not supporting Optional types. if key_padding_mask is not None and key_padding_mask.dim() == 0: key_padding_mask = None if key_padding_mask is not None: assert key_padding_mask.size(0) == kv_bsz assert key_padding_mask.size(1) == src_len if self.add_zero_attn: assert v is not None src_len += 1 k, v, key_padding_mask, attn_mask = self._append_zero_attn( k=k, v=v, key_padding_mask=key_padding_mask, attn_mask=attn_mask ) if self.encoder_decoder_attention and bsz != kv_bsz: attn_weights = torch.einsum( "bxhtd,bhsd->bxhts", q.view((kv_bsz, -1, self.num_heads) + q.size()[1:]), k.view((kv_bsz, self.num_heads) + k.size()[1:]), ) attn_weights = attn_weights.reshape((-1,) + attn_weights.size()[-2:]) else: attn_weights = torch.bmm(q, k.transpose(1, 2)) attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz) assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len] if attn_mask is not None: attn_mask = attn_mask.unsqueeze(0) if self.onnx_trace: attn_mask = attn_mask.repeat(attn_weights.size(0), 1, 1) attn_weights += attn_mask if key_padding_mask is not None: # don't attend to padding symbols attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) if not is_tpu: attn_weights = attn_weights.view( kv_bsz, -1, self.num_heads, tgt_len, src_len ) attn_weights = attn_weights.masked_fill( key_padding_mask.unsqueeze(1) .unsqueeze(2) .unsqueeze(3) .to(torch.bool), float("-inf"), ) else: attn_weights = attn_weights.transpose(0, 2) attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf")) attn_weights = attn_weights.transpose(0, 2) attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) if before_softmax: return attn_weights, v attn_weights_float = softmax( attn_weights, dim=-1, onnx_trace=self.onnx_trace ) attn_weights = attn_weights_float.type_as(attn_weights) attn_probs = self.dropout_module(attn_weights) assert v is not None if self.encoder_decoder_attention and bsz != kv_bsz: attn = torch.einsum( "bxhts,bhsd->bxhtd", attn_probs.view( ( kv_bsz, -1, self.num_heads, ) + attn_probs.size()[1:] ), v.view( ( kv_bsz, self.num_heads, ) + v.size()[1:] ), ) attn = attn.reshape((-1,) + attn.size()[-2:]) else: attn = torch.bmm(attn_probs, v) assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim] if self.onnx_trace and attn.size(1) == 1: # when ONNX tracing a single decoder step (sequence length == 1) # the transpose is a no-op copy before view, thus unnecessary attn = attn.contiguous().view(tgt_len, bsz, self.embed_dim) else: attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, self.embed_dim) attn = self.out_proj(attn) attn_weights: Optional[Tensor] = None if need_weights: attn_weights = attn_weights_float.view( bsz, self.num_heads, tgt_len, src_len ).transpose(1, 0) if not need_head_weights: # average attention weights over heads attn_weights = attn_weights.mean(dim=0) return attn, attn_weights @staticmethod def _append_prev_key_padding_mask( key_padding_mask: Optional[Tensor], prev_key_padding_mask: Optional[Tensor], batch_size: int, src_len: int, static_kv: bool, ) -> Optional[Tensor]: # saved key padding masks have shape (bsz, seq_len) if prev_key_padding_mask is not None and static_kv: new_key_padding_mask = prev_key_padding_mask elif prev_key_padding_mask is not None and key_padding_mask is not None: new_key_padding_mask = torch.cat( [prev_key_padding_mask.float(), key_padding_mask.float()], dim=1 ) # During incremental decoding, as the padding token enters and # leaves the frame, there will be a time when prev or current # is None elif prev_key_padding_mask is not None: if src_len > prev_key_padding_mask.size(1): filler = torch.zeros( (batch_size, src_len - prev_key_padding_mask.size(1)), device=prev_key_padding_mask.device, ) new_key_padding_mask = torch.cat( [prev_key_padding_mask.float(), filler.float()], dim=1 ) else: new_key_padding_mask = prev_key_padding_mask.float() elif key_padding_mask is not None: if src_len > key_padding_mask.size(1): filler = torch.zeros( (batch_size, src_len - key_padding_mask.size(1)), device=key_padding_mask.device, ) new_key_padding_mask = torch.cat( [filler.float(), key_padding_mask.float()], dim=1 ) else: new_key_padding_mask = key_padding_mask.float() else: new_key_padding_mask = prev_key_padding_mask return new_key_padding_mask @torch.jit.export def reorder_incremental_state( self, incremental_state: Dict[str, Dict[str, Optional[Tensor]]], new_order: Tensor, ): """Reorder buffered internal state (for incremental generation).""" input_buffer = self._get_input_buffer(incremental_state) if input_buffer is not None: for k in input_buffer.keys(): input_buffer_k = input_buffer[k] if input_buffer_k is not None: if self.encoder_decoder_attention: if input_buffer_k.size(0) * self.beam_size == new_order.size(0): return incremental_state elif self.beam_size > 1: input_buffer[k] = input_buffer_k.index_select( 0, new_order.reshape(-1, self.beam_size)[:, 0] // self.beam_size, ) else: input_buffer[k] = input_buffer_k.index_select(0, new_order) else: input_buffer[k] = input_buffer_k.index_select(0, new_order) incremental_state = self._set_input_buffer(incremental_state, input_buffer) return incremental_state def set_beam_size(self, beam_size): """Used for effiecient beamable enc-dec attention""" self.beam_size = beam_size def _get_input_buffer( self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] ) -> Dict[str, Optional[Tensor]]: result = self.get_incremental_state(incremental_state, "attn_state") if result is not None: return result else: empty_result: Dict[str, Optional[Tensor]] = {} return empty_result def _set_input_buffer( self, incremental_state: Dict[str, Dict[str, Optional[Tensor]]], buffer: Dict[str, Optional[Tensor]], ): return self.set_incremental_state(incremental_state, "attn_state", buffer) def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int): return attn_weights def upgrade_state_dict_named(self, state_dict, name): prefix = name + "." if name != "" else "" items_to_add = {} keys_to_remove = [] for k in state_dict.keys(): if k.endswith(prefix + "in_proj_weight"): # in_proj_weight used to be q + k + v with same dimensions dim = int(state_dict[k].shape[0] / 3) items_to_add[prefix + "q_proj.weight"] = state_dict[k][:dim] items_to_add[prefix + "k_proj.weight"] = state_dict[k][dim : 2 * dim] items_to_add[prefix + "v_proj.weight"] = state_dict[k][2 * dim :] keys_to_remove.append(k) k_bias = prefix + "in_proj_bias" if k_bias in state_dict.keys(): dim = int(state_dict[k].shape[0] / 3) items_to_add[prefix + "q_proj.bias"] = state_dict[k_bias][:dim] items_to_add[prefix + "k_proj.bias"] = state_dict[k_bias][ dim : 2 * dim ] items_to_add[prefix + "v_proj.bias"] = state_dict[k_bias][2 * dim :] keys_to_remove.append(prefix + "in_proj_bias") for k in keys_to_remove: del state_dict[k] for key, value in items_to_add.items(): state_dict[key] = value @dataclass class QuantNoiseConfig: _name: str = "transformer" pq: float = 0.0 pq_block_size: int = 8 scalar: float = 0.0 def to_dict(self): return asdict(self) @classmethod def from_dict(cls, data): return cls(**data) @dataclass class EncDecBaseConfig: _name: str = "transformer" embed_path: Optional[str] = None embed_dim: int = 768 ffn_embed_dim: int = 3072 layers: int = 12 attention_heads: int = 12 normalize_before: bool = False learned_pos: bool = False layerdrop: float = 0.0 layers_to_keep: Optional[list[int]] = None xformers_att_config: Optional[dict] = None quant_noise: QuantNoiseConfig = field(default_factory=QuantNoiseConfig) padding_idx= 1 vocab_size = 64001 @dataclass class DecoderConfig(EncDecBaseConfig): input_dim: int = 768 output_dim: int = 768 vocab_size = 528 @dataclass class TransformerConfig: _name: str = "transformer" activation_fn: str = "relu" dropout: float = 0.1 attention_dropout: float = 0.1 activation_dropout: float = 0.0 adaptive_input: bool = False encoder: EncDecBaseConfig = field(default_factory=EncDecBaseConfig) max_source_positions: int = 1024 decoder: DecoderConfig = field(default_factory=DecoderConfig) max_target_positions: int = 1024 share_decoder_input_output_embed: bool = True share_all_embeddings: bool = False no_token_positional_embeddings: bool = False adaptive_softmax_cutoff: Optional[list[int]] = None adaptive_softmax_dropout: float = 0.0 adaptive_softmax_factor: int = 4 layernorm_embedding: bool = False tie_adaptive_weights: bool = False tie_adaptive_proj: bool = False no_scale_embedding: bool = False checkpoint_activations: bool = False offload_activations: bool = False no_cross_attention: bool = False cross_self_attention: bool = False quant_noise: QuantNoiseConfig = field(default_factory=QuantNoiseConfig) min_params_to_wrap: int = 100_000_000 char_inputs: bool = False relu_dropout: float = 0.0 base_layers: int = 0 base_sublayers: int = 1 base_shuffle: int = 1 export: bool = False no_decoder_final_norm: bool = False # Example of instantiating the config main_config = TransformerConfig() class TokenEmbedding(nn.Module): def __init__(self, vocab_size, embed_dim, padding_idx): super().__init__() self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx) self.vocab_size = vocab_size self.embedding_dim = embed_dim self.padding_idx = padding_idx def forward(self, input_tokens): return self.embedding(input_tokens) # Example Usage def initialize_embed_tokens(cfg, model='encoder'): """ Initialize the embed_tokens layer. Args: cfg: Configuration object dictionary: Vocabulary dictionary with token-to-index mapping Returns: embed_tokens: Token embedding layer """ vocab_size = cfg.encoder.vocab_size if model == 'encoder' else cfg.decoder.vocab_size # Assuming this attribute is added in the config embed_dim = cfg.encoder.embed_dim # Assuming this attribute is added in the config padding_idx = cfg.encoder.padding_idx #dictionary.pad() # Fetch the padding index from the dictionary return TokenEmbedding(vocab_size, embed_dim, padding_idx) class EncoderDecoderModel(nn.Module): """Standalone Encoder-Decoder model for Fairseq with necessary functionalities.""" def __init__(self, cfg): super().__init__() self.cfg = cfg self.encoder = TransformerEncoderBase(cfg, enc_dictionary, encoder_embedding.embedding) self.decoder = TransformerDecoderBase(cfg, dec_dictionary, decoder_embedding.embedding) self.supports_align_args = True self._is_generation_fast = False def forward(self, src_tokens, src_lengths, prev_output_tokens, **kwargs): """ Perform a forward pass. Args: src_tokens (LongTensor): Source tokens `(batch, src_len)` src_lengths (LongTensor): Source lengths `(batch)` prev_output_tokens (LongTensor): Previous decoder outputs `(batch, tgt_len)` Returns: Tuple: decoder output and additional info """ encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs) decoder_out = self.decoder( prev_output_tokens, encoder_out=encoder_out, **kwargs ) return decoder_out def forward_decoder(self, prev_output_tokens, **kwargs): return self.decoder(prev_output_tokens, **kwargs) def output_layer(self, features, **kwargs): """Project features to the default output size (typically vocabulary size).""" return self.decoder.output_layer(features, **kwargs) def max_positions(self): """Maximum length supported by the model.""" return (self.encoder.max_positions(), self.decoder.max_positions()) def max_decoder_positions(self): """Maximum length supported by the decoder.""" return self.decoder.max_positions() encoder_embedding = initialize_embed_tokens(main_config) decoder_embedding = initialize_embed_tokens(main_config, 'decoder') enc_dictionary = [9]* main_config.encoder.vocab_size dec_dictionary = [9] * main_config.decoder.vocab_size class AfroLidForSequenceClassification(PreTrainedModel): config_class = AfroLidConfig base_model_prefix = "transformer" def __init__(self, config): super().__init__(config) self.cfg = main_config self.encoder = TransformerEncoderBase(self.cfg, enc_dictionary, encoder_embedding.embedding) self.decoder = TransformerDecoderBase(self.cfg, dec_dictionary, decoder_embedding.embedding) self.supports_align_args = True self._is_generation_fast = False def forward(self, src_tokens, src_lengths, prev_output_tokens, **kwargs): """ Perform a forward pass. Args: src_tokens (LongTensor): Source tokens `(batch, src_len)` src_lengths (LongTensor): Source lengths `(batch)` prev_output_tokens (LongTensor): Previous decoder outputs `(batch, tgt_len)` Returns: Tuple: decoder output and additional info """ encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs) decoder_out = self.decoder( prev_output_tokens, encoder_out=encoder_out, **kwargs ) return decoder_out def forward_decoder(self, prev_output_tokens, **kwargs): return self.decoder(prev_output_tokens, **kwargs) def output_layer(self, features, **kwargs): """Project features to the default output size (typically vocabulary size).""" return self.decoder.output_layer(features, **kwargs) def max_positions(self): """Maximum length supported by the model.""" return (self.encoder.max_positions(), self.decoder.max_positions()) def max_decoder_positions(self): """Maximum length supported by the decoder.""" return self.decoder.max_positions() config = AfroLidConfig() afrolid_model = AfroLidForSequenceClassification(config) AutoConfig.register("afrolid", AfroLidConfig) AutoModel.register(AfroLidConfig, AfroLidForSequenceClassification) AutoModelForSequenceClassification.register( AfroLidConfig, AfroLidForSequenceClassification)