Spaces:
Running
Running
import math | |
import numpy as np | |
import torch | |
from torch import nn, Tensor | |
from torch.nn import TransformerEncoder, TransformerEncoderLayer | |
# from cmib.model.positional_encoding import PositionalEmbedding | |
class SinPositionalEncoding(nn.Module): | |
def __init__(self, d_model, dropout=0.1, max_len=100): | |
super(SinPositionalEncoding, self).__init__() | |
self.dropout = nn.Dropout(p=dropout) | |
pe = torch.zeros(max_len, d_model) | |
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) | |
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model)) | |
pe[:, 0::2] = torch.sin(position * div_term) | |
pe[:, 1::2] = torch.cos(position * div_term) | |
pe = pe.unsqueeze(0).transpose(0, 1) | |
self.register_buffer('pe', pe) | |
def forward(self, x): | |
# not used in the final model | |
x = x + self.pe[:x.shape[0], :] | |
return self.dropout(x) | |
class MultiHeadedAttention(nn.Module): | |
def __init__(self, n_head, d_model, d_head, dropout=0.1, | |
pre_lnorm=True, bias=False): | |
""" | |
Multi-headed attention with relative positional encoding and | |
memory mechanism. | |
Args: | |
n_head (int): Number of heads. | |
d_model (int): Input dimension. | |
d_head (int): Head dimension. | |
dropout (float, optional): Dropout value. Defaults to 0.1. | |
pre_lnorm (bool, optional): | |
Apply layer norm before rest of calculation. Defaults to True. | |
In original Transformer paper (pre_lnorm=False): | |
LayerNorm(x + Sublayer(x)) | |
In tensor2tensor implementation (pre_lnorm=True): | |
x + Sublayer(LayerNorm(x)) | |
bias (bool, optional): | |
Add bias to q, k, v and output projections. Defaults to False. | |
""" | |
super(MultiHeadedAttention, self).__init__() | |
self.n_head = n_head | |
self.d_model = d_model | |
self.d_head = d_head | |
self.dropout = dropout | |
self.pre_lnorm = pre_lnorm | |
self.bias = bias | |
self.atten_scale = 1 / math.sqrt(self.d_model) | |
self.q_linear = nn.Linear(d_model, n_head * d_head, bias=bias) | |
self.k_linear = nn.Linear(d_model, n_head * d_head, bias=bias) | |
self.v_linear = nn.Linear(d_model, n_head * d_head, bias=bias) | |
self.out_linear = nn.Linear(n_head * d_head, d_model, bias=bias) | |
self.droput_layer = nn.Dropout(dropout) | |
self.atten_dropout_layer = nn.Dropout(dropout) | |
self.layer_norm = nn.LayerNorm(d_model) | |
def forward(self, hidden, memory=None, mask=None, | |
extra_atten_score=None): | |
""" | |
Args: | |
hidden (Tensor): Input embedding or hidden state of previous layer. | |
Shape: (batch, seq, dim) | |
pos_emb (Tensor): Relative positional embedding lookup table. | |
Shape: (batch, (seq+mem_len)*2-1, d_head) | |
pos_emb[:, seq+mem_len] | |
memory (Tensor): Memory tensor of previous layer. | |
Shape: (batch, mem_len, dim) | |
mask (BoolTensor, optional): Attention mask. | |
Set item value to True if you DO NOT want keep certain | |
attention score, otherwise False. Defaults to None. | |
Shape: (seq, seq+mem_len). | |
""" | |
combined = hidden | |
# if memory is None: | |
# combined = hidden | |
# mem_len = 0 | |
# else: | |
# combined = torch.cat([memory, hidden], dim=1) | |
# mem_len = memory.shape[1] | |
if self.pre_lnorm: | |
hidden = self.layer_norm(hidden) | |
combined = self.layer_norm(combined) | |
# shape: (batch, q/k/v_len, dim) | |
q = self.q_linear(hidden) | |
k = self.k_linear(combined) | |
v = self.v_linear(combined) | |
# reshape to (batch, q/k/v_len, n_head, d_head) | |
q = q.reshape(q.shape[0], q.shape[1], self.n_head, self.d_head) | |
k = k.reshape(k.shape[0], k.shape[1], self.n_head, self.d_head) | |
v = v.reshape(v.shape[0], v.shape[1], self.n_head, self.d_head) | |
# transpose to (batch, n_head, q/k/v_len, d_head) | |
q = q.transpose(1, 2) | |
k = k.transpose(1, 2) | |
v = v.transpose(1, 2) | |
# add n_head dimension for relative positional embedding lookup table | |
# (batch, n_head, k/v_len*2-1, d_head) | |
# pos_emb = pos_emb[:, None] | |
# (batch, n_head, q_len, k_len) | |
atten_score = torch.matmul(q, k.transpose(-1, -2)) | |
# qpos = torch.matmul(q, pos_emb.transpose(-1, -2)) | |
# DEBUG | |
# ones = torch.zeros(q.shape) | |
# ones[:, :, :, 0] = 1.0 | |
# qpos = torch.matmul(ones, pos_emb.transpose(-1, -2)) | |
# atten_score = atten_score + self.skew(qpos, mem_len) | |
atten_score = atten_score * self.atten_scale | |
# if extra_atten_score is not None: | |
# atten_score = atten_score + extra_atten_score | |
if mask is not None: | |
# print(atten_score.shape) | |
# print(mask.shape) | |
# apply attention mask | |
atten_score = atten_score.masked_fill(mask, float("-inf")) | |
atten_score = atten_score.softmax(dim=-1) | |
atten_score = self.atten_dropout_layer(atten_score) | |
# (batch, n_head, q_len, d_head) | |
atten_vec = torch.matmul(atten_score, v) | |
# (batch, q_len, n_head*d_head) | |
atten_vec = atten_vec.transpose(1, 2).flatten(start_dim=-2) | |
# linear projection | |
output = self.droput_layer(self.out_linear(atten_vec)) | |
if self.pre_lnorm: | |
return hidden + output | |
else: | |
return self.layer_norm(hidden + output) | |
class FeedForward(nn.Module): | |
def __init__(self, d_model, d_inner, dropout=0.1, pre_lnorm=True): | |
""" | |
Positionwise feed-forward network. | |
Args: | |
d_model(int): Dimension of the input and output. | |
d_inner (int): Dimension of the middle layer(bottleneck). | |
dropout (float, optional): Dropout value. Defaults to 0.1. | |
pre_lnorm (bool, optional): | |
Apply layer norm before rest of calculation. Defaults to True. | |
In original Transformer paper (pre_lnorm=False): | |
LayerNorm(x + Sublayer(x)) | |
In tensor2tensor implementation (pre_lnorm=True): | |
x + Sublayer(LayerNorm(x)) | |
""" | |
super(FeedForward, self).__init__() | |
self.d_model = d_model | |
self.d_inner = d_inner | |
self.dropout = dropout | |
self.pre_lnorm = pre_lnorm | |
self.layer_norm = nn.LayerNorm(d_model) | |
self.network = nn.Sequential( | |
nn.Linear(d_model, d_inner), | |
nn.ReLU(), | |
nn.Dropout(dropout), | |
nn.Linear(d_inner, d_model), | |
nn.Dropout(dropout), | |
) | |
def forward(self, x): | |
if self.pre_lnorm: | |
return x + self.network(self.layer_norm(x)) | |
else: | |
return self.layer_norm(x + self.network(x)) | |
class TransformerModel(nn.Module): | |
def __init__( | |
self, | |
seq_len: int, | |
input_dim: int, | |
d_model: int, | |
nhead: int, | |
d_hid: int, | |
nlayers: int, | |
dropout: float = 0.5, | |
out_dim=91, | |
masked_attention_stage=False, | |
): | |
super().__init__() | |
self.model_type = "Transformer" | |
self.seq_len = seq_len | |
self.d_model = d_model | |
self.nhead = nhead | |
self.d_hid = d_hid | |
self.nlayers = nlayers | |
self.pos_embedding = SinPositionalEncoding(d_model=d_model, dropout=0.1, max_len=seq_len) | |
if masked_attention_stage: | |
self.input_layer = nn.Linear(input_dim+1, d_model) | |
# visible to invisible attention | |
self.att_layers = nn.ModuleList() | |
self.pff_layers = nn.ModuleList() | |
self.pre_lnorm = True | |
self.layer_norm = nn.LayerNorm(d_model) | |
for i in range(self.nlayers): | |
self.att_layers.append( | |
MultiHeadedAttention( | |
self.nhead, self.d_model, | |
self.d_model // self.nhead, dropout=dropout, | |
pre_lnorm=True, | |
bias=False | |
) | |
) | |
self.pff_layers.append( | |
FeedForward( | |
self.d_model, d_hid, | |
dropout=dropout, | |
pre_lnorm=True | |
) | |
) | |
else: | |
self.att_layers = None | |
self.input_layer = nn.Linear(input_dim, d_model) | |
encoder_layers = TransformerEncoderLayer( | |
d_model, nhead, d_hid, dropout, activation="gelu" | |
) | |
self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers) | |
self.decoder = nn.Linear(d_model, out_dim) | |
self.init_weights() | |
def init_weights(self) -> None: | |
initrange = 0.1 | |
self.decoder.bias.data.zero_() | |
self.decoder.weight.data.uniform_(-initrange, initrange) | |
def forward(self, src: Tensor, src_mask: Tensor, data_mask=None, atten_mask=None) -> Tensor: | |
""" | |
Args: | |
src: Tensor, shape [seq_len, batch_size, embedding_dim] | |
src_mask: Tensor, shape [seq_len, seq_len] | |
Returns: | |
output Tensor of shape [seq_len, batch_size, embedding_dim] | |
""" | |
if not data_mask is None: | |
src = torch.cat([src, data_mask.expand(*src.shape[:-1], data_mask.shape[-1])], dim=-1) | |
src = self.input_layer(src) | |
output = self.pos_embedding(src) | |
# output = src | |
if self.att_layers: | |
assert not atten_mask is None | |
output = output.permute(1, 0, 2) | |
for i in range(self.nlayers): | |
output = self.att_layers[i](output, mask=atten_mask) | |
output = self.pff_layers[i](output) | |
if self.pre_lnorm: | |
output = self.layer_norm(output) | |
output = output.permute(1, 0, 2) | |
output = self.transformer_encoder(output) | |
output = self.decoder(output) | |
return output | |