from transformers import PretrainedConfig, PreTrainedModel from torch import nn import torch class OctagonConfig(PretrainedConfig): model_type = "octagon" def __init__( self, vocab_size=30522, hidden_size=768, num_hidden_layers=8, # Octagon has 8 sides! num_attention_heads=8, intermediate_size=3072, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, position_embedding_type="absolute", classifier_dropout=None, num_labels=2, **kwargs ): super().__init__(pad_token_id=pad_token_id, **kwargs) self.vocab_size = vocab_size self.hidden_size = hidden_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.intermediate_size = intermediate_size self.hidden_act = hidden_act self.hidden_dropout_prob = hidden_dropout_prob self.attention_probs_dropout_prob = attention_probs_dropout_prob self.max_position_embeddings = max_position_embeddings self.type_vocab_size = type_vocab_size self.initializer_range = initializer_range self.layer_norm_eps = layer_norm_eps self.position_embedding_type = position_embedding_type self.classifier_dropout = classifier_dropout self.num_labels = num_labels class OctagonEmbeddings(nn.Module): def __init__(self, config): super().__init__() self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size) self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) def forward(self, input_ids=None, token_type_ids=None, position_ids=None): seq_length = input_ids.size(1) if position_ids is None: position_ids = self.position_ids[:, :seq_length] if token_type_ids is None: token_type_ids = torch.zeros_like(input_ids) word_embeddings = self.word_embeddings(input_ids) position_embeddings = self.position_embeddings(position_ids) token_type_embeddings = self.token_type_embeddings(token_type_ids) embeddings = word_embeddings + position_embeddings + token_type_embeddings embeddings = self.LayerNorm(embeddings) embeddings = self.dropout(embeddings) return embeddings class OctagonSelfAttention(nn.Module): def __init__(self, config): super().__init__() self.num_attention_heads = config.num_attention_heads self.attention_head_size = int(config.hidden_size / config.num_attention_heads) self.all_head_size = self.num_attention_heads * self.attention_head_size self.query = nn.Linear(config.hidden_size, self.all_head_size) self.key = nn.Linear(config.hidden_size, self.all_head_size) self.value = nn.Linear(config.hidden_size, self.all_head_size) self.dropout = nn.Dropout(config.attention_probs_dropout_prob) def transpose_for_scores(self, x): new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) x = x.view(*new_x_shape) return x.permute(0, 2, 1, 3) def forward(self, hidden_states): query_layer = self.transpose_for_scores(self.query(hidden_states)) key_layer = self.transpose_for_scores(self.key(hidden_states)) value_layer = self.transpose_for_scores(self.value(hidden_states)) attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) attention_scores = attention_scores / math.sqrt(self.attention_head_size) attention_probs = nn.functional.softmax(attention_scores, dim=-1) attention_probs = self.dropout(attention_probs) context_layer = torch.matmul(attention_probs, value_layer) context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) context_layer = context_layer.view(*new_context_layer_shape) return context_layer class OctagonSelfOutput(nn.Module): def __init__(self, config): super().__init__() self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) def forward(self, hidden_states, input_tensor): hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) hidden_states = self.LayerNorm(hidden_states + input_tensor) return hidden_states class OctagonAttention(nn.Module): def __init__(self, config): super().__init__() self.self = OctagonSelfAttention(config) self.output = OctagonSelfOutput(config) def forward(self, hidden_states): self_outputs = self.self(hidden_states) attention_output = self.output(self_outputs, hidden_states) return attention_output class OctagonIntermediate(nn.Module): def __init__(self, config): super().__init__() self.dense = nn.Linear(config.hidden_size, config.intermediate_size) self.intermediate_act_fn = nn.GELU() def forward(self, hidden_states): hidden_states = self.dense(hidden_states) hidden_states = self.intermediate_act_fn(hidden_states) return hidden_states class OctagonOutput(nn.Module): def __init__(self, config): super().__init__() self.dense = nn.Linear(config.intermediate_size, config.hidden_size) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) def forward(self, hidden_states, input_tensor): hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) hidden_states = self.LayerNorm(hidden_states + input_tensor) return hidden_states class OctagonLayer(nn.Module): def __init__(self, config): super().__init__() self.attention = OctagonAttention(config) self.intermediate = OctagonIntermediate(config) self.output = OctagonOutput(config) def forward(self, hidden_states): attention_output = self.attention(hidden_states) intermediate_output = self.intermediate(attention_output) layer_output = self.output(intermediate_output, attention_output) return layer_output class OctagonEncoder(nn.Module): def __init__(self, config): super().__init__() self.layer = nn.ModuleList([OctagonLayer(config) for _ in range(config.num_hidden_layers)]) def forward(self, hidden_states): for layer_module in self.layer: hidden_states = layer_module(hidden_states) return hidden_states class OctagonModel(PreTrainedModel): config_class = OctagonConfig def __init__(self, config): super().__init__(config) self.config = config self.embeddings = OctagonEmbeddings(config) self.encoder = OctagonEncoder(config) self.pooler = nn.Linear(config.hidden_size, config.hidden_size) self.tanh = nn.Tanh() self.post_init() def forward(self, input_ids=None, token_type_ids=None, position_ids=None): if input_ids is not None: input_shape = input_ids.size() else: raise ValueError("You have to specify input_ids") embedding_output = self.embeddings( input_ids=input_ids, token_type_ids=token_type_ids, position_ids=position_ids ) encoder_outputs = self.encoder(embedding_output) pooled_output = self.pooler(encoder_outputs[:, 0]) pooled_output = self.tanh(pooled_output) return encoder_outputs, pooled_output class OctagonForSequenceClassification(PreTrainedModel): config_class = OctagonConfig def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels self.octagon = OctagonModel(config) classifier_dropout = ( config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob ) self.dropout = nn.Dropout(classifier_dropout) self.classifier = nn.Linear(config.hidden_size, config.num_labels) self.post_init() def forward(self, input_ids=None, token_type_ids=None, position_ids=None, labels=None): _, pooled_output = self.octagon( input_ids=input_ids, token_type_ids=token_type_ids, position_ids=position_ids ) pooled_output = self.dropout(pooled_output) logits = self.classifier(pooled_output) loss = None if labels is not None: loss_fct = nn.CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) return {"loss": loss, "logits": logits}