AfroLid / modelling_afrolid.py
damilojohn's picture
Upload 2 files
0079cb1 verified
import math
from typing import Any, Optional
import torch
import torch.onnx.operators
from torch import nn, Tensor
import torch.nn as nn
from typing import Optional, Dict, List, Any, Tuple
import torch.nn as nn
import torch.nn.functional as F
import torch
import sys
import torch.distributed as dist
import uuid
from dataclasses import dataclass, field, asdict
from transformers.modeling_utils import PreTrainedModel
from transformers import AutoConfig, AutoModel, AutoModelForSequenceClassification
from .configuration_afrolid import AfroLidConfig
def quant_noise(module, p, block_size):
"""
Wraps modules and applies quantization noise to the weights for
subsequent quantization with Iterative Product Quantization as
described in "Training with Quantization Noise for Extreme Model Compression"
Args:
- module: nn.Module
- p: amount of Quantization Noise
- block_size: size of the blocks for subsequent quantization with iPQ
Remarks:
- Module weights must have the right sizes wrt the block size
- Only Linear, Embedding and Conv2d modules are supported for the moment
- For more detail on how to quantize by blocks with convolutional weights,
see "And the Bit Goes Down: Revisiting the Quantization of Neural Networks"
- We implement the simplest form of noise here as stated in the paper
which consists in randomly dropping blocks
"""
# if no quantization noise, don't register hook
if p <= 0:
return module
# supported modules
assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d))
# test whether module.weight has the right sizes wrt block_size
is_conv = module.weight.ndim == 4
# 2D matrix
if not is_conv:
assert (
module.weight.size(1) % block_size == 0
), "Input features must be a multiple of block sizes"
# 4D matrix
else:
# 1x1 convolutions
if module.kernel_size == (1, 1):
assert (
module.in_channels % block_size == 0
), "Input channels must be a multiple of block sizes"
# regular convolutions
else:
k = module.kernel_size[0] * module.kernel_size[1]
assert k % block_size == 0, "Kernel size must be a multiple of block size"
def _forward_pre_hook(mod, input):
# no noise for evaluation
if mod.training:
if not is_conv:
# gather weight and sizes
weight = mod.weight
in_features = weight.size(1)
out_features = weight.size(0)
# split weight matrix into blocks and randomly drop selected blocks
mask = torch.zeros(
in_features // block_size * out_features, device=weight.device
)
mask.bernoulli_(p)
mask = mask.repeat_interleave(block_size, -1).view(-1, in_features)
else:
# gather weight and sizes
weight = mod.weight
in_channels = mod.in_channels
out_channels = mod.out_channels
# split weight matrix into blocks and randomly drop selected blocks
if mod.kernel_size == (1, 1):
mask = torch.zeros(
int(in_channels // block_size * out_channels),
device=weight.device,
)
mask.bernoulli_(p)
mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels)
else:
mask = torch.zeros(
weight.size(0), weight.size(1), device=weight.device
)
mask.bernoulli_(p)
mask = (
mask.unsqueeze(2)
.unsqueeze(3)
.repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1])
)
# scale weights and apply mask
mask = mask.to(
torch.bool
) # x.bool() is not currently supported in TorchScript
s = 1 / (1 - p)
mod.weight.data = s * weight.masked_fill(mask, 0)
module.register_forward_pre_hook(_forward_pre_hook)
return module
def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True, export=False):
# if torch.jit.is_scripting() or torch.jit.is_tracing():
# export = True
# if not export and torch.cuda.is_available() and has_fused_layernorm:
# return FusedLayerNorm(normalized_shape, eps, elementwise_affine)
return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine)
class LayerDropModuleList(nn.ModuleList):
"""
A LayerDrop implementation based on :class:`torch.nn.ModuleList`.
We refresh the choice of which layers to drop every time we iterate
over the LayerDropModuleList instance. During evaluation we always
iterate over all layers.
Usage::
layers = LayerDropList(p=0.5, modules=[layer1, layer2, layer3])
for layer in layers: # this might iterate over layers 1 and 3
x = layer(x)
for layer in layers: # this might iterate over all layers
x = layer(x)
for layer in layers: # this might not iterate over any layers
x = layer(x)
Args:
p (float): probability of dropping out each layer
modules (iterable, optional): an iterable of modules to add
"""
def __init__(self, p, modules=None):
super().__init__(modules)
self.p = p
def __iter__(self):
dropout_probs = torch.empty(len(self)).uniform_()
for i, m in enumerate(super().__iter__()):
if not self.training or (dropout_probs[i] > self.p):
yield m
from typing import List, Callable
from typing import Dict
import warnings
def gelu_accurate(x):
if not hasattr(gelu_accurate, "_a"):
gelu_accurate._a = math.sqrt(2 / math.pi)
return (
0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3))))
)
def deprecation_warning(message, stacklevel=3):
# don't use DeprecationWarning, since it's ignored by default
warnings.warn(message, stacklevel=stacklevel)
def gelu(x: torch.Tensor) -> torch.Tensor:
return torch.nn.functional.gelu(x.float()).type_as(x)
def relu_squared(x: torch.Tensor):
return F.relu(x).pow(2)
def get_activation_fn(activation: str) -> Callable:
"""Returns the activation function corresponding to `activation`"""
if activation == "relu":
return F.relu
elif activation == "relu_squared":
return relu_squared
elif activation == "gelu":
return gelu
elif activation == "gelu_fast":
deprecation_warning(
"--activation-fn=gelu_fast has been renamed to gelu_accurate"
)
return gelu_accurate
elif activation == "gelu_accurate":
return gelu_accurate
elif activation == "tanh":
return torch.tanh
elif activation == "linear":
return lambda x: x
elif activation == "swish":
return torch.nn.SiLU
else:
raise RuntimeError("--activation-fn {} not supported".format(activation))
class FairseqDropout(nn.Module):
def __init__(self, p, module_name=None):
super().__init__()
self.p = p
self.module_name = module_name
self.apply_during_inference = False
def forward(self, x, inplace: bool = False):
if self.p > 0 and (self.training or self.apply_during_inference):
return F.dropout(x, p=self.p, training=True, inplace=inplace)
else:
return x
class TransformerEncoderLayerBase(nn.Module):
"""Encoder layer block.
In the original paper each operation (multi-head attention or FFN) is
postprocessed with: `dropout -> add residual -> layernorm`. In the
tensor2tensor code they suggest that learning is more robust when
preprocessing each layer with layernorm and postprocessing with:
`dropout -> add residual`. We default to the approach in the paper, but the
tensor2tensor approach can be enabled by setting
*cfg.encoder.normalize_before* to ``True``.
Args:
args (argparse.Namespace): parsed command-line arguments
"""
def __init__(self, cfg, return_fc=False):
super().__init__()
self.cfg = cfg
self.return_fc = return_fc
self.embed_dim = cfg.encoder.embed_dim
self.quant_noise = cfg.quant_noise.pq
self.quant_noise_block_size = cfg.quant_noise.pq_block_size
self.self_attn = self.build_self_attention(self.embed_dim, cfg)
self.self_attn_layer_norm = LayerNorm(self.embed_dim, export=cfg.export)
self.dropout_module = FairseqDropout(
cfg.dropout, module_name=self.__class__.__name__
)
self.activation_fn = get_activation_fn(activation=cfg.activation_fn)
activation_dropout_p = cfg.activation_dropout
if activation_dropout_p == 0:
# for backwards compatibility with models that use cfg.relu_dropout
activation_dropout_p = cfg.relu_dropout or 0
self.activation_dropout_module = FairseqDropout(
float(activation_dropout_p), module_name=self.__class__.__name__
)
self.normalize_before = cfg.encoder.normalize_before
self.fc1 = self.build_fc1(
self.embed_dim,
cfg.encoder.ffn_embed_dim,
self.quant_noise,
self.quant_noise_block_size,
)
self.fc2 = self.build_fc2(
cfg.encoder.ffn_embed_dim,
self.embed_dim,
self.quant_noise,
self.quant_noise_block_size,
)
self.final_layer_norm = LayerNorm(self.embed_dim, export=cfg.export)
self.num_heads = cfg.encoder.attention_heads
self.load_to_BT = False
self.ever_training = False
# For BT, we need continuous mem
self.in_proj_weight = torch.nn.Parameter(
torch.zeros(
self.self_attn.q_proj.weight.shape[0] * 3,
self.self_attn.q_proj.weight.shape[1],
)
)
self.in_proj_bias = torch.nn.Parameter(
torch.zeros(self.self_attn.q_proj.bias.shape[0] * 3)
)
self.out_proj_weight = torch.nn.Parameter(
torch.zeros(self.self_attn.out_proj.weight.shape)
)
self.out_proj_bias = torch.nn.Parameter(
torch.zeros(self.self_attn.out_proj.bias.shape)
)
self.fc1_weight = torch.nn.Parameter(torch.zeros(self.fc1.weight.shape))
self.fc1_bias = torch.nn.Parameter(torch.zeros(self.fc1.bias.shape))
self.fc2_weight = torch.nn.Parameter(torch.zeros(self.fc2.weight.shape))
self.fc2_bias = torch.nn.Parameter(torch.zeros(self.fc2.bias.shape))
if (
self.activation_fn is torch.nn.functional.relu
or isinstance(self.activation_fn, torch.nn.ReLU)
or self.activation_fn == "relu"
):
self.activation_relu_or_gelu = 1
elif (
self.activation_fn is torch.nn.functional.gelu
or isinstance(self.activation_fn, torch.nn.GELU)
or self.activation_fn == "gelu"
):
self.activation_relu_or_gelu = 2
else:
self.activation_relu_or_gelu = 0
# Batch first can not be justified but needs user to make sure
self.can_use_fastpath = None
self.cfg_checkpoint_activations = self.cfg.checkpoint_activations
# torch version check
# make sure BT version is >=1.12.0
self.BT_version = False
if "fb" in torch.__version__:
self.BT_version = True
else:
if "+" in torch.__version__:
self.torch_version = torch.__version__.split("+")[0]
else:
self.torch_version = torch.__version__
self.torch_version = self.torch_version.split(".")
self.int_version = (
int(self.torch_version[0]) * 1000
+ int(self.torch_version[1]) * 10
+ int(self.torch_version[2])
)
if len(self.torch_version) == 3:
if self.int_version >= 1120:
self.BT_version = True
elif len(self.torch_version) == 4:
if self.int_version >= 1130:
self.BT_version = True
def _load_from_state_dict(
self,
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
self.load_to_BT = True
old_name = prefix + "self_attn."
q_proj_weight = state_dict[old_name + "q_proj.weight"]
k_proj_weight = state_dict[old_name + "k_proj.weight"]
v_proj_weight = state_dict[old_name + "v_proj.weight"]
q_proj_bias = state_dict[old_name + "q_proj.bias"]
k_proj_bias = state_dict[old_name + "k_proj.bias"]
v_proj_bias = state_dict[old_name + "v_proj.bias"]
new_name = prefix
state_dict[new_name + "in_proj_weight"] = torch.cat(
(q_proj_weight, k_proj_weight, v_proj_weight), dim=0
)
state_dict[new_name + "in_proj_bias"] = torch.cat(
(q_proj_bias, k_proj_bias, v_proj_bias), dim=0
)
state_dict[new_name + "out_proj_weight"] = state_dict[
old_name + "out_proj.weight"
]
state_dict[new_name + "out_proj_bias"] = state_dict[old_name + "out_proj.bias"]
state_dict[new_name + "fc1_weight"] = state_dict[prefix + "fc1.weight"]
state_dict[new_name + "fc1_bias"] = state_dict[prefix + "fc1.bias"]
state_dict[new_name + "fc2_weight"] = state_dict[prefix + "fc2.weight"]
state_dict[new_name + "fc2_bias"] = state_dict[prefix + "fc2.bias"]
super(TransformerEncoderLayerBase, self)._load_from_state_dict(
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
)
def build_fc1(self, input_dim, output_dim, q_noise, qn_block_size):
return quant_noise(
nn.Linear(input_dim, output_dim), p=q_noise, block_size=qn_block_size
)
def build_fc2(self, input_dim, output_dim, q_noise, qn_block_size):
return quant_noise(
nn.Linear(input_dim, output_dim), p=q_noise, block_size=qn_block_size
)
def _get_fc_rank(self, remove_num: int) -> List[int]:
f1_filter_param = []
for i in range(self.fc1.out_features):
f1_filter_param.append(
torch.sum(torch.abs(self.fc1.weight[i]))
+ torch.sum(torch.abs(self.fc2.weight[:, i]))
+ torch.abs(self.fc1.bias[i])
)
return sorted(
range(len(f1_filter_param)), key=lambda k: f1_filter_param[k], reverse=False
)[0:remove_num]
def _prune_fc_layer(self, remove_index: List[int]):
new_fc1_weight = []
new_fc1_bias = []
for i in range(self.fc1.out_features):
if i not in remove_index:
new_fc1_weight.append(self.fc1.weight[i])
new_fc1_bias.append(self.fc1.bias[i])
new_fc1_weight = torch.stack(new_fc1_weight).detach()
new_fc1_weight.requires_grad = True
new_fc1_bias = torch.stack(new_fc1_bias).detach()
new_fc1_bias.requires_grad = True
self.fc1 = quant_noise(
nn.Linear(self.fc1.in_features, self.fc1.out_features - len(remove_index)),
p=self.quant_noise,
block_size=self.quant_noise_block_size,
)
self.fc1.weight = torch.nn.Parameter(new_fc1_weight)
self.fc1.bias = torch.nn.Parameter(new_fc1_bias)
new_fc2_weight = []
new_fc2_bias = []
for i in range(self.fc2.in_features):
if i not in remove_index:
new_fc2_weight.append(self.fc2.weight[:, i])
new_fc2_bias = self.fc2.bias.detach()
new_fc2_weight = torch.stack(new_fc2_weight, dim=-1).detach()
new_fc2_weight.requires_grad = True
new_fc2_bias = self.fc2.bias.detach()
new_fc2_bias.requires_grad = True
self.fc2 = quant_noise(
nn.Linear(self.fc2.in_features - len(remove_index), self.fc2.out_features),
p=self.quant_noise,
block_size=self.quant_noise_block_size,
)
self.fc2.weight = torch.nn.Parameter(new_fc2_weight)
self.fc2.bias = torch.nn.Parameter(new_fc2_bias)
def build_self_attention(self, embed_dim, cfg):
return MultiheadAttention(
embed_dim,
cfg.encoder.attention_heads,
dropout=cfg.attention_dropout,
self_attention=True,
q_noise=self.quant_noise,
qn_block_size=self.quant_noise_block_size,
xformers_att_config=cfg.encoder.xformers_att_config,
)
def residual_connection(self, x, residual):
return residual + x
def upgrade_state_dict_named(self, state_dict, name):
"""
Rename layer norm states from `...layer_norms.0.weight` to
`...self_attn_layer_norm.weight` and `...layer_norms.1.weight` to
`...final_layer_norm.weight`
"""
layer_norm_map = {"0": "self_attn_layer_norm", "1": "final_layer_norm"}
for old, new in layer_norm_map.items():
for m in ("weight", "bias"):
k = "{}.layer_norms.{}.{}".format(name, old, m)
if k in state_dict:
state_dict["{}.{}.{}".format(name, new, m)] = state_dict[k]
del state_dict[k]
def forward(
self,
x,
encoder_padding_mask: Optional[Tensor],
attn_mask: Optional[Tensor] = None,
):
"""
Args:
x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
encoder_padding_mask (ByteTensor): binary ByteTensor of shape
`(batch, seq_len)` where padding elements are indicated by ``1``.
attn_mask (ByteTensor): binary tensor of shape `(tgt_len, src_len)`,
where `tgt_len` is the length of output and `src_len` is the
length of input, though here both are equal to `seq_len`.
`attn_mask[tgt_i, src_j] = 1` means that when calculating the
embedding for `tgt_i`, we exclude (mask out) `src_j`. This is
useful for strided self-attention.
Returns:
encoded output of shape `(seq_len, batch, embed_dim)`
"""
# anything in original attn_mask = 1, becomes -1e8
# anything in original attn_mask = 0, becomes 0
# Note that we cannot use -inf here, because at some edge cases,
# the attention weight (before softmax) for some padded element in query
# will become -inf, which results in NaN in model parameters
if self.training:
self.ever_training = True
if (
self.BT_version
and x.dim() == 3
and self.load_to_BT
and not self.return_fc
and self.can_use_fastpath
and not self.training
and not self.ever_training
and not self.cfg_checkpoint_activations
):
# assume is Batch first and nested tensor
output = torch._transformer_encoder_layer_fwd(
x,
self.embed_dim,
self.num_heads,
self.in_proj_weight,
self.in_proj_bias,
self.out_proj_weight,
self.out_proj_bias,
self.activation_relu_or_gelu == 2,
False, # norm_first, currently not supported
self.self_attn_layer_norm.eps,
self.self_attn_layer_norm.weight,
self.self_attn_layer_norm.bias,
self.final_layer_norm.weight,
self.final_layer_norm.bias,
self.fc1_weight,
self.fc1_bias,
self.fc2_weight,
self.fc2_bias,
encoder_padding_mask if encoder_padding_mask is not None else attn_mask,
)
return output
else:
if attn_mask is not None:
attn_mask = attn_mask.masked_fill(
attn_mask.to(torch.bool), -1e8 if x.dtype == torch.float32 else -1e4
)
residual = x
if self.normalize_before:
x = self.self_attn_layer_norm(x)
x, _ = self.self_attn(
query=x,
key=x,
value=x,
key_padding_mask=encoder_padding_mask,
need_weights=False,
attn_mask=attn_mask,
)
x = self.dropout_module(x)
x = self.residual_connection(x, residual)
if not self.normalize_before:
x = self.self_attn_layer_norm(x)
residual = x
if self.normalize_before:
x = self.final_layer_norm(x)
x = self.activation_fn(self.fc1(x))
x = self.activation_dropout_module(x)
x = self.fc2(x)
fc_result = x
x = self.dropout_module(x)
x = self.residual_connection(x, residual)
if not self.normalize_before:
x = self.final_layer_norm(x)
if self.return_fc and not torch.jit.is_scripting():
return x, fc_result
return x
def safe_getattr(obj, k, default=None):
"""Returns obj[k] if it exists and is not None, otherwise returns default."""
from omegaconf import OmegaConf
if OmegaConf.is_config(obj):
return obj[k] if k in obj and obj[k] is not None else default
return getattr(obj, k, default)
class TransformerDecoderLayerBase(nn.Module):
"""Decoder layer block.
In the original paper each operation (multi-head attention, encoder
attention or FFN) is postprocessed with: `dropout -> add residual ->
layernorm`. In the tensor2tensor code they suggest that learning is more
robust when preprocessing each layer with layernorm and postprocessing with:
`dropout -> add residual`. We default to the approach in the paper, but the
tensor2tensor approach can be enabled by setting
*cfg.decoder.normalize_before* to ``True``.
Args:
args (argparse.Namespace): parsed command-line arguments
no_encoder_attn (bool, optional): whether to attend to encoder outputs
(default: False).
"""
def __init__(
self, cfg, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False
): #embed_dim, num_heads, ff_dim, dropout
super().__init__()
self.embed_dim = cfg.decoder.embed_dim
self.dropout_module = FairseqDropout(
cfg.dropout, module_name=self.__class__.__name__
)
self.quant_noise = cfg.quant_noise.pq
self.quant_noise_block_size = cfg.quant_noise.pq_block_size
self.cross_self_attention = cfg.cross_self_attention
self.self_attn = self.build_self_attention(
self.embed_dim,
cfg,
add_bias_kv=add_bias_kv,
add_zero_attn=add_zero_attn,
)
self.attn_ln = (
LayerNorm(self.embed_dim)
if safe_getattr(cfg, "scale_attn", False)
else None
)
self.nh = self.self_attn.num_heads
self.head_dim = self.self_attn.head_dim
scale_heads = safe_getattr(cfg, "scale_heads", False)
self.c_attn = (
nn.Parameter(torch.ones((self.nh,)), requires_grad=True)
if scale_heads
else None
)
self.activation_fn = get_activation_fn(activation=cfg.activation_fn)
activation_dropout_p = cfg.activation_dropout
if activation_dropout_p == 0:
# for backwards compatibility with models that use cfg.relu_dropout
activation_dropout_p = cfg.relu_dropout or 0
self.activation_dropout_module = FairseqDropout(
float(activation_dropout_p), module_name=self.__class__.__name__
)
self.normalize_before = cfg.decoder.normalize_before
self.self_attn_layer_norm = LayerNorm(self.embed_dim, export=cfg.export)
if no_encoder_attn:
self.encoder_attn = None
self.encoder_attn_layer_norm = None
else:
self.encoder_attn = self.build_encoder_attention(self.embed_dim, cfg)
self.encoder_attn_layer_norm = LayerNorm(self.embed_dim, export=cfg.export)
self.ffn_layernorm = (
LayerNorm(cfg.decoder.ffn_embed_dim)
if safe_getattr(cfg, "scale_fc", False)
else None
)
self.w_resid = (
nn.Parameter(
torch.ones(
self.embed_dim,
),
requires_grad=True,
)
if safe_getattr(cfg, "scale_resids", False)
else None
)
self.fc1 = self.build_fc1(
self.embed_dim,
cfg.decoder.ffn_embed_dim,
self.quant_noise,
self.quant_noise_block_size,
)
self.fc2 = self.build_fc2(
cfg.decoder.ffn_embed_dim,
self.embed_dim,
self.quant_noise,
self.quant_noise_block_size,
)
self.final_layer_norm = LayerNorm(self.embed_dim, export=cfg.export)
self.need_attn = True
self.onnx_trace = False
def build_fc1(self, input_dim, output_dim, q_noise, qn_block_size):
return quant_noise(nn.Linear(input_dim, output_dim), q_noise, qn_block_size)
def build_fc2(self, input_dim, output_dim, q_noise, qn_block_size):
return quant_noise(nn.Linear(input_dim, output_dim), q_noise, qn_block_size)
def build_self_attention(
self, embed_dim, cfg, add_bias_kv=False, add_zero_attn=False
):
return MultiheadAttention(
embed_dim,
cfg.decoder.attention_heads,
dropout=cfg.attention_dropout,
add_bias_kv=add_bias_kv,
add_zero_attn=add_zero_attn,
self_attention=not cfg.cross_self_attention,
q_noise=self.quant_noise,
qn_block_size=self.quant_noise_block_size,
xformers_att_config=cfg.decoder.xformers_att_config,
)
def build_encoder_attention(self, embed_dim, cfg):
return MultiheadAttention(
embed_dim,
cfg.decoder.attention_heads,
kdim=cfg.encoder.embed_dim,
vdim=cfg.encoder.embed_dim,
dropout=cfg.attention_dropout,
encoder_decoder_attention=True,
q_noise=self.quant_noise,
qn_block_size=self.quant_noise_block_size,
xformers_att_config=cfg.encoder.xformers_att_config,
)
def prepare_for_onnx_export_(self):
self.onnx_trace = True
def residual_connection(self, x, residual):
return residual + x
def forward(
self,
x,
encoder_out: Optional[torch.Tensor] = None,
encoder_padding_mask: Optional[torch.Tensor] = None,
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
prev_self_attn_state: Optional[List[torch.Tensor]] = None,
prev_attn_state: Optional[List[torch.Tensor]] = None,
self_attn_mask: Optional[torch.Tensor] = None,
self_attn_padding_mask: Optional[torch.Tensor] = None,
need_attn: bool = False,
need_head_weights: bool = False,
):
"""
Args:
x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
encoder_padding_mask (ByteTensor, optional): binary
ByteTensor of shape `(batch, src_len)` where padding
elements are indicated by ``1``.
need_attn (bool, optional): return attention weights
need_head_weights (bool, optional): return attention weights
for each head (default: return average over heads).
Returns:
encoded output of shape `(seq_len, batch, embed_dim)`
"""
if need_head_weights:
need_attn = True
residual = x
if self.normalize_before:
x = self.self_attn_layer_norm(x)
if prev_self_attn_state is not None:
prev_key, prev_value = prev_self_attn_state[:2]
saved_state: Dict[str, Optional[Tensor]] = {
"prev_key": prev_key,
"prev_value": prev_value,
}
if len(prev_self_attn_state) >= 3:
saved_state["prev_key_padding_mask"] = prev_self_attn_state[2]
assert incremental_state is not None
self.self_attn._set_input_buffer(incremental_state, saved_state)
_self_attn_input_buffer = self.self_attn._get_input_buffer(incremental_state)
if self.cross_self_attention and not (
incremental_state is not None
and _self_attn_input_buffer is not None
and "prev_key" in _self_attn_input_buffer
):
if self_attn_mask is not None:
assert encoder_out is not None
self_attn_mask = torch.cat(
(x.new_zeros(x.size(0), encoder_out.size(0)), self_attn_mask), dim=1
)
if self_attn_padding_mask is not None:
if encoder_padding_mask is None:
assert encoder_out is not None
encoder_padding_mask = self_attn_padding_mask.new_zeros(
encoder_out.size(1), encoder_out.size(0)
)
self_attn_padding_mask = torch.cat(
(encoder_padding_mask, self_attn_padding_mask), dim=1
)
assert encoder_out is not None
y = torch.cat((encoder_out, x), dim=0)
else:
y = x
x, attn = self.self_attn(
query=x,
key=y,
value=y,
key_padding_mask=self_attn_padding_mask,
incremental_state=incremental_state,
need_weights=False,
attn_mask=self_attn_mask,
)
if self.c_attn is not None:
tgt_len, bsz = x.size(0), x.size(1)
x = x.view(tgt_len, bsz, self.nh, self.head_dim)
x = torch.einsum("tbhd,h->tbhd", x, self.c_attn)
x = x.reshape(tgt_len, bsz, self.embed_dim)
if self.attn_ln is not None:
x = self.attn_ln(x)
x = self.dropout_module(x)
x = self.residual_connection(x, residual)
if not self.normalize_before:
x = self.self_attn_layer_norm(x)
if self.encoder_attn is not None and encoder_out is not None:
residual = x
if self.normalize_before:
x = self.encoder_attn_layer_norm(x)
if prev_attn_state is not None:
prev_key, prev_value = prev_attn_state[:2]
saved_state: Dict[str, Optional[Tensor]] = {
"prev_key": prev_key,
"prev_value": prev_value,
}
if len(prev_attn_state) >= 3:
saved_state["prev_key_padding_mask"] = prev_attn_state[2]
assert incremental_state is not None
self.encoder_attn._set_input_buffer(incremental_state, saved_state)
x, attn = self.encoder_attn(
query=x,
key=encoder_out,
value=encoder_out,
key_padding_mask=encoder_padding_mask,
incremental_state=incremental_state,
static_kv=True,
need_weights=need_attn or (not self.training and self.need_attn),
need_head_weights=need_head_weights,
)
x = self.dropout_module(x)
x = self.residual_connection(x, residual)
if not self.normalize_before:
x = self.encoder_attn_layer_norm(x)
residual = x
if self.normalize_before:
x = self.final_layer_norm(x)
x = self.activation_fn(self.fc1(x))
x = self.activation_dropout_module(x)
if self.ffn_layernorm is not None:
x = self.ffn_layernorm(x)
x = self.fc2(x)
x = self.dropout_module(x)
if self.w_resid is not None:
residual = torch.mul(self.w_resid, residual)
x = self.residual_connection(x, residual)
if not self.normalize_before:
x = self.final_layer_norm(x)
if self.onnx_trace and incremental_state is not None:
saved_state = self.self_attn._get_input_buffer(incremental_state)
assert saved_state is not None
if self_attn_padding_mask is not None:
self_attn_state = [
saved_state["prev_key"],
saved_state["prev_value"],
saved_state["prev_key_padding_mask"],
]
else:
self_attn_state = [saved_state["prev_key"], saved_state["prev_value"]]
return x, attn, self_attn_state
return x, attn, None
def make_generation_fast_(self, need_attn: bool = False, **kwargs):
self.need_attn = need_attn
import torch
import torch.nn as nn
import math
from typing import Optional, Dict, List, Any
from torch import Tensor
def make_positions(tensor, padding_idx: int, onnx_trace: bool = False):
"""Replace non-padding symbols with their position numbers.
Position numbers begin at padding_idx+1. Padding symbols are ignored.
"""
# The series of casts and type-conversions here are carefully
# balanced to both work with ONNX export and XLA. In particular XLA
# prefers ints, cumsum defaults to output longs, and ONNX doesn't know
# how to handle the dtype kwarg in cumsum.
mask = tensor.ne(padding_idx).int()
return (torch.cumsum(mask, dim=1).type_as(mask) * mask).long() + padding_idx
class SinusoidalPositionalEmbedding(nn.Module):
"""This module produces sinusoidal positional embeddings of any length.
Padding symbols are ignored.
"""
def __init__(self, embedding_dim, padding_idx, init_size=1024):
super().__init__()
self.embedding_dim = embedding_dim
self.padding_idx = padding_idx if padding_idx is not None else 0
self.weights = SinusoidalPositionalEmbedding.get_embedding(
init_size, embedding_dim, padding_idx
)
self.onnx_trace = False
self.register_buffer("_float_tensor", torch.FloatTensor(1))
self.max_positions = int(1e5)
def prepare_for_onnx_export_(self):
self.onnx_trace = True
@staticmethod
def get_embedding(
num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None
):
"""Build sinusoidal embeddings.
This matches the implementation in tensor2tensor, but differs slightly
from the description in Section 3.5 of "Attention Is All You Need".
"""
half_dim = embedding_dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(
1
) * emb.unsqueeze(0)
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(
num_embeddings, -1
)
if embedding_dim % 2 == 1:
# zero pad
emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
if padding_idx is not None:
emb[padding_idx, :] = 0
return emb
def forward(
self,
input,
incremental_state: Optional[Any] = None,
timestep: Optional[Tensor] = None,
positions: Optional[Any] = None,
):
"""Input is expected to be of size [bsz x seqlen]."""
bspair = torch.onnx.operators.shape_as_tensor(input)
bsz, seq_len = bspair[0], bspair[1]
max_pos = self.padding_idx + 1 + seq_len
if self.weights is None or max_pos > self.weights.size(0):
# recompute/expand embeddings if needed
self.weights = SinusoidalPositionalEmbedding.get_embedding(
max_pos, self.embedding_dim, self.padding_idx
)
self.weights = self.weights.to(self._float_tensor)
if incremental_state is not None:
# positions is the same for every token when decoding a single step
pos = timestep.view(-1)[0] + 1 if timestep is not None else seq_len
if self.onnx_trace:
return (
self.weights.index_select(index=self.padding_idx + pos, dim=0)
.unsqueeze(1)
.repeat(bsz, 1, 1)
)
return self.weights[self.padding_idx + pos, :].expand(bsz, 1, -1)
positions = make_positions(
input, self.padding_idx, onnx_trace=self.onnx_trace
)
if self.onnx_trace:
flat_embeddings = self.weights.detach().index_select(0, positions.view(-1))
embedding_shape = torch.cat(
(bsz.view(1), seq_len.view(1), torch.tensor([-1], dtype=torch.long))
)
embeddings = torch.onnx.operators.reshape_from_tensor_shape(
flat_embeddings, embedding_shape
)
return embeddings
return (
self.weights.index_select(0, positions.view(-1))
.view(bsz, seq_len, -1)
.detach()
)
class TransformerEncoderBase(nn.Module):
def __init__(self, cfg, dictionary, embed_tokens, return_fc=False):
super().__init__()
self.cfg = cfg
self.dictionary = dictionary
self.return_fc = return_fc
self.register_buffer('version', torch.Tensor([3]))
self.dropout_module = FairseqDropout(cfg.dropout)
self.encoder_layerdrop = cfg.encoder.layerdrop
embed_dim = embed_tokens.embedding_dim
self.padding_idx = embed_tokens.padding_idx
self.max_source_positions = cfg.max_source_positions
self.embed_tokens = embed_tokens
self.embed_scale = 1.0 if cfg.no_scale_embedding else math.sqrt(embed_dim)
self.embed_positions = (
SinusoidalPositionalEmbedding(
embed_dim, self.padding_idx, cfg.max_source_positions + self.padding_idx + 1
) if not cfg.no_token_positional_embeddings else None
)
# self.layernorm_embedding = (
# nn.LayerNorm(embed_dim) if cfg.layernorm_embedding else None
# )
if cfg.layernorm_embedding:
self.layernorm_embedding = LayerNorm(embed_dim, export=cfg.export)
else:
self.layernorm_embedding = None
if not cfg.adaptive_input and cfg.quant_noise.pq > 0:
self.quant_noise = quant_noise(
nn.Linear(embed_dim, embed_dim, bias=False),
cfg.quant_noise.pq,
cfg.quant_noise.pq_block_size,
)
else:
self.quant_noise = None
if self.encoder_layerdrop > 0.0:
self.layers = LayerDropModuleList(p=self.encoder_layerdrop)
else:
self.layers = nn.ModuleList([])
self.layers.extend(
[self.build_encoder_layer(cfg) for i in range(cfg.encoder.layers)]
)
self.num_layers = len(self.layers)
if cfg.encoder.normalize_before:
self.layer_norm = LayerNorm(embed_dim, export=cfg.export)
else:
self.layer_norm = None
def build_encoder_layer(self, cfg):
layer = TransformerEncoderLayerBase(
cfg, return_fc=self.return_fc
)
checkpoint = cfg.checkpoint_activations
# if checkpoint:
# offload_to_cpu = cfg.offload_activations
# layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu)
# if we are checkpointing, enforce that FSDP always wraps the
# checkpointed layer, regardless of layer size
min_params_to_wrap = cfg.min_params_to_wrap if not checkpoint else 0
# layer = fsdp_wrap(layer, min_num_params=min_params_to_wrap)
return layer
def forward_embedding(
self, src_tokens, token_embedding: Optional[torch.Tensor] = None):
# embed tokens and positions
if token_embedding is None:
token_embedding = self.embed_tokens(src_tokens)
x = embed = self.embed_scale * token_embedding
if self.embed_positions is not None:
x = embed + self.embed_positions(src_tokens)
if self.layernorm_embedding is not None:
x = self.layernorm_embedding(x)
x = self.dropout_module(x)
if self.quant_noise is not None:
x = self.quant_noise(x)
return x, embed
def max_positions(self):
"""Maximum input length supported by the encoder."""
if self.embed_positions is None:
return self.max_source_positions
return min(self.max_source_positions, self.embed_positions.max_positions)
def forward(self, src_tokens, src_lengths: Optional[torch.Tensor] = None, token_embeddings: Optional[torch.Tensor] = None, return_all_hiddens: bool = False):
encoder_padding_mask = src_tokens.eq(self.padding_idx)
# encoder_padding_mask = src_tokens.device.type == "xla" or encoder_padding_mask.any()
has_pads = src_tokens.device.type == "xla" or encoder_padding_mask.any()
x, encoder_embedding = self.forward_embedding(src_tokens)
if has_pads:
x = x * (1 - encoder_padding_mask.unsqueeze(-1).type_as(x))
x = x.transpose(0, 1) # B x T x C -> T x B x C
encoder_states = [] if return_all_hiddens else None
fc_results = []
if return_all_hiddens:
encoder_states.append(x)
encoder_padding_mask = encoder_padding_mask if has_pads else None
for layer in self.layers:
x = layer(x, encoder_padding_mask = encoder_padding_mask)
if isinstance(x, tuple) and len(x) ==2:
x, fc_result = x
else:
fc_result = None
if return_all_hiddens:
assert encoder_states is not None
encoder_states.append(x)
fc_results.append(fc_result)
if self.layer_norm is not None:
x = self.layer_norm(x)
src_lengths = (
src_tokens.ne(self.padding_idx)
.sum(dim=1, dtype=torch.int32)
.reshape(-1, 1)
.contiguous()
)
return {
"encoder_out": [x], # T x B x C
"encoder_padding_mask": [encoder_padding_mask], # B x T
"encoder_embedding": [encoder_embedding], # B x T x C
"encoder_states": encoder_states, # List[T x B x C]
"fc_results": fc_results, # List[T x B x C]
"src_tokens": [],
"src_lengths": [src_lengths],
}
import torch.nn as nn
import torch
import sys
import torch.distributed as dist
# from fairseq import utils
# from fairseq.distributed import utils as distributed_utils
# from fairseq.modules.layer_norm import LayerNorm
_MODEL_PARALLEL_GROUP = None
# Data parallel group that the current rank belongs to.
_DATA_PARALLEL_GROUP = None
_USE_XLA = False
def use_xla():
global _USE_XLA
return _USE_XLA
def get_world_size(group):
if use_xla():
assert group[0] == "tpu"
my_group = _find_my_group(group[1])
return len(my_group)
elif torch.distributed.is_initialized():
return dist.get_world_size(group=group)
else:
return 1
def get_global_world_size():
if use_xla():
return xm.xrt_world_size()
elif torch.distributed.is_initialized():
return torch.distributed.get_world_size()
else:
return 1
def get_global_rank():
if use_xla():
return xm.get_ordinal()
elif torch.distributed.is_initialized():
return torch.distributed.get_rank()
else:
return 0
def new_groups(grouped_ranks: List[List[int]]):
if use_xla():
return ("tpu", grouped_ranks)
else:
groups = [dist.new_group(g) for g in grouped_ranks]
my_group_idx = _find_my_group_index(grouped_ranks)
return groups[my_group_idx]
def get_global_group():
if use_xla():
return new_groups([list(range(get_global_world_size()))])
elif torch.distributed.is_initialized():
if not hasattr(get_global_group, "_global_group"):
# ideally we could use torch.distributed.group.WORLD, but it seems
# to cause random NCCL hangs in some cases
get_global_group._global_group = dist.new_group()
return get_global_group._global_group
else:
return None
def get_global_group():
if use_xla():
return new_groups([list(range(get_global_world_size()))])
elif torch.distributed.is_initialized():
if not hasattr(get_global_group, "_global_group"):
# ideally we could use torch.distributed.group.WORLD, but it seems
# to cause random NCCL hangs in some cases
get_global_group._global_group = dist.new_group()
return get_global_group._global_group
else:
return None
def _find_my_group_index(grouped_ranks):
my_rank = get_global_rank()
for i, group in enumerate(grouped_ranks):
if my_rank in group:
return i
raise RuntimeError
def _find_my_group(grouped_ranks):
index = _find_my_group_index(grouped_ranks)
return grouped_ranks[index]
def get_global_group():
if use_xla():
return new_groups([list(range(get_global_world_size()))])
elif torch.distributed.is_initialized():
if not hasattr(get_global_group, "_global_group"):
# ideally we could use torch.distributed.group.WORLD, but it seems
# to cause random NCCL hangs in some cases
get_global_group._global_group = dist.new_group()
return get_global_group._global_group
else:
return None
def get_world_size(group):
if use_xla():
assert group[0] == "tpu"
my_group = _find_my_group(group[1])
return len(my_group)
elif torch.distributed.is_initialized():
return dist.get_world_size(group=group)
else:
return 1
def get_rank(group):
if use_xla():
assert group[0] == "tpu"
my_group = _find_my_group(group[1])
return my_group.index(get_global_rank())
else:
return dist.get_rank(group=group)
def mpu_get_data_parallel_group():
"""Get the data parallel group the caller rank belongs to."""
assert _DATA_PARALLEL_GROUP is not None, \
'data parallel group is not initialized'
return _DATA_PARALLEL_GROUP
def get_data_parallel_group():
"""Get the data parallel group the caller rank belongs to."""
global _USE_MEGATRON
if _USE_MEGATRON:
return mpu_get_data_parallel_group()
else:
return get_global_group()
def get_data_parallel_rank():
"""Return my rank for the data parallel group."""
return get_rank(get_data_parallel_group())
def get_data_parallel_world_size():
"""Return world size for the data parallel group."""
return get_world_size(get_data_parallel_group())
class BaseSublayer(nn.Module):
def __init__(self, args):
super().__init__()
self.activation_fn = get_activation_fn(
activation=getattr(args, "activation_fn", "relu") or "relu"
)
self.norm = LayerNorm(args.decoder_embed_dim, export=False)
self.ff1 = torch.nn.Linear(args.decoder_embed_dim, args.decoder_ffn_embed_dim)
self.ff2 = torch.nn.Linear(args.decoder_ffn_embed_dim, args.decoder_embed_dim)
self.ff2.weight.data.zero_()
def forward(self, xs):
return xs + self.ff2(self.activation_fn(self.ff1(self.norm(xs))))
class BaseLayer(nn.Module):
def __init__(self, args):
super().__init__()
self.num_workers = get_data_parallel_world_size()
expert_centroids = torch.empty(self.num_workers, args.decoder_embed_dim)
torch.nn.init.orthogonal_(expert_centroids, gain=0.1)
self.register_parameter(
"expert_centroids", torch.nn.Parameter(expert_centroids)
)
self.expert_network = nn.Sequential(
*([BaseSublayer(args) for _ in range(args.base_sublayers)])
)
self.expert_id = get_data_parallel_rank()
self.shuffle = args.base_shuffle
self.cpp = self.load_assignment()
# Add a special attribute to the expert parameters, so we know not to sync their gradients
for param in self.expert_network.parameters():
param.expert = True
def forward(self, input_features, *args, **kwargs):
features = input_features.reshape(-1, input_features.size(-1))
is_training = input_features.requires_grad
if self.shuffle and is_training:
# Send each token to a random worker, to break correlations within the batch
shuffle_sort = torch.randperm(features.size(0), device=features.device)
features = All2All.apply(features[shuffle_sort])
with torch.no_grad():
# Compute similarity of each token to each expert, for routing
token_expert_affinities = features.matmul(
self.expert_centroids.transpose(0, 1)
)
# Compute which token goes to which expert
sort_by_expert, input_splits, output_splits = (
self.balanced_assignment(token_expert_affinities)
if is_training
else self.greedy_assignment(token_expert_affinities)
)
# Swap these tokens for the right ones for our expert
routed_features = All2All.apply(
features[sort_by_expert], output_splits, input_splits
)
if routed_features.size(0) > 0:
# Mix in the expert network based on how appropriate it is for these tokens
alpha = torch.sigmoid(
routed_features.mv(self.expert_centroids[self.expert_id])
).unsqueeze(1)
routed_features = (
alpha * self.expert_network(routed_features)
+ (1 - alpha) * routed_features
)
# Return to original worker and ordering
result = All2All.apply(routed_features, input_splits, output_splits)[
self.inverse_sort(sort_by_expert)
]
if self.shuffle and is_training:
# Undo shuffling
result = All2All.apply(result)[self.inverse_sort(shuffle_sort)]
# Return additional Nones for compatibility with TransformerDecoderLayer
return result.view(input_features.size()), None, None
def inverse_sort(self, order):
# Creates an index that undoes a sort: xs==xs[order][inverse_sort(order)]
return torch.empty_like(order).scatter_(
0, order, torch.arange(0, order.size(0), device=order.device)
)
def balanced_assignment(self, scores):
ok = scores.isfinite()
if not ok.all():
# NaNs here can break the assignment algorithm
scores[~ok] = scores[ok].min()
return self.cpp.balanced_assignment(scores), None, None
# Assigns each token to the top k experts
def greedy_assignment(self, scores, k=1):
token_to_workers = torch.topk(scores, dim=1, k=k, largest=True).indices.view(-1)
token_to_workers, sort_ordering = torch.sort(token_to_workers)
worker2token = sort_ordering // k
# Find how many tokens we're sending to each other worker (being careful for sending 0 tokens to some workers)
output_splits = torch.zeros(
(self.num_workers,), dtype=torch.long, device=scores.device
)
workers, counts = torch.unique_consecutive(token_to_workers, return_counts=True)
output_splits[workers] = counts
# Tell other workers how many tokens to expect from us
input_splits = All2All.apply(output_splits)
return worker2token, input_splits.tolist(), output_splits.tolist()
def load_assignment(self):
try:
from fairseq import libbase
return libbase
except ImportError as e:
sys.stderr.write(
"ERROR: missing libbase. run `python setup.py build_ext --inplace`\n"
)
raise e
class TransformerDecoderBase(nn.Module):
"""
Transformer decoder implemented using PyTorch's nn.Module.
Args:
vocab_size (int): Size of the vocabulary.
embed_dim (int): Dimension of the embeddings.
num_layers (int): Number of Transformer decoder layers.
num_heads (int): Number of attention heads.
ff_dim (int): Dimension of feed-forward layers.
dropout (float): Dropout probability.
max_target_positions (int): Maximum target sequence length.
padding_idx (int): Index for the padding token.
share_input_output_embed (bool): Whether to share input/output embeddings.
"""
def __init__(
self,
cfg,
dictionary,
embed_tokens,
no_encoder_attn=False,
output_projection=None,
):
super().__init__()
self.register_buffer("version", torch.Tensor([3]))
self._future_mask = torch.empty(0)
################
self.dropout_module = FairseqDropout(
cfg.dropout, module_name="TransformerDecoder")
self.decoder_layerdrop = cfg.decoder.layerdrop
self.share_input_output_embed = cfg.share_decoder_input_output_embed
input_embed_dim = embed_tokens.embedding_dim
embed_dim = cfg.decoder.embed_dim
self.embed_dim = embed_dim
self.output_embed_dim = cfg.decoder.output_dim
self.padding_idx = embed_tokens.padding_idx
self.max_target_positions = cfg.max_target_positions
self.embed_tokens = embed_tokens
self.embed_scale = 1.0 if cfg.no_scale_embedding else math.sqrt(
embed_dim)
if cfg.quant_noise.pq > 0:
self.quant_noise = quant_noise(
nn.Linear(embed_dim, embed_dim, bias=False),
cfg.quant_noise.pq,
cfg.quant_noise.pq_block_size,
)
else:
self.quant_noise = None
self.project_in_dim = (
nn.Linear(input_embed_dim, embed_dim, bias=False)
if embed_dim != input_embed_dim
else None
)
self.embed_positions = (
SinusoidalPositionalEmbedding(
embed_dim, self.padding_idx, cfg.max_target_positions + self.padding_idx + 1
)
if not cfg.no_token_positional_embeddings
else None
)
if cfg.layernorm_embedding:
self.layernorm_embedding = LayerNorm(embed_dim, export=cfg.export)
else:
self.layernorm_embedding = None
self.cross_self_attention = cfg.cross_self_attention
if self.decoder_layerdrop > 0.0:
self.layers = LayerDropModuleList(p=self.decoder_layerdrop)
else:
self.layers = nn.ModuleList([])
self.layers.extend(
[
self.build_decoder_layer(cfg, no_encoder_attn)
for _ in range(cfg.decoder.layers)
]
)
self.num_layers = len(self.layers)
if cfg.decoder.normalize_before and not cfg.no_decoder_final_norm:
self.layer_norm = LayerNorm(embed_dim, export=cfg.export)
else:
self.layer_norm = None
self.project_out_dim = (
nn.Linear(embed_dim, self.output_embed_dim, bias=False)
if embed_dim != self.output_embed_dim and not cfg.tie_adaptive_weights
else None
)
self.adaptive_softmax = None
self.output_projection = output_projection
if self.output_projection is None:
self.build_output_projection(cfg, dictionary, embed_tokens)
################
def build_output_projection(self, cfg, dictionary, embed_tokens):
if self.share_input_output_embed:
self.output_projection = nn.Linear(
self.embed_tokens.weight.shape[1],
self.embed_tokens.weight.shape[0],
bias=False,
)
self.output_projection.weight = self.embed_tokens.weight
else:
self.output_projection = nn.Linear(
self.output_embed_dim, len(dictionary), bias=False
)
nn.init.normal_(
self.output_projection.weight, mean=0, std=self.output_embed_dim**-0.5
)
num_base_layers = cfg.base_layers
for i in range(num_base_layers):
self.layers.insert(
((i + 1) * cfg.decoder.layers) // (num_base_layers + 1),
BaseLayer(cfg),
)
def build_decoder_layer(self, cfg, no_encoder_attn=False):
layer = TransformerDecoderLayerBase(cfg, no_encoder_attn)
checkpoint = cfg.checkpoint_activations
if checkpoint:
offload_to_cpu = cfg.offload_activations
# layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu)
# if we are checkpointing, enforce that FSDP always wraps the
# checkpointed layer, regardless of layer size
min_params_to_wrap = cfg.min_params_to_wrap if not checkpoint else 0
# layer = fsdp_wrap(layer, min_num_params=min_params_to_wrap)
return layer
def forward(
self,
prev_output_tokens: Tensor,
encoder_out: Optional[Tensor] = None,
src_padding_mask: Optional[Tensor] = None,
src_lengths: Optional[Any] = None,
return_all_hiddens: bool = False,
features_only: bool = False,
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
full_context_alignment: bool = False,
alignment_layer: Optional[int] = None,
alignment_heads: Optional[int] = None,
):
"""
Args:
prev_output_tokens (Tensor): Previous output tokens of shape (batch, tgt_len).
encoder_out (Tensor, optional): Encoder outputs (batch, src_len, embed_dim).
src_padding_mask (Tensor, optional): Padding mask for the encoder inputs.
Returns:
Tensor: Decoder output of shape (batch, tgt_len, vocab_size).
"""
bs, slen = prev_output_tokens.size()
if alignment_layer is None:
alignment_layer = self.num_layers - 1
enc: Optional[Tensor] = None
padding_mask: Optional[Tensor] = None
if encoder_out is not None and len(encoder_out["encoder_out"]) > 0:
enc = encoder_out["encoder_out"][0]
if encoder_out is not None and len(encoder_out["encoder_padding_mask"]) > 0:
padding_mask = encoder_out["encoder_padding_mask"][0]
# embed positions
positions = None
if self.embed_positions is not None:
positions = self.embed_positions(
prev_output_tokens, incremental_state=incremental_state
)
if incremental_state is not None:
prev_output_tokens = prev_output_tokens[:, -1:]
if positions is not None:
positions = positions[:, -1:]
# Prevent torchscript exporting issue for dynamic quant embedding
prev_output_tokens = prev_output_tokens.contiguous()
# embed tokens and positions
x = self.embed_scale * self.embed_tokens(prev_output_tokens)
if self.quant_noise is not None:
x = self.quant_noise(x)
if self.project_in_dim is not None:
x = self.project_in_dim(x)
if positions is not None:
x += positions
if self.layernorm_embedding is not None:
x = self.layernorm_embedding(x)
x = self.dropout_module(x)
# B x T x C -> T x B x C
x = x.transpose(0, 1)
self_attn_padding_mask: Optional[Tensor] = None
if self.cross_self_attention or prev_output_tokens.eq(self.padding_idx).any():
self_attn_padding_mask = prev_output_tokens.eq(self.padding_idx)
# Embed tokens and positions
# positions = torch.arange(prev_output_tokens.size(1), device=prev_output_tokens.device).unsqueeze(0)
# x = self.embed_tokens(prev_output_tokens) + self.embed_positions(positions)
# x = self.dropout(x)
# decoder layers
attn: Optional[Tensor] = None
inner_states: List[Optional[Tensor]] = [x]
for idx, layer in enumerate(self.layers):
if incremental_state is None and not full_context_alignment:
self_attn_mask = self.buffered_future_mask(x)
else:
self_attn_mask = None
x, layer_attn, _ = layer(
x,
enc,
padding_mask,
incremental_state,
self_attn_mask=self_attn_mask,
self_attn_padding_mask=self_attn_padding_mask,
need_attn=bool((idx == alignment_layer)),
need_head_weights=bool((idx == alignment_layer)),
)
inner_states.append(x)
if layer_attn is not None and idx == alignment_layer:
attn = layer_attn.float().to(x)
if attn is not None:
if alignment_heads is not None:
attn = attn[:alignment_heads]
# average probabilities over heads
attn = attn.mean(dim=0)
if self.layer_norm is not None:
x = self.layer_norm(x)
# T x B x C -> B x T x C
x = x.transpose(0, 1)
if self.project_out_dim is not None:
x = self.project_out_dim(x)
if not features_only:
x = self.output_layer(x)
return x, {"attn": [attn], "inner_states": inner_states}
def output_layer(self, features):
"""Project features to the vocabulary size."""
if self.adaptive_softmax is None:
# project back to size of vocabulary
return self.output_projection(features)
else:
return features
def max_positions(self):
"""Maximum output length supported by the decoder."""
if self.embed_positions is None:
return self.max_target_positions
return min(self.max_target_positions, self.embed_positions.max_positions)
def fill_with_neg_inf(self, t):
"""FP16-compatible function that fills a tensor with -inf."""
return t.float().fill_(float("-inf")).type_as(t)
def buffered_future_mask(self, tensor):
dim = tensor.size(0)
# self._future_mask.device != tensor.device is not working in TorchScript. This is a workaround.
if (
self._future_mask.size(0) == 0
or (not self._future_mask.device == tensor.device)
or self._future_mask.size(0) < dim
):
self._future_mask = torch.triu(
self.fill_with_neg_inf(torch.zeros([dim, dim])), 1
)
self._future_mask = self._future_mask.to(tensor)
return self._future_mask[:dim, :dim]
class FairseqIncrementalState(object):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.init_incremental_state()
def init_incremental_state(self):
self._incremental_state_id = str(uuid.uuid4())
def _get_full_incremental_state_key(self, key: str) -> str:
return "{}.{}".format(self._incremental_state_id, key)
def get_incremental_state(
self,
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]],
key: str,
) -> Optional[Dict[str, Optional[Tensor]]]:
"""Helper for getting incremental state for an nn.Module."""
full_key = self._get_full_incremental_state_key(key)
if incremental_state is None or full_key not in incremental_state:
return None
return incremental_state[full_key]
def set_incremental_state(
self,
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]],
key: str,
value: Dict[str, Optional[Tensor]],
) -> Optional[Dict[str, Dict[str, Optional[Tensor]]]]:
"""Helper for setting incremental state for an nn.Module."""
if incremental_state is not None:
full_key = self._get_full_incremental_state_key(key)
incremental_state[full_key] = value
return incremental_state
def with_incremental_state(cls):
cls.__bases__ = (FairseqIncrementalState,) + tuple(
b for b in cls.__bases__ if b != FairseqIncrementalState
)
return cls
def eval_str_dict(x, type=dict):
if x is None:
return None
if isinstance(x, str):
x = eval(x)
return x
def softmax(x, dim: int, onnx_trace: bool = False):
if onnx_trace:
return F.softmax(x.float(), dim=dim)
else:
return F.softmax(x, dim=dim, dtype=torch.float32)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
from typing import Dict, List, Optional, Tuple
import torch
import torch.nn.functional as F
from torch import Tensor, nn
from torch.nn import Parameter
try:
from xformers.components.attention import build_attention
from xformers.components.attention.utils import maybe_merge_masks
_xformers_available = True
except ImportError:
_xformers_available = False
# TODO: move this into xformers?
# TODO: uint8 input type should just output a bool
def _mask_for_xformers(mask: Tensor, to_dtype: Optional[torch.dtype] = None):
"""
call to pytorch multihead accepts three mask types:
- ByteTensor where non-zero means to mask
- FloatTensor which is an additive mask
- BoolTensor where True means to mask
xFormers currently accepts boolean and additive maks. For boolean masks
the values have opposite meaning. For a BoolTensor True mean to keep the value.
"""
float_types = [torch.float, torch.float16]
# If an input mask is a float it is an additive mask. Otherwise it is either uint8 or bool.
additive = mask.dtype in float_types
# If to_dype is not specified, keep same dtype as mask.
to_dtype = mask.dtype if to_dtype is None else to_dtype
to_additive = to_dtype in float_types
if additive:
if to_additive:
return mask.to(to_dtype)
mask = mask < 0
if to_additive:
# return additive mask
new_mask = torch.zeros_like(mask, dtype=to_dtype)
new_mask = new_mask.masked_fill_(mask, -float("inf"))
return new_mask
# In xFormers True is value to keep rather than value to mask
mask = ~mask.to(torch.bool)
mask = mask.to(to_dtype)
return mask
def softmax(x, dim: int, onnx_trace: bool = False):
if onnx_trace:
return F.softmax(x.float(), dim=dim)
else:
return F.softmax(x, dim=dim, dtype=torch.float32)
def eval_str_dict(x, type=dict):
if x is None:
return None
if isinstance(x, str):
x = eval(x)
return x
@with_incremental_state
class MultiheadAttention(nn.Module):
"""Multi-headed attention.
See "Attention Is All You Need" for more details.
"""
def __init__(
self,
embed_dim,
num_heads,
kdim=None,
vdim=None,
dropout=0.0,
bias=True,
add_bias_kv=False,
add_zero_attn=False,
self_attention=False,
encoder_decoder_attention=False,
q_noise=0.0,
qn_block_size=8,
# TODO: pass in config rather than string.
# config defined in xformers.components.attention.AttentionConfig
xformers_att_config: Optional[str] = None,
xformers_blocksparse_layout: Optional[
torch.Tensor
] = None, # This should be part of the config
xformers_blocksparse_blocksize: Optional[
int
] = 16, # This should be part of the config
):
super().__init__()
xformers_att_config = eval_str_dict(xformers_att_config)
self.use_xformers = xformers_att_config is not None
if self.use_xformers and not _xformers_available:
raise ImportError("\n\n Please install xFormers.")
self.embed_dim = embed_dim
self.kdim = kdim if kdim is not None else embed_dim
self.vdim = vdim if vdim is not None else embed_dim
self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
self.num_heads = num_heads
self.dropout_module = FairseqDropout(
dropout, module_name=self.__class__.__name__
)
self.head_dim = embed_dim // num_heads
assert (
self.head_dim * num_heads == self.embed_dim
), "embed_dim must be divisible by num_heads"
self.scaling = self.head_dim**-0.5
self.self_attention = self_attention
self.encoder_decoder_attention = encoder_decoder_attention
assert not self.self_attention or self.qkv_same_dim, (
"Self-attention requires query, key and " "value to be of the same size"
)
self.k_proj = quant_noise(
nn.Linear(self.kdim, embed_dim, bias=bias), q_noise, qn_block_size
)
self.v_proj = quant_noise(
nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size
)
self.q_proj = quant_noise(
nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size
)
self.out_proj = quant_noise(
nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size
)
if add_bias_kv:
self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
else:
self.bias_k = self.bias_v = None
self.add_zero_attn = add_zero_attn
self.beam_size = 1
self.reset_parameters()
if self.use_xformers:
xformers_att_config["dropout"] = xformers_att_config.get("dropout", dropout)
xformers_att_config["num_heads"] = xformers_att_config.get(
"num_heads", num_heads
)
if xformers_blocksparse_layout is not None:
# Could be part of a single config passed only once
xformers_att_config["block_size"] = xformers_blocksparse_blocksize
xformers_att_config["layout"] = xformers_blocksparse_layout
xformers_att_config["name"] = "blocksparse"
self.attention = build_attention(xformers_att_config)
self.onnx_trace = False
self.skip_embed_dim_check = False
def prepare_for_onnx_export_(self):
self.onnx_trace = True
def reset_parameters(self):
if self.qkv_same_dim:
# Empirically observed the convergence to be much better with
# the scaled initialization
nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
else:
nn.init.xavier_uniform_(self.k_proj.weight)
nn.init.xavier_uniform_(self.v_proj.weight)
nn.init.xavier_uniform_(self.q_proj.weight)
nn.init.xavier_uniform_(self.out_proj.weight)
if self.out_proj.bias is not None:
nn.init.constant_(self.out_proj.bias, 0.0)
if self.bias_k is not None:
nn.init.xavier_normal_(self.bias_k)
if self.bias_v is not None:
nn.init.xavier_normal_(self.bias_v)
def _get_reserve_head_index(self, num_heads_to_keep: int):
k_proj_heads_norm = []
q_proj_heads_norm = []
v_proj_heads_norm = []
for i in range(self.num_heads):
start_idx = i * self.head_dim
end_idx = (i + 1) * self.head_dim
k_proj_heads_norm.append(
torch.sum(
torch.abs(
self.k_proj.weight[
start_idx:end_idx,
]
)
).tolist()
+ torch.sum(torch.abs(self.k_proj.bias[start_idx:end_idx])).tolist()
)
q_proj_heads_norm.append(
torch.sum(
torch.abs(
self.q_proj.weight[
start_idx:end_idx,
]
)
).tolist()
+ torch.sum(torch.abs(self.q_proj.bias[start_idx:end_idx])).tolist()
)
v_proj_heads_norm.append(
torch.sum(
torch.abs(
self.v_proj.weight[
start_idx:end_idx,
]
)
).tolist()
+ torch.sum(torch.abs(self.v_proj.bias[start_idx:end_idx])).tolist()
)
heads_norm = []
for i in range(self.num_heads):
heads_norm.append(
k_proj_heads_norm[i] + q_proj_heads_norm[i] + v_proj_heads_norm[i]
)
sorted_head_index = sorted(
range(self.num_heads), key=lambda k: heads_norm[k], reverse=True
)
reserve_head_index = []
for i in range(num_heads_to_keep):
start = sorted_head_index[i] * self.head_dim
end = (sorted_head_index[i] + 1) * self.head_dim
reserve_head_index.append((start, end))
return reserve_head_index
def _adaptive_prune_heads(self, reserve_head_index: List[Tuple[int, int]]):
new_q_weight = []
new_q_bias = []
new_k_weight = []
new_k_bias = []
new_v_weight = []
new_v_bias = []
new_out_proj_weight = []
for ele in reserve_head_index:
start_idx, end_idx = ele
new_q_weight.append(
self.q_proj.weight[
start_idx:end_idx,
]
)
new_q_bias.append(self.q_proj.bias[start_idx:end_idx])
new_k_weight.append(
self.k_proj.weight[
start_idx:end_idx,
]
)
new_k_bias.append(self.k_proj.bias[start_idx:end_idx])
new_v_weight.append(
self.v_proj.weight[
start_idx:end_idx,
]
)
new_v_bias.append(self.v_proj.bias[start_idx:end_idx])
new_out_proj_weight.append(self.out_proj.weight[:, start_idx:end_idx])
new_q_weight = torch.cat(new_q_weight).detach()
new_k_weight = torch.cat(new_k_weight).detach()
new_v_weight = torch.cat(new_v_weight).detach()
new_out_proj_weight = torch.cat(new_out_proj_weight, dim=-1).detach()
new_q_weight.requires_grad = True
new_k_weight.requires_grad = True
new_v_weight.requires_grad = True
new_out_proj_weight.requires_grad = True
new_q_bias = torch.cat(new_q_bias).detach()
new_q_bias.requires_grad = True
new_k_bias = torch.cat(new_k_bias).detach()
new_k_bias.requires_grad = True
new_v_bias = torch.cat(new_v_bias).detach()
new_v_bias.requires_grad = True
self.q_proj.weight = torch.nn.Parameter(new_q_weight)
self.q_proj.bias = torch.nn.Parameter(new_q_bias)
self.k_proj.weight = torch.nn.Parameter(new_k_weight)
self.k_proj.bias = torch.nn.Parameter(new_k_bias)
self.v_proj.weight = torch.nn.Parameter(new_v_weight)
self.v_proj.bias = torch.nn.Parameter(new_v_bias)
self.out_proj.weight = torch.nn.Parameter(new_out_proj_weight)
self.num_heads = len(reserve_head_index)
self.embed_dim = self.head_dim * self.num_heads
self.q_proj.out_features = self.embed_dim
self.k_proj.out_features = self.embed_dim
self.v_proj.out_features = self.embed_dim
def _set_skip_embed_dim_check(self):
self.skip_embed_dim_check = True
def _pad_masks(
self,
key_padding_mask: Optional[Tensor],
attn_mask: Optional[Tensor],
) -> Tuple[Optional[Tensor], Optional[Tensor]]:
if attn_mask is not None:
shape = attn_mask.size()[:-1] + torch.Size([1])
attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(shape)], dim=-1)
if key_padding_mask is not None:
shape = key_padding_mask.size()[:-1] + torch.Size([1])
key_padding_mask = torch.cat(
[
key_padding_mask,
key_padding_mask.new_zeros(shape),
],
dim=-1,
)
return key_padding_mask, attn_mask
def _add_bias(
self,
k: Tensor,
v: Tensor,
key_padding_mask: Optional[Tensor],
attn_mask: Optional[Tensor],
bsz: int,
) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]:
assert self.bias_k is not None
assert self.bias_v is not None
k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
key_padding_mask, attn_mask = self._pad_masks(
key_padding_mask=key_padding_mask, attn_mask=attn_mask
)
return k, v, key_padding_mask, attn_mask
def _append_zero_attn(
self,
k: Tensor,
v: Tensor,
key_padding_mask: Optional[Tensor],
attn_mask: Optional[Tensor],
) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]:
zero_attn_shape = k.size()[:-2] + torch.Size([1]) + k.size()[-1:]
k = torch.cat(
[k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=-2
)
v = torch.cat(
[v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=-2
)
key_padding_mask, attn_mask = self._pad_masks(
key_padding_mask=key_padding_mask, attn_mask=attn_mask
)
return k, v, key_padding_mask, attn_mask
def _xformers_attn_forward(
self,
query,
key: Optional[Tensor],
value: Optional[Tensor],
key_padding_mask: Optional[Tensor] = None,
need_weights: bool = True,
attn_mask: Optional[Tensor] = None,
) -> Tuple[Tensor, Optional[Tensor]]:
tgt_len, bsz, embed_dim = query.size()
if key_padding_mask is not None:
assert key_padding_mask.size(0) == bsz
assert key_padding_mask.size(1) == tgt_len
if self.self_attention:
key = query
value = query
elif self.encoder_decoder_attention:
value = key
q = self.q_proj(query)
k = self.k_proj(key)
v = self.v_proj(value)
if self.bias_k is not None:
assert self.bias_v is not None
k, v, attn_mask, key_padding_mask = self._add_bias(
k, v, attn_mask, key_padding_mask, bsz
)
def fold_heads(x):
return (
x.contiguous()
.view(-1, bsz * self.num_heads, self.head_dim)
.transpose(0, 1)
)
def split_heads(x):
return (
x.contiguous()
.view(-1, bsz, self.num_heads, self.head_dim)
.transpose(0, 1)
.transpose(1, 2)
)
massage = split_heads if self.attention.requires_head_dimension else fold_heads
q = massage(q)
if k is not None:
k = massage(k)
if v is not None:
v = massage(v)
if self.add_zero_attn:
k, v, key_padding_mask, attn_mask = self._append_zero_attn(
k=k, v=v, key_padding_mask=key_padding_mask, attn_mask=attn_mask
)
kwargs = {}
if attn_mask is not None and self.attention.supports_attention_mask:
attn_mask = _mask_for_xformers(attn_mask, to_dtype=q.dtype)
kwargs["att_mask"] = attn_mask
if key_padding_mask is not None:
key_padding_mask = _mask_for_xformers(key_padding_mask, to_dtype=torch.bool)
if not self.attention.requires_separate_masks:
attn_mask = maybe_merge_masks(
attn_mask,
key_padding_mask,
batch_size=bsz,
src_len=k.size(-2),
tgt_len=q.size(-2),
num_heads=self.num_heads,
)
key_padding_mask = None
kwargs["att_mask"] = attn_mask
if self.attention.supports_key_padding_mask:
kwargs["key_padding_mask"] = key_padding_mask
y = self.attention(q, k, v, **kwargs)
y = (
y.view(bsz, self.num_heads, tgt_len, self.head_dim)
.transpose(1, 2)
.flatten(start_dim=2, end_dim=3)
.transpose(0, 1)
)
assert list(y.size()) == [tgt_len, bsz, embed_dim]
# Dropout not needed because already applied in attention.
# It is applied to the attention weights before matmul with v.
y = self.out_proj(y)
# TODO: support returning attention weights if needed.
return y, None
def forward(
self,
query,
key: Optional[Tensor],
value: Optional[Tensor],
key_padding_mask: Optional[Tensor] = None,
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
need_weights: bool = True,
static_kv: bool = False,
attn_mask: Optional[Tensor] = None,
before_softmax: bool = False,
need_head_weights: bool = False,
) -> Tuple[Tensor, Optional[Tensor]]:
"""Input shape: Time x Batch x Channel
Args:
key_padding_mask (ByteTensor, optional): mask to exclude
keys that are pads, of shape `(batch, src_len)`, where
padding elements are indicated by 1s.
need_weights (bool, optional): return the attention weights,
averaged over heads (default: False).
attn_mask (ByteTensor, optional): typically used to
implement causal attention, where the mask prevents the
attention from looking forward in time (default: None).
before_softmax (bool, optional): return the raw attention
weights and values before the attention softmax.
need_head_weights (bool, optional): return the attention
weights for each head. Implies *need_weights*. Default:
return the average attention weights over all heads.
"""
if need_head_weights:
need_weights = True
is_tpu = query.device.type == "xla"
tgt_len, bsz, embed_dim = query.size()
src_len = tgt_len
if not self.skip_embed_dim_check:
assert (
embed_dim == self.embed_dim
), f"query dim {embed_dim} != {self.embed_dim}"
assert list(query.size()) == [tgt_len, bsz, embed_dim]
if key is not None:
src_len, key_bsz, _ = key.size()
if not torch.jit.is_scripting():
assert value is not None
assert src_len, key_bsz == value.shape[:2]
if (
not self.onnx_trace
and not is_tpu # don't use PyTorch version on TPUs
and incremental_state is None
and not static_kv
# A workaround for quantization to work. Otherwise JIT compilation
# treats bias in linear module as method.
and not torch.jit.is_scripting()
# The Multihead attention implemented in pytorch forces strong dimension check
# for input embedding dimention and K,Q,V projection dimension.
# Since pruning will break the dimension check and it is not easy to modify the pytorch API,
# it is preferred to bypass the pytorch MHA when we need to skip embed_dim_check
and not self.skip_embed_dim_check
):
assert key is not None and value is not None
if self.use_xformers:
return self._xformers_attn_forward(
query, key, value, key_padding_mask, need_weights, attn_mask
)
else:
return F.multi_head_attention_forward(
query,
key,
value,
self.embed_dim,
self.num_heads,
torch.empty([0]),
torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)),
self.bias_k,
self.bias_v,
self.add_zero_attn,
self.dropout_module.p,
self.out_proj.weight,
self.out_proj.bias,
self.training or self.dropout_module.apply_during_inference,
key_padding_mask,
need_weights,
attn_mask,
use_separate_proj_weight=True,
q_proj_weight=self.q_proj.weight,
k_proj_weight=self.k_proj.weight,
v_proj_weight=self.v_proj.weight,
)
if incremental_state is not None:
saved_state = self._get_input_buffer(incremental_state)
if saved_state is not None and "prev_key" in saved_state:
# previous time steps are cached - no need to recompute
# key and value if they are static
if static_kv:
assert self.encoder_decoder_attention and not self.self_attention
key = value = None
else:
saved_state = None
if self.self_attention:
q = self.q_proj(query)
k = self.k_proj(query)
v = self.v_proj(query)
elif self.encoder_decoder_attention:
# encoder-decoder attention
q = self.q_proj(query)
if key is None:
assert value is None
k = v = None
else:
if self.beam_size > 1 and bsz == key.size(1):
# key is [T, bsz*beam_size, C], reduce to [T, bsz, C]
key = key.view(key.size(0), -1, self.beam_size, key.size(2))[
:, :, 0, :
]
if key_padding_mask is not None:
key_padding_mask = key_padding_mask.view(
-1, self.beam_size, key_padding_mask.size(1)
)[:, 0, :]
k = self.k_proj(key)
v = self.v_proj(key)
else:
assert key is not None and value is not None
q = self.q_proj(query)
k = self.k_proj(key)
v = self.v_proj(value)
q *= self.scaling
if self.bias_k is not None:
assert self.bias_v is not None
k, v, attn_mask, key_padding_mask = self._add_bias(
k, v, attn_mask, key_padding_mask, bsz
)
q = (
q.contiguous()
.view(tgt_len, bsz * self.num_heads, self.head_dim)
.transpose(0, 1)
)
kv_bsz = bsz # need default value for scripting
if k is not None:
kv_bsz = k.size(1)
k = (
k.contiguous()
.view(-1, kv_bsz * self.num_heads, self.head_dim)
.transpose(0, 1)
)
if v is not None:
v = (
v.contiguous()
.view(-1, kv_bsz * self.num_heads, self.head_dim)
.transpose(0, 1)
)
if saved_state is not None:
# saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
if "prev_key" in saved_state:
_prev_key = saved_state["prev_key"]
assert _prev_key is not None
kv_bsz = _prev_key.size(0)
prev_key = _prev_key.view(kv_bsz * self.num_heads, -1, self.head_dim)
if static_kv:
k = prev_key
else:
assert k is not None
k = torch.cat([prev_key, k], dim=1)
src_len = k.size(1)
if "prev_value" in saved_state:
_prev_value = saved_state["prev_value"]
assert _prev_value is not None
assert kv_bsz == _prev_value.size(0)
prev_value = _prev_value.view(
kv_bsz * self.num_heads, -1, self.head_dim
)
if static_kv:
v = prev_value
else:
assert v is not None
v = torch.cat([prev_value, v], dim=1)
prev_key_padding_mask: Optional[Tensor] = None
if "prev_key_padding_mask" in saved_state:
prev_key_padding_mask = saved_state["prev_key_padding_mask"]
assert k is not None and v is not None
key_padding_mask = MultiheadAttention._append_prev_key_padding_mask(
key_padding_mask=key_padding_mask,
prev_key_padding_mask=prev_key_padding_mask,
batch_size=kv_bsz,
src_len=k.size(1),
static_kv=static_kv,
)
saved_state["prev_key"] = k.view(kv_bsz, self.num_heads, -1, self.head_dim)
saved_state["prev_value"] = v.view(
kv_bsz, self.num_heads, -1, self.head_dim
)
saved_state["prev_key_padding_mask"] = key_padding_mask
# In this branch incremental_state is never None
assert incremental_state is not None
incremental_state = self._set_input_buffer(incremental_state, saved_state)
assert k is not None
assert k.size(1) == src_len
# This is part of a workaround to get around fork/join parallelism
# not supporting Optional types.
if key_padding_mask is not None and key_padding_mask.dim() == 0:
key_padding_mask = None
if key_padding_mask is not None:
assert key_padding_mask.size(0) == kv_bsz
assert key_padding_mask.size(1) == src_len
if self.add_zero_attn:
assert v is not None
src_len += 1
k, v, key_padding_mask, attn_mask = self._append_zero_attn(
k=k, v=v, key_padding_mask=key_padding_mask, attn_mask=attn_mask
)
if self.encoder_decoder_attention and bsz != kv_bsz:
attn_weights = torch.einsum(
"bxhtd,bhsd->bxhts",
q.view((kv_bsz, -1, self.num_heads) + q.size()[1:]),
k.view((kv_bsz, self.num_heads) + k.size()[1:]),
)
attn_weights = attn_weights.reshape((-1,) + attn_weights.size()[-2:])
else:
attn_weights = torch.bmm(q, k.transpose(1, 2))
attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
if attn_mask is not None:
attn_mask = attn_mask.unsqueeze(0)
if self.onnx_trace:
attn_mask = attn_mask.repeat(attn_weights.size(0), 1, 1)
attn_weights += attn_mask
if key_padding_mask is not None:
# don't attend to padding symbols
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
if not is_tpu:
attn_weights = attn_weights.view(
kv_bsz, -1, self.num_heads, tgt_len, src_len
)
attn_weights = attn_weights.masked_fill(
key_padding_mask.unsqueeze(1)
.unsqueeze(2)
.unsqueeze(3)
.to(torch.bool),
float("-inf"),
)
else:
attn_weights = attn_weights.transpose(0, 2)
attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf"))
attn_weights = attn_weights.transpose(0, 2)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
if before_softmax:
return attn_weights, v
attn_weights_float = softmax(
attn_weights, dim=-1, onnx_trace=self.onnx_trace
)
attn_weights = attn_weights_float.type_as(attn_weights)
attn_probs = self.dropout_module(attn_weights)
assert v is not None
if self.encoder_decoder_attention and bsz != kv_bsz:
attn = torch.einsum(
"bxhts,bhsd->bxhtd",
attn_probs.view(
(
kv_bsz,
-1,
self.num_heads,
)
+ attn_probs.size()[1:]
),
v.view(
(
kv_bsz,
self.num_heads,
)
+ v.size()[1:]
),
)
attn = attn.reshape((-1,) + attn.size()[-2:])
else:
attn = torch.bmm(attn_probs, v)
assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
if self.onnx_trace and attn.size(1) == 1:
# when ONNX tracing a single decoder step (sequence length == 1)
# the transpose is a no-op copy before view, thus unnecessary
attn = attn.contiguous().view(tgt_len, bsz, self.embed_dim)
else:
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, self.embed_dim)
attn = self.out_proj(attn)
attn_weights: Optional[Tensor] = None
if need_weights:
attn_weights = attn_weights_float.view(
bsz, self.num_heads, tgt_len, src_len
).transpose(1, 0)
if not need_head_weights:
# average attention weights over heads
attn_weights = attn_weights.mean(dim=0)
return attn, attn_weights
@staticmethod
def _append_prev_key_padding_mask(
key_padding_mask: Optional[Tensor],
prev_key_padding_mask: Optional[Tensor],
batch_size: int,
src_len: int,
static_kv: bool,
) -> Optional[Tensor]:
# saved key padding masks have shape (bsz, seq_len)
if prev_key_padding_mask is not None and static_kv:
new_key_padding_mask = prev_key_padding_mask
elif prev_key_padding_mask is not None and key_padding_mask is not None:
new_key_padding_mask = torch.cat(
[prev_key_padding_mask.float(), key_padding_mask.float()], dim=1
)
# During incremental decoding, as the padding token enters and
# leaves the frame, there will be a time when prev or current
# is None
elif prev_key_padding_mask is not None:
if src_len > prev_key_padding_mask.size(1):
filler = torch.zeros(
(batch_size, src_len - prev_key_padding_mask.size(1)),
device=prev_key_padding_mask.device,
)
new_key_padding_mask = torch.cat(
[prev_key_padding_mask.float(), filler.float()], dim=1
)
else:
new_key_padding_mask = prev_key_padding_mask.float()
elif key_padding_mask is not None:
if src_len > key_padding_mask.size(1):
filler = torch.zeros(
(batch_size, src_len - key_padding_mask.size(1)),
device=key_padding_mask.device,
)
new_key_padding_mask = torch.cat(
[filler.float(), key_padding_mask.float()], dim=1
)
else:
new_key_padding_mask = key_padding_mask.float()
else:
new_key_padding_mask = prev_key_padding_mask
return new_key_padding_mask
@torch.jit.export
def reorder_incremental_state(
self,
incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
new_order: Tensor,
):
"""Reorder buffered internal state (for incremental generation)."""
input_buffer = self._get_input_buffer(incremental_state)
if input_buffer is not None:
for k in input_buffer.keys():
input_buffer_k = input_buffer[k]
if input_buffer_k is not None:
if self.encoder_decoder_attention:
if input_buffer_k.size(0) * self.beam_size == new_order.size(0):
return incremental_state
elif self.beam_size > 1:
input_buffer[k] = input_buffer_k.index_select(
0,
new_order.reshape(-1, self.beam_size)[:, 0]
// self.beam_size,
)
else:
input_buffer[k] = input_buffer_k.index_select(0, new_order)
else:
input_buffer[k] = input_buffer_k.index_select(0, new_order)
incremental_state = self._set_input_buffer(incremental_state, input_buffer)
return incremental_state
def set_beam_size(self, beam_size):
"""Used for effiecient beamable enc-dec attention"""
self.beam_size = beam_size
def _get_input_buffer(
self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]
) -> Dict[str, Optional[Tensor]]:
result = self.get_incremental_state(incremental_state, "attn_state")
if result is not None:
return result
else:
empty_result: Dict[str, Optional[Tensor]] = {}
return empty_result
def _set_input_buffer(
self,
incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
buffer: Dict[str, Optional[Tensor]],
):
return self.set_incremental_state(incremental_state, "attn_state", buffer)
def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int):
return attn_weights
def upgrade_state_dict_named(self, state_dict, name):
prefix = name + "." if name != "" else ""
items_to_add = {}
keys_to_remove = []
for k in state_dict.keys():
if k.endswith(prefix + "in_proj_weight"):
# in_proj_weight used to be q + k + v with same dimensions
dim = int(state_dict[k].shape[0] / 3)
items_to_add[prefix + "q_proj.weight"] = state_dict[k][:dim]
items_to_add[prefix + "k_proj.weight"] = state_dict[k][dim : 2 * dim]
items_to_add[prefix + "v_proj.weight"] = state_dict[k][2 * dim :]
keys_to_remove.append(k)
k_bias = prefix + "in_proj_bias"
if k_bias in state_dict.keys():
dim = int(state_dict[k].shape[0] / 3)
items_to_add[prefix + "q_proj.bias"] = state_dict[k_bias][:dim]
items_to_add[prefix + "k_proj.bias"] = state_dict[k_bias][
dim : 2 * dim
]
items_to_add[prefix + "v_proj.bias"] = state_dict[k_bias][2 * dim :]
keys_to_remove.append(prefix + "in_proj_bias")
for k in keys_to_remove:
del state_dict[k]
for key, value in items_to_add.items():
state_dict[key] = value
@dataclass
class QuantNoiseConfig:
_name: str = "transformer"
pq: float = 0.0
pq_block_size: int = 8
scalar: float = 0.0
def to_dict(self):
return asdict(self)
@classmethod
def from_dict(cls, data):
return cls(**data)
@dataclass
class EncDecBaseConfig:
_name: str = "transformer"
embed_path: Optional[str] = None
embed_dim: int = 768
ffn_embed_dim: int = 3072
layers: int = 12
attention_heads: int = 12
normalize_before: bool = False
learned_pos: bool = False
layerdrop: float = 0.0
layers_to_keep: Optional[list[int]] = None
xformers_att_config: Optional[dict] = None
quant_noise: QuantNoiseConfig = field(default_factory=QuantNoiseConfig)
padding_idx= 1
vocab_size = 64001
@dataclass
class DecoderConfig(EncDecBaseConfig):
input_dim: int = 768
output_dim: int = 768
vocab_size = 528
@dataclass
class TransformerConfig:
_name: str = "transformer"
activation_fn: str = "relu"
dropout: float = 0.1
attention_dropout: float = 0.1
activation_dropout: float = 0.0
adaptive_input: bool = False
encoder: EncDecBaseConfig = field(default_factory=EncDecBaseConfig)
max_source_positions: int = 1024
decoder: DecoderConfig = field(default_factory=DecoderConfig)
max_target_positions: int = 1024
share_decoder_input_output_embed: bool = True
share_all_embeddings: bool = False
no_token_positional_embeddings: bool = False
adaptive_softmax_cutoff: Optional[list[int]] = None
adaptive_softmax_dropout: float = 0.0
adaptive_softmax_factor: int = 4
layernorm_embedding: bool = False
tie_adaptive_weights: bool = False
tie_adaptive_proj: bool = False
no_scale_embedding: bool = False
checkpoint_activations: bool = False
offload_activations: bool = False
no_cross_attention: bool = False
cross_self_attention: bool = False
quant_noise: QuantNoiseConfig = field(default_factory=QuantNoiseConfig)
min_params_to_wrap: int = 100_000_000
char_inputs: bool = False
relu_dropout: float = 0.0
base_layers: int = 0
base_sublayers: int = 1
base_shuffle: int = 1
export: bool = False
no_decoder_final_norm: bool = False
# Example of instantiating the config
main_config = TransformerConfig()
class TokenEmbedding(nn.Module):
def __init__(self, vocab_size, embed_dim, padding_idx):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx)
self.vocab_size = vocab_size
self.embedding_dim = embed_dim
self.padding_idx = padding_idx
def forward(self, input_tokens):
return self.embedding(input_tokens)
# Example Usage
def initialize_embed_tokens(cfg, model='encoder'):
"""
Initialize the embed_tokens layer.
Args:
cfg: Configuration object
dictionary: Vocabulary dictionary with token-to-index mapping
Returns:
embed_tokens: Token embedding layer
"""
vocab_size = cfg.encoder.vocab_size if model == 'encoder' else cfg.decoder.vocab_size # Assuming this attribute is added in the config
embed_dim = cfg.encoder.embed_dim # Assuming this attribute is added in the config
padding_idx = cfg.encoder.padding_idx #dictionary.pad() # Fetch the padding index from the dictionary
return TokenEmbedding(vocab_size, embed_dim, padding_idx)
class EncoderDecoderModel(nn.Module):
"""Standalone Encoder-Decoder model for Fairseq with necessary functionalities."""
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
self.encoder = TransformerEncoderBase(cfg, enc_dictionary, encoder_embedding.embedding)
self.decoder = TransformerDecoderBase(cfg, dec_dictionary, decoder_embedding.embedding)
self.supports_align_args = True
self._is_generation_fast = False
def forward(self, src_tokens, src_lengths, prev_output_tokens, **kwargs):
"""
Perform a forward pass.
Args:
src_tokens (LongTensor): Source tokens `(batch, src_len)`
src_lengths (LongTensor): Source lengths `(batch)`
prev_output_tokens (LongTensor): Previous decoder outputs `(batch, tgt_len)`
Returns:
Tuple: decoder output and additional info
"""
encoder_out = self.encoder(src_tokens, src_lengths=src_lengths,
**kwargs)
decoder_out = self.decoder(
prev_output_tokens, encoder_out=encoder_out, **kwargs
)
return decoder_out
def forward_decoder(self, prev_output_tokens, **kwargs):
return self.decoder(prev_output_tokens, **kwargs)
def output_layer(self, features, **kwargs):
"""Project features to the default output size (typically vocabulary size)."""
return self.decoder.output_layer(features, **kwargs)
def max_positions(self):
"""Maximum length supported by the model."""
return (self.encoder.max_positions(), self.decoder.max_positions())
def max_decoder_positions(self):
"""Maximum length supported by the decoder."""
return self.decoder.max_positions()
encoder_embedding = initialize_embed_tokens(main_config)
decoder_embedding = initialize_embed_tokens(main_config, 'decoder')
enc_dictionary = [9]* main_config.encoder.vocab_size
dec_dictionary = [9] * main_config.decoder.vocab_size
class AfroLidForSequenceClassification(PreTrainedModel):
config_class = AfroLidConfig
base_model_prefix = "transformer"
def __init__(self, config):
super().__init__(config)
self.cfg = main_config
self.encoder = TransformerEncoderBase(self.cfg, enc_dictionary, encoder_embedding.embedding)
self.decoder = TransformerDecoderBase(self.cfg, dec_dictionary, decoder_embedding.embedding)
self.supports_align_args = True
self._is_generation_fast = False
def forward(self, src_tokens, src_lengths, prev_output_tokens, **kwargs):
"""
Perform a forward pass.
Args:
src_tokens (LongTensor): Source tokens `(batch, src_len)`
src_lengths (LongTensor): Source lengths `(batch)`
prev_output_tokens (LongTensor): Previous decoder outputs `(batch, tgt_len)`
Returns:
Tuple: decoder output and additional info
"""
encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs)
decoder_out = self.decoder(
prev_output_tokens, encoder_out=encoder_out, **kwargs
)
return decoder_out
def forward_decoder(self, prev_output_tokens, **kwargs):
return self.decoder(prev_output_tokens, **kwargs)
def output_layer(self, features, **kwargs):
"""Project features to the default output size (typically vocabulary size)."""
return self.decoder.output_layer(features, **kwargs)
def max_positions(self):
"""Maximum length supported by the model."""
return (self.encoder.max_positions(), self.decoder.max_positions())
def max_decoder_positions(self):
"""Maximum length supported by the decoder."""
return self.decoder.max_positions()
config = AfroLidConfig()
afrolid_model = AfroLidForSequenceClassification(config)
AutoConfig.register("afrolid", AfroLidConfig)
AutoModel.register(AfroLidConfig, AfroLidForSequenceClassification)
AutoModelForSequenceClassification.register(
AfroLidConfig, AfroLidForSequenceClassification)