minBERT / bert.py
GlowCheese's picture
First model version
9756d99
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from base_bert import BertPreTrainedModel
from utils import *
class BertSelfAttention(nn.Module):
def __init__(self, config):
super().__init__()
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = config.hidden_size // config.num_attention_heads
self.all_head_size = self.num_attention_heads * self.attention_head_size
# Initialize the linear transformation layers for key, value, query.
self.query = nn.Linear(config.hidden_size, self.all_head_size)
self.key = nn.Linear(config.hidden_size, self.all_head_size)
self.value = nn.Linear(config.hidden_size, self.all_head_size)
# This dropout is applied to normalized attention scores following the original
# implementation of transformer. Although it is a bit unusual, we empirically
# observe that it yields better performance.
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
def transform(self, x, linear_layer):
# The corresponding linear_layer of k, v, q are used to project the hidden_state (x).
bs, seq_len = x.shape[:2]
proj = linear_layer(x)
# Next, we need to produce multiple heads for the proj. This is done by spliting the
# hidden state to self.num_attention_heads, each of size self.attention_head_size.
proj = proj.view(bs, seq_len, self.num_attention_heads, self.attention_head_size)
# By proper transpose, we have proj of size [bs, num_attention_heads, seq_len, attention_head_size].
proj = proj.transpose(1, 2)
return proj
def attention(self, key, query, value, attention_mask):
"""
key, query, value: [batch_size, num_attention_heads, seq_len, attention_head_size]
attention_mask: [batch_size, 1, 1, seq_len], masks padding tokens in the input.
"""
d_k = query.size(-1) # attention_head_size
attention_scores = torch.matmul(query, key.transpose(-1, -2)) / math.sqrt(d_k)
# attention_scores shape: [batch_size, num_attention_heads, seq_len, seq_len]
# Apply attention mask
attention_scores = attention_scores + attention_mask
# Normalize scores with softmax and apply dropout.
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
attention_probs = self.dropout(attention_probs)
context = torch.matmul(attention_probs, value)
# context shape: [batch_size, num_attention_heads, seq_len, attention_head_size]
# Concatenate all attention heads to recover original shape: [batch_size, seq_len, hidden_size]
context = context.transpose(1, 2).contiguous()
context = context.view(context.size(0), context.size(1), -1)
return context
def forward(self, hidden_states, attention_mask):
"""
hidden_states: [bs, seq_len, hidden_state]
attention_mask: [bs, 1, 1, seq_len]
output: [bs, seq_len, hidden_state]
"""
# First, we have to generate the key, value, query for each token for multi-head attention
# using self.transform (more details inside the function).
# Size of *_layer is [bs, num_attention_heads, seq_len, attention_head_size].
key_layer = self.transform(hidden_states, self.key)
value_layer = self.transform(hidden_states, self.value)
query_layer = self.transform(hidden_states, self.query)
# Calculate the multi-head attention.
attn_value = self.attention(key_layer, query_layer, value_layer, attention_mask)
return attn_value
class BertLayer(nn.Module):
def __init__(self, config):
super().__init__()
# Multi-head attention.
self.self_attention = BertSelfAttention(config)
# Add-norm for multi-head attention.
self.attention_dense = nn.Linear(config.hidden_size, config.hidden_size)
self.attention_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.attention_dropout = nn.Dropout(config.hidden_dropout_prob)
# Feed forward.
self.interm_dense = nn.Linear(config.hidden_size, config.intermediate_size)
self.interm_af = F.gelu
# Add-norm for feed forward.
self.out_dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.out_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.out_dropout = nn.Dropout(config.hidden_dropout_prob)
def add_norm(self, input, output, dense_layer, dropout, ln_layer):
transformed_output = dense_layer(output) # Biến đổi output bằng dense_layer
transformed_output = dropout(transformed_output) # Áp dụng dropout
added_output = input + transformed_output # Kết hợp input và output
normalized_output = ln_layer(added_output) # Áp dụng chuẩn hóa
return normalized_output
def forward(self, hidden_states, attention_mask):
# 1. Multi-head attention
attention_output = self.self_attention(hidden_states, attention_mask)
# 2. Add-norm after attention
attention_output = self.add_norm(
hidden_states,
attention_output,
self.attention_dense,
self.attention_dropout,
self.attention_layer_norm
)
# 3. Feed-forward network
intermediate_output = self.interm_af(self.interm_dense(attention_output))
# 4. Add-norm after feed-forward
layer_output = self.add_norm(
attention_output,
intermediate_output,
self.out_dense,
self.out_dropout,
self.out_layer_norm
)
return layer_output
class BertModel(BertPreTrainedModel):
"""
The BERT model returns the final embeddings for each token in a sentence.
The model consists of:
1. Embedding layers (used in self.embed).
2. A stack of n BERT layers (used in self.encode).
3. A linear transformation layer for the [CLS] token (used in self.forward, as given).
"""
def __init__(self, config):
super().__init__(config)
self.config = config
# Embedding layers.
self.word_embedding = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
self.pos_embedding = nn.Embedding(config.max_position_embeddings, config.hidden_size)
self.tk_type_embedding = nn.Embedding(config.type_vocab_size, config.hidden_size)
self.embed_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.embed_dropout = nn.Dropout(config.hidden_dropout_prob)
# Register position_ids (1, len position emb) to buffer because it is a constant.
position_ids = torch.arange(config.max_position_embeddings).unsqueeze(0)
self.register_buffer('position_ids', position_ids)
# BERT encoder.
self.bert_layers = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
# [CLS] token transformations.
self.pooler_dense = nn.Linear(config.hidden_size, config.hidden_size)
self.pooler_af = nn.Tanh()
self.init_weights()
def embed(self, input_ids):
input_shape = input_ids.size()
seq_length = input_shape[1]
inputs_embeds = self.word_embedding(input_ids)
pos_ids = self.position_ids[:, :seq_length]
pos_embeds = self.pos_embedding(pos_ids)
# Since we are not considering token type, this embedding is just a placeholder.
tk_type_ids = torch.zeros(input_shape, dtype=torch.long, device=input_ids.device)
tk_type_embeds = self.tk_type_embedding(tk_type_ids)
embeddings = inputs_embeds + pos_embeds + tk_type_embeds
embeddings = self.embed_layer_norm(embeddings)
embeddings = self.embed_dropout(embeddings)
return embeddings
def encode(self, hidden_states, attention_mask):
"""
hidden_states: the output from the embedding layer [batch_size, seq_len, hidden_size]
attention_mask: [batch_size, seq_len]
"""
# Get the extended attention mask for self-attention.
# Returns extended_attention_mask of size [batch_size, 1, 1, seq_len].
# Distinguishes between non-padding tokens (with a value of 0) and padding tokens
# (with a value of a large negative number).
extended_attention_mask: torch.Tensor = get_extended_attention_mask(attention_mask, self.dtype)
# Pass the hidden states through the encoder layers.
for i, layer_module in enumerate(self.bert_layers):
# Feed the encoding from the last bert_layer to the next.
hidden_states = layer_module(hidden_states, extended_attention_mask)
return hidden_states
def forward(self, input_ids, attention_mask):
"""
input_ids: [batch_size, seq_len], seq_len is the max length of the batch
attention_mask: same size as input_ids, 1 represents non-padding tokens, 0 represents padding tokens
"""
# Get the embedding for each input token.
embedding_output = self.embed(input_ids=input_ids)
# Feed to a transformer (a stack of BertLayers).
sequence_output = self.encode(embedding_output, attention_mask=attention_mask)
# Get cls token hidden state.
first_tk = sequence_output[:, 0]
first_tk = self.pooler_dense(first_tk)
first_tk = self.pooler_af(first_tk)
return {'last_hidden_state': sequence_output, 'pooler_output': first_tk}