|
|
|
|
|
|
|
|
|
""" |
|
Implementation of the following modules is borrowed from ml-cvnets repo: |
|
https://github.com/apple/ml-cvnets/blob/main/cvnets/layers/multi_head_attention.py |
|
https://github.com/apple/ml-cvnets/blob/main/cvnets/text_encoders/transformer.py |
|
|
|
Please see ACKNOWLEDGEMENTS for license details. |
|
""" |
|
|
|
from typing import List, Optional, Union |
|
|
|
import torch |
|
from torch import Size, Tensor, nn |
|
from torch.nn import functional as F |
|
from torchvision.ops import StochasticDepth |
|
|
|
|
|
class LayerNormFP32(nn.LayerNorm): |
|
""" |
|
Applies `Layer Normalization <https://arxiv.org/abs/1607.06450>`_ over a input tensor with FP32 precision |
|
""" |
|
|
|
def __init__( |
|
self, |
|
normalized_shape: Union[int, List[int], Size], |
|
eps: Optional[float] = 1e-5, |
|
elementwise_affine: Optional[bool] = True, |
|
*args, |
|
**kwargs, |
|
): |
|
super().__init__( |
|
normalized_shape=normalized_shape, |
|
eps=eps, |
|
elementwise_affine=elementwise_affine, |
|
*args, |
|
**kwargs, |
|
) |
|
|
|
def forward(self, x: Tensor) -> Tensor: |
|
|
|
|
|
inp_dtype = x.dtype |
|
return super().forward(x.to(torch.float32)).to(inp_dtype) |
|
|
|
|
|
def get_normalization_layer(norm_type, num_features): |
|
if norm_type == "layer_norm": |
|
return nn.LayerNorm(num_features) |
|
elif norm_type == "layer_norm_fp32": |
|
return LayerNormFP32(num_features) |
|
else: |
|
raise NotImplementedError(f"Option: {norm_type} not supported.") |
|
|
|
|
|
class PositionalEmbedding(nn.Module): |
|
def __init__( |
|
self, |
|
num_embeddings: int, |
|
embedding_dim: int, |
|
padding_idx: Optional[int] = None, |
|
is_learnable: Optional[bool] = False, |
|
interpolation_mode: Optional[str] = "bilinear", |
|
*args, |
|
**kwargs, |
|
): |
|
super().__init__() |
|
|
|
module = LearnablePositionalEmbedding |
|
|
|
self.pos_embed = module( |
|
num_embeddings=num_embeddings, |
|
embedding_dim=embedding_dim, |
|
padding_idx=padding_idx, |
|
interpolation_mode=interpolation_mode, |
|
*args, |
|
**kwargs, |
|
) |
|
|
|
def forward(self, seq_len: int, *args, **kwargs) -> Tensor: |
|
return self.pos_embed(seq_len, *args, **kwargs) |
|
|
|
def __repr__(self): |
|
return self.pos_embed.__repr__() |
|
|
|
|
|
class LearnablePositionalEmbedding(nn.Module): |
|
"""Learnable Positional embedding""" |
|
|
|
def __init__( |
|
self, |
|
num_embeddings: int, |
|
embedding_dim: int, |
|
padding_idx: Optional[int] = None, |
|
interpolation_mode: Optional[str] = "bilinear", |
|
*args, |
|
**kwargs, |
|
): |
|
super().__init__() |
|
self.pos_embed = nn.Parameter(torch.empty(1, 1, num_embeddings, embedding_dim)) |
|
self.embedding_dim = embedding_dim |
|
self.num_embeddings = num_embeddings |
|
self.padding_idx = padding_idx |
|
self.interpolation_mode = interpolation_mode |
|
|
|
self.reset_parameters() |
|
|
|
def reset_parameters(self) -> None: |
|
nn.init.trunc_normal_(self.pos_embed, mean=0, std=self.embedding_dim**-0.5) |
|
if self.padding_idx is not None: |
|
with torch.no_grad(): |
|
self.pos_embed[:, :, self.padding_idx, ...] = 0.0 |
|
|
|
def forward(self, seq_len: int, *args, **kwargs) -> Tensor: |
|
|
|
pos_embed = self.pos_embed |
|
if self.padding_idx is not None: |
|
with torch.no_grad(): |
|
pos_embed[:, :, self.padding_idx, ...] = 0.0 |
|
|
|
if seq_len != self.num_embeddings: |
|
pos_embed = F.interpolate( |
|
pos_embed, |
|
size=(seq_len, self.embedding_dim), |
|
mode=self.interpolation_mode, |
|
) |
|
|
|
|
|
return pos_embed.reshape(1, seq_len, self.embedding_dim) |
|
|
|
def __repr__(self): |
|
return "{}(num_embeddings={}, embedding_dim={}, padding_idx={})".format( |
|
self.__class__.__name__, |
|
self.num_embeddings, |
|
self.embedding_dim, |
|
self.padding_idx, |
|
) |
|
|
|
|
|
class MultiHeadAttention(nn.Module): |
|
""" |
|
This layer applies a multi-head self- or cross-attention as described in |
|
`Attention is all you need <https://arxiv.org/abs/1706.03762>`_ paper |
|
|
|
Args: |
|
embed_dim (int): :math:`C_{in}` from an expected input of size :math:`(N, S, C_{in})` |
|
num_heads (int): Number of heads in multi-head attention |
|
attn_dropout (Optional[float]): Attention dropout. Default: 0.0 |
|
bias (Optional[bool]): Use bias or not. Default: ``True`` |
|
|
|
Shape: |
|
- Input: |
|
- Query tensor (x_q) :math:`(N, S, C_{in})` where :math:`N` is batch size, :math:`S` is number of source tokens, |
|
and :math:`C_{in}` is input embedding dim |
|
- Optional Key-Value tensor (x_kv) :math:`(N, T, C_{in})` where :math:`T` is number of target tokens |
|
- Output: same shape as the input |
|
|
|
""" |
|
|
|
def __init__( |
|
self, |
|
embed_dim: int, |
|
num_heads: int, |
|
attn_dropout: Optional[float] = 0.0, |
|
bias: Optional[bool] = True, |
|
output_dim: Optional[int] = None, |
|
*args, |
|
**kwargs, |
|
) -> None: |
|
if output_dim is None: |
|
output_dim = embed_dim |
|
super().__init__() |
|
if embed_dim % num_heads != 0: |
|
Warning( |
|
"Embedding dim must be divisible by number of heads in {}. Got: embed_dim={} and num_heads={}".format( |
|
self.__class__.__name__, embed_dim, num_heads |
|
) |
|
) |
|
|
|
self.qkv_proj = nn.Linear( |
|
in_features=embed_dim, out_features=3 * embed_dim, bias=bias |
|
) |
|
|
|
self.attn_dropout = nn.Dropout(p=attn_dropout) |
|
self.out_proj = nn.Linear( |
|
in_features=embed_dim, out_features=output_dim, bias=bias |
|
) |
|
|
|
self.head_dim = embed_dim // num_heads |
|
self.scaling = self.head_dim**-0.5 |
|
self.softmax = nn.Softmax(dim=-1) |
|
self.num_heads = num_heads |
|
self.embed_dim = embed_dim |
|
self.use_separate_proj_weight = embed_dim != output_dim |
|
|
|
def __repr__(self): |
|
return "{}(head_dim={}, num_heads={}, attn_dropout={})".format( |
|
self.__class__.__name__, self.head_dim, self.num_heads, self.attn_dropout.p |
|
) |
|
|
|
def _forward_impl( |
|
self, |
|
x_q: Tensor, |
|
x_kv: Optional[Tensor] = None, |
|
key_padding_mask: Optional[Tensor] = None, |
|
attn_mask: Optional[Tensor] = None, |
|
) -> Tensor: |
|
|
|
b_sz, S_len, in_channels = x_q.shape |
|
|
|
if x_kv is None: |
|
|
|
|
|
qkv = self.qkv_proj(x_q).reshape(b_sz, S_len, 3, self.num_heads, -1) |
|
|
|
qkv = qkv.transpose(1, 3).contiguous() |
|
|
|
|
|
query, key, value = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2] |
|
else: |
|
T_len = x_kv.shape[1] |
|
|
|
|
|
|
|
query = F.linear( |
|
x_q, |
|
weight=self.qkv_proj.weight[: self.embed_dim, ...], |
|
bias=self.qkv_proj.bias[: self.embed_dim] |
|
if self.qkv_proj.bias is not None |
|
else None, |
|
) |
|
|
|
query = ( |
|
query.reshape(b_sz, S_len, self.num_heads, self.head_dim) |
|
.transpose(1, 2) |
|
.contiguous() |
|
) |
|
|
|
|
|
kv = F.linear( |
|
x_kv, |
|
weight=self.qkv_proj.weight[self.embed_dim :, ...], |
|
bias=self.qkv_proj.bias[self.embed_dim :] |
|
if self.qkv_proj.bias is not None |
|
else None, |
|
) |
|
|
|
kv = kv.reshape(b_sz, T_len, 2, self.num_heads, self.head_dim) |
|
|
|
kv = kv.transpose(1, 3).contiguous() |
|
key, value = kv[:, :, 0], kv[:, :, 1] |
|
|
|
query = query * self.scaling |
|
|
|
|
|
key = key.transpose(-1, -2) |
|
|
|
|
|
|
|
attn = torch.matmul(query, key) |
|
|
|
batch_size, num_heads, num_src_tokens, num_tgt_tokens = attn.shape |
|
if attn_mask is not None: |
|
|
|
assert list(attn_mask.shape) == [ |
|
batch_size, |
|
num_src_tokens, |
|
num_tgt_tokens, |
|
], "Shape of attention mask should be [{}, {}, {}]. Got: {}".format( |
|
batch_size, num_src_tokens, num_tgt_tokens, attn_mask.shape |
|
) |
|
|
|
attn_mask = attn_mask.unsqueeze(1) |
|
attn = attn + attn_mask |
|
|
|
if key_padding_mask is not None: |
|
|
|
|
|
assert key_padding_mask.dim() == 2 and list(key_padding_mask.shape) == [ |
|
batch_size, |
|
num_tgt_tokens, |
|
], "Key_padding_mask should be 2-dimension with shape [{}, {}]. Got: {}".format( |
|
batch_size, num_tgt_tokens, key_padding_mask.shape |
|
) |
|
attn = attn.masked_fill( |
|
key_padding_mask.unsqueeze(1) |
|
.unsqueeze(2) |
|
.to(torch.bool), |
|
float("-inf"), |
|
) |
|
|
|
attn_dtype = attn.dtype |
|
attn_as_float = self.softmax(attn.float()) |
|
attn = attn_as_float.to(attn_dtype) |
|
attn = self.attn_dropout(attn) |
|
|
|
|
|
|
|
out = torch.matmul(attn, value) |
|
|
|
|
|
out = out.transpose(1, 2).reshape(b_sz, S_len, -1) |
|
out = self.out_proj(out) |
|
|
|
return out |
|
|
|
def forward( |
|
self, |
|
x_q: Tensor, |
|
x_kv: Optional[Tensor] = None, |
|
key_padding_mask: Optional[Tensor] = None, |
|
attn_mask: Optional[Tensor] = None, |
|
*args, |
|
**kwargs, |
|
) -> Tensor: |
|
|
|
return self._forward_impl( |
|
x_q=x_q, |
|
x_kv=x_kv, |
|
key_padding_mask=key_padding_mask, |
|
attn_mask=attn_mask, |
|
) |
|
|
|
|
|
class TransformerEncoder(nn.Module): |
|
""" |
|
This class defines the pre-norm `Transformer encoder <https://arxiv.org/abs/1706.03762>`_ |
|
Args: |
|
embed_dim: :math:`C_{in}` from an expected input of size :math:`(N, P, C_{in})`. |
|
ffn_latent_dim: Inner dimension of the FFN. |
|
num_heads: Number of heads in multi-head attention. Default: 8. |
|
attn_dropout: Dropout rate for attention in multi-head attention. Default: 0.0 |
|
dropout: Dropout rate. Default: 0.0. |
|
ffn_dropout: Dropout between FFN layers. Default: 0.0. |
|
transformer_norm_layer: Normalization layer. Default: layer_norm. |
|
stochastic_dropout: Stochastic dropout setting. Default: 0.0. |
|
|
|
Shape: |
|
- Input: :math:`(N, P, C_{in})` where :math:`N` is batch size, :math:`P` is number of patches, |
|
and :math:`C_{in}` is input embedding dim |
|
- Output: same shape as the input |
|
""" |
|
|
|
def __init__( |
|
self, |
|
embed_dim: int, |
|
ffn_latent_dim: int, |
|
num_heads: Optional[int] = 8, |
|
attn_dropout: Optional[float] = 0.0, |
|
dropout: Optional[float] = 0.0, |
|
ffn_dropout: Optional[float] = 0.0, |
|
transformer_norm_layer: Optional[str] = "layer_norm", |
|
stochastic_dropout: Optional[float] = 0.0, |
|
*args, |
|
**kwargs, |
|
) -> None: |
|
|
|
super().__init__() |
|
|
|
|
|
attn_unit = MultiHeadAttention( |
|
embed_dim, |
|
num_heads, |
|
attn_dropout=attn_dropout, |
|
bias=True, |
|
) |
|
|
|
self.pre_norm_mha = nn.Sequential( |
|
get_normalization_layer( |
|
norm_type=transformer_norm_layer, num_features=embed_dim |
|
), |
|
attn_unit, |
|
nn.Dropout(p=dropout), |
|
) |
|
|
|
act_name = nn.GELU() |
|
self.pre_norm_ffn = nn.Sequential( |
|
get_normalization_layer( |
|
norm_type=transformer_norm_layer, num_features=embed_dim |
|
), |
|
nn.Linear(in_features=embed_dim, out_features=ffn_latent_dim, bias=True), |
|
act_name, |
|
nn.Dropout(p=ffn_dropout), |
|
nn.Linear(in_features=ffn_latent_dim, out_features=embed_dim, bias=True), |
|
nn.Dropout(p=dropout), |
|
) |
|
|
|
self.drop_path = nn.Identity() |
|
if stochastic_dropout > 0.0: |
|
if dropout > 0.0: |
|
Warning( |
|
"Stochastic dropout and dropout are mutually exclusive. " |
|
"Use either of them, but not both." |
|
"Got: {} and {}".format(stochastic_dropout, dropout) |
|
) |
|
self.drop_path = StochasticDepth(p=stochastic_dropout, mode="row") |
|
|
|
self.embed_dim = embed_dim |
|
self.ffn_dim = ffn_latent_dim |
|
self.ffn_dropout = ffn_dropout |
|
self.stochastic_dropout = stochastic_dropout |
|
self.std_dropout = dropout |
|
self.attn_fn_name = attn_unit.__class__.__name__ |
|
self.act_fn_name = act_name.__class__.__name__ |
|
self.norm_type = transformer_norm_layer |
|
|
|
def __repr__(self) -> str: |
|
return "{}(embed_dim={}, ffn_dim={}, dropout={}, ffn_dropout={}, stochastic_dropout={}, attn_fn={}, act_fn={}, norm_fn={})".format( |
|
self.__class__.__name__, |
|
self.embed_dim, |
|
self.ffn_dim, |
|
self.std_dropout, |
|
self.ffn_dropout, |
|
self.stochastic_dropout, |
|
self.attn_fn_name, |
|
self.act_fn_name, |
|
self.norm_type, |
|
) |
|
|
|
def forward( |
|
self, |
|
x: Tensor, |
|
x_prev: Optional[Tensor] = None, |
|
key_padding_mask: Optional[Tensor] = None, |
|
attn_mask: Optional[Tensor] = None, |
|
*args, |
|
**kwargs, |
|
) -> Tensor: |
|
|
|
|
|
res = x |
|
x = self.pre_norm_mha[0](x) |
|
x = self.pre_norm_mha[1]( |
|
x_q=x, |
|
x_kv=x_prev, |
|
key_padding_mask=key_padding_mask, |
|
attn_mask=attn_mask, |
|
*args, |
|
**kwargs, |
|
) |
|
|
|
x = self.drop_path(self.pre_norm_mha[2](x)) |
|
x = x + res |
|
|
|
|
|
x = x + self.drop_path(self.pre_norm_ffn(x)) |
|
return x |
|
|