|
from transformers import PretrainedConfig, PreTrainedModel |
|
from torch import nn |
|
import torch |
|
|
|
class OctagonConfig(PretrainedConfig): |
|
model_type = "octagon" |
|
|
|
def __init__( |
|
self, |
|
vocab_size=30522, |
|
hidden_size=768, |
|
num_hidden_layers=8, |
|
num_attention_heads=8, |
|
intermediate_size=3072, |
|
hidden_act="gelu", |
|
hidden_dropout_prob=0.1, |
|
attention_probs_dropout_prob=0.1, |
|
max_position_embeddings=512, |
|
type_vocab_size=2, |
|
initializer_range=0.02, |
|
layer_norm_eps=1e-12, |
|
pad_token_id=0, |
|
position_embedding_type="absolute", |
|
classifier_dropout=None, |
|
num_labels=2, |
|
**kwargs |
|
): |
|
super().__init__(pad_token_id=pad_token_id, **kwargs) |
|
self.vocab_size = vocab_size |
|
self.hidden_size = hidden_size |
|
self.num_hidden_layers = num_hidden_layers |
|
self.num_attention_heads = num_attention_heads |
|
self.intermediate_size = intermediate_size |
|
self.hidden_act = hidden_act |
|
self.hidden_dropout_prob = hidden_dropout_prob |
|
self.attention_probs_dropout_prob = attention_probs_dropout_prob |
|
self.max_position_embeddings = max_position_embeddings |
|
self.type_vocab_size = type_vocab_size |
|
self.initializer_range = initializer_range |
|
self.layer_norm_eps = layer_norm_eps |
|
self.position_embedding_type = position_embedding_type |
|
self.classifier_dropout = classifier_dropout |
|
self.num_labels = num_labels |
|
|
|
class OctagonEmbeddings(nn.Module): |
|
def __init__(self, config): |
|
super().__init__() |
|
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size) |
|
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) |
|
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) |
|
|
|
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) |
|
self.dropout = nn.Dropout(config.hidden_dropout_prob) |
|
|
|
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) |
|
|
|
def forward(self, input_ids=None, token_type_ids=None, position_ids=None): |
|
seq_length = input_ids.size(1) |
|
|
|
if position_ids is None: |
|
position_ids = self.position_ids[:, :seq_length] |
|
|
|
if token_type_ids is None: |
|
token_type_ids = torch.zeros_like(input_ids) |
|
|
|
word_embeddings = self.word_embeddings(input_ids) |
|
position_embeddings = self.position_embeddings(position_ids) |
|
token_type_embeddings = self.token_type_embeddings(token_type_ids) |
|
|
|
embeddings = word_embeddings + position_embeddings + token_type_embeddings |
|
embeddings = self.LayerNorm(embeddings) |
|
embeddings = self.dropout(embeddings) |
|
return embeddings |
|
|
|
class OctagonSelfAttention(nn.Module): |
|
def __init__(self, config): |
|
super().__init__() |
|
self.num_attention_heads = config.num_attention_heads |
|
self.attention_head_size = int(config.hidden_size / config.num_attention_heads) |
|
self.all_head_size = self.num_attention_heads * self.attention_head_size |
|
|
|
self.query = nn.Linear(config.hidden_size, self.all_head_size) |
|
self.key = nn.Linear(config.hidden_size, self.all_head_size) |
|
self.value = nn.Linear(config.hidden_size, self.all_head_size) |
|
|
|
self.dropout = nn.Dropout(config.attention_probs_dropout_prob) |
|
|
|
def transpose_for_scores(self, x): |
|
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) |
|
x = x.view(*new_x_shape) |
|
return x.permute(0, 2, 1, 3) |
|
|
|
def forward(self, hidden_states): |
|
query_layer = self.transpose_for_scores(self.query(hidden_states)) |
|
key_layer = self.transpose_for_scores(self.key(hidden_states)) |
|
value_layer = self.transpose_for_scores(self.value(hidden_states)) |
|
|
|
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) |
|
attention_scores = attention_scores / math.sqrt(self.attention_head_size) |
|
attention_probs = nn.functional.softmax(attention_scores, dim=-1) |
|
attention_probs = self.dropout(attention_probs) |
|
|
|
context_layer = torch.matmul(attention_probs, value_layer) |
|
context_layer = context_layer.permute(0, 2, 1, 3).contiguous() |
|
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) |
|
context_layer = context_layer.view(*new_context_layer_shape) |
|
return context_layer |
|
|
|
class OctagonSelfOutput(nn.Module): |
|
def __init__(self, config): |
|
super().__init__() |
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size) |
|
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) |
|
self.dropout = nn.Dropout(config.hidden_dropout_prob) |
|
|
|
def forward(self, hidden_states, input_tensor): |
|
hidden_states = self.dense(hidden_states) |
|
hidden_states = self.dropout(hidden_states) |
|
hidden_states = self.LayerNorm(hidden_states + input_tensor) |
|
return hidden_states |
|
|
|
class OctagonAttention(nn.Module): |
|
def __init__(self, config): |
|
super().__init__() |
|
self.self = OctagonSelfAttention(config) |
|
self.output = OctagonSelfOutput(config) |
|
|
|
def forward(self, hidden_states): |
|
self_outputs = self.self(hidden_states) |
|
attention_output = self.output(self_outputs, hidden_states) |
|
return attention_output |
|
|
|
class OctagonIntermediate(nn.Module): |
|
def __init__(self, config): |
|
super().__init__() |
|
self.dense = nn.Linear(config.hidden_size, config.intermediate_size) |
|
self.intermediate_act_fn = nn.GELU() |
|
|
|
def forward(self, hidden_states): |
|
hidden_states = self.dense(hidden_states) |
|
hidden_states = self.intermediate_act_fn(hidden_states) |
|
return hidden_states |
|
|
|
class OctagonOutput(nn.Module): |
|
def __init__(self, config): |
|
super().__init__() |
|
self.dense = nn.Linear(config.intermediate_size, config.hidden_size) |
|
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) |
|
self.dropout = nn.Dropout(config.hidden_dropout_prob) |
|
|
|
def forward(self, hidden_states, input_tensor): |
|
hidden_states = self.dense(hidden_states) |
|
hidden_states = self.dropout(hidden_states) |
|
hidden_states = self.LayerNorm(hidden_states + input_tensor) |
|
return hidden_states |
|
|
|
class OctagonLayer(nn.Module): |
|
def __init__(self, config): |
|
super().__init__() |
|
self.attention = OctagonAttention(config) |
|
self.intermediate = OctagonIntermediate(config) |
|
self.output = OctagonOutput(config) |
|
|
|
def forward(self, hidden_states): |
|
attention_output = self.attention(hidden_states) |
|
intermediate_output = self.intermediate(attention_output) |
|
layer_output = self.output(intermediate_output, attention_output) |
|
return layer_output |
|
|
|
class OctagonEncoder(nn.Module): |
|
def __init__(self, config): |
|
super().__init__() |
|
self.layer = nn.ModuleList([OctagonLayer(config) for _ in range(config.num_hidden_layers)]) |
|
|
|
def forward(self, hidden_states): |
|
for layer_module in self.layer: |
|
hidden_states = layer_module(hidden_states) |
|
return hidden_states |
|
|
|
class OctagonModel(PreTrainedModel): |
|
config_class = OctagonConfig |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.config = config |
|
self.embeddings = OctagonEmbeddings(config) |
|
self.encoder = OctagonEncoder(config) |
|
self.pooler = nn.Linear(config.hidden_size, config.hidden_size) |
|
self.tanh = nn.Tanh() |
|
|
|
self.post_init() |
|
|
|
def forward(self, input_ids=None, token_type_ids=None, position_ids=None): |
|
if input_ids is not None: |
|
input_shape = input_ids.size() |
|
else: |
|
raise ValueError("You have to specify input_ids") |
|
|
|
embedding_output = self.embeddings( |
|
input_ids=input_ids, |
|
token_type_ids=token_type_ids, |
|
position_ids=position_ids |
|
) |
|
|
|
encoder_outputs = self.encoder(embedding_output) |
|
pooled_output = self.pooler(encoder_outputs[:, 0]) |
|
pooled_output = self.tanh(pooled_output) |
|
|
|
return encoder_outputs, pooled_output |
|
|
|
class OctagonForSequenceClassification(PreTrainedModel): |
|
config_class = OctagonConfig |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.num_labels = config.num_labels |
|
self.octagon = OctagonModel(config) |
|
classifier_dropout = ( |
|
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob |
|
) |
|
self.dropout = nn.Dropout(classifier_dropout) |
|
self.classifier = nn.Linear(config.hidden_size, config.num_labels) |
|
|
|
self.post_init() |
|
|
|
def forward(self, input_ids=None, token_type_ids=None, position_ids=None, labels=None): |
|
_, pooled_output = self.octagon( |
|
input_ids=input_ids, |
|
token_type_ids=token_type_ids, |
|
position_ids=position_ids |
|
) |
|
|
|
pooled_output = self.dropout(pooled_output) |
|
logits = self.classifier(pooled_output) |
|
|
|
loss = None |
|
if labels is not None: |
|
loss_fct = nn.CrossEntropyLoss() |
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) |
|
|
|
return {"loss": loss, "logits": logits} |