Octa / modeling_octagon.py
Georg4000's picture
Create modeling_octagon.py
e1386e9 verified
from transformers import PretrainedConfig, PreTrainedModel
from torch import nn
import torch
class OctagonConfig(PretrainedConfig):
model_type = "octagon"
def __init__(
self,
vocab_size=30522,
hidden_size=768,
num_hidden_layers=8, # Octagon has 8 sides!
num_attention_heads=8,
intermediate_size=3072,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=2,
initializer_range=0.02,
layer_norm_eps=1e-12,
pad_token_id=0,
position_embedding_type="absolute",
classifier_dropout=None,
num_labels=2,
**kwargs
):
super().__init__(pad_token_id=pad_token_id, **kwargs)
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
self.position_embedding_type = position_embedding_type
self.classifier_dropout = classifier_dropout
self.num_labels = num_labels
class OctagonEmbeddings(nn.Module):
def __init__(self, config):
super().__init__()
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
def forward(self, input_ids=None, token_type_ids=None, position_ids=None):
seq_length = input_ids.size(1)
if position_ids is None:
position_ids = self.position_ids[:, :seq_length]
if token_type_ids is None:
token_type_ids = torch.zeros_like(input_ids)
word_embeddings = self.word_embeddings(input_ids)
position_embeddings = self.position_embeddings(position_ids)
token_type_embeddings = self.token_type_embeddings(token_type_ids)
embeddings = word_embeddings + position_embeddings + token_type_embeddings
embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings)
return embeddings
class OctagonSelfAttention(nn.Module):
def __init__(self, config):
super().__init__()
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = nn.Linear(config.hidden_size, self.all_head_size)
self.key = nn.Linear(config.hidden_size, self.all_head_size)
self.value = nn.Linear(config.hidden_size, self.all_head_size)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(self, hidden_states):
query_layer = self.transpose_for_scores(self.query(hidden_states))
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
attention_probs = self.dropout(attention_probs)
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
return context_layer
class OctagonSelfOutput(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
class OctagonAttention(nn.Module):
def __init__(self, config):
super().__init__()
self.self = OctagonSelfAttention(config)
self.output = OctagonSelfOutput(config)
def forward(self, hidden_states):
self_outputs = self.self(hidden_states)
attention_output = self.output(self_outputs, hidden_states)
return attention_output
class OctagonIntermediate(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
self.intermediate_act_fn = nn.GELU()
def forward(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
class OctagonOutput(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
class OctagonLayer(nn.Module):
def __init__(self, config):
super().__init__()
self.attention = OctagonAttention(config)
self.intermediate = OctagonIntermediate(config)
self.output = OctagonOutput(config)
def forward(self, hidden_states):
attention_output = self.attention(hidden_states)
intermediate_output = self.intermediate(attention_output)
layer_output = self.output(intermediate_output, attention_output)
return layer_output
class OctagonEncoder(nn.Module):
def __init__(self, config):
super().__init__()
self.layer = nn.ModuleList([OctagonLayer(config) for _ in range(config.num_hidden_layers)])
def forward(self, hidden_states):
for layer_module in self.layer:
hidden_states = layer_module(hidden_states)
return hidden_states
class OctagonModel(PreTrainedModel):
config_class = OctagonConfig
def __init__(self, config):
super().__init__(config)
self.config = config
self.embeddings = OctagonEmbeddings(config)
self.encoder = OctagonEncoder(config)
self.pooler = nn.Linear(config.hidden_size, config.hidden_size)
self.tanh = nn.Tanh()
self.post_init()
def forward(self, input_ids=None, token_type_ids=None, position_ids=None):
if input_ids is not None:
input_shape = input_ids.size()
else:
raise ValueError("You have to specify input_ids")
embedding_output = self.embeddings(
input_ids=input_ids,
token_type_ids=token_type_ids,
position_ids=position_ids
)
encoder_outputs = self.encoder(embedding_output)
pooled_output = self.pooler(encoder_outputs[:, 0])
pooled_output = self.tanh(pooled_output)
return encoder_outputs, pooled_output
class OctagonForSequenceClassification(PreTrainedModel):
config_class = OctagonConfig
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.octagon = OctagonModel(config)
classifier_dropout = (
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
)
self.dropout = nn.Dropout(classifier_dropout)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
self.post_init()
def forward(self, input_ids=None, token_type_ids=None, position_ids=None, labels=None):
_, pooled_output = self.octagon(
input_ids=input_ids,
token_type_ids=token_type_ids,
position_ids=position_ids
)
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
loss = None
if labels is not None:
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
return {"loss": loss, "logits": logits}