Create modeling_octagon.py
Browse files- modeling_octagon.py +241 -0
modeling_octagon.py
ADDED
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import PretrainedConfig, PreTrainedModel
|
2 |
+
from torch import nn
|
3 |
+
import torch
|
4 |
+
|
5 |
+
class OctagonConfig(PretrainedConfig):
|
6 |
+
model_type = "octagon"
|
7 |
+
|
8 |
+
def __init__(
|
9 |
+
self,
|
10 |
+
vocab_size=30522,
|
11 |
+
hidden_size=768,
|
12 |
+
num_hidden_layers=8, # Octagon has 8 sides!
|
13 |
+
num_attention_heads=8,
|
14 |
+
intermediate_size=3072,
|
15 |
+
hidden_act="gelu",
|
16 |
+
hidden_dropout_prob=0.1,
|
17 |
+
attention_probs_dropout_prob=0.1,
|
18 |
+
max_position_embeddings=512,
|
19 |
+
type_vocab_size=2,
|
20 |
+
initializer_range=0.02,
|
21 |
+
layer_norm_eps=1e-12,
|
22 |
+
pad_token_id=0,
|
23 |
+
position_embedding_type="absolute",
|
24 |
+
classifier_dropout=None,
|
25 |
+
num_labels=2,
|
26 |
+
**kwargs
|
27 |
+
):
|
28 |
+
super().__init__(pad_token_id=pad_token_id, **kwargs)
|
29 |
+
self.vocab_size = vocab_size
|
30 |
+
self.hidden_size = hidden_size
|
31 |
+
self.num_hidden_layers = num_hidden_layers
|
32 |
+
self.num_attention_heads = num_attention_heads
|
33 |
+
self.intermediate_size = intermediate_size
|
34 |
+
self.hidden_act = hidden_act
|
35 |
+
self.hidden_dropout_prob = hidden_dropout_prob
|
36 |
+
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
37 |
+
self.max_position_embeddings = max_position_embeddings
|
38 |
+
self.type_vocab_size = type_vocab_size
|
39 |
+
self.initializer_range = initializer_range
|
40 |
+
self.layer_norm_eps = layer_norm_eps
|
41 |
+
self.position_embedding_type = position_embedding_type
|
42 |
+
self.classifier_dropout = classifier_dropout
|
43 |
+
self.num_labels = num_labels
|
44 |
+
|
45 |
+
class OctagonEmbeddings(nn.Module):
|
46 |
+
def __init__(self, config):
|
47 |
+
super().__init__()
|
48 |
+
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
|
49 |
+
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
|
50 |
+
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
|
51 |
+
|
52 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
53 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
54 |
+
|
55 |
+
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
|
56 |
+
|
57 |
+
def forward(self, input_ids=None, token_type_ids=None, position_ids=None):
|
58 |
+
seq_length = input_ids.size(1)
|
59 |
+
|
60 |
+
if position_ids is None:
|
61 |
+
position_ids = self.position_ids[:, :seq_length]
|
62 |
+
|
63 |
+
if token_type_ids is None:
|
64 |
+
token_type_ids = torch.zeros_like(input_ids)
|
65 |
+
|
66 |
+
word_embeddings = self.word_embeddings(input_ids)
|
67 |
+
position_embeddings = self.position_embeddings(position_ids)
|
68 |
+
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
69 |
+
|
70 |
+
embeddings = word_embeddings + position_embeddings + token_type_embeddings
|
71 |
+
embeddings = self.LayerNorm(embeddings)
|
72 |
+
embeddings = self.dropout(embeddings)
|
73 |
+
return embeddings
|
74 |
+
|
75 |
+
class OctagonSelfAttention(nn.Module):
|
76 |
+
def __init__(self, config):
|
77 |
+
super().__init__()
|
78 |
+
self.num_attention_heads = config.num_attention_heads
|
79 |
+
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
80 |
+
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
81 |
+
|
82 |
+
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
83 |
+
self.key = nn.Linear(config.hidden_size, self.all_head_size)
|
84 |
+
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
85 |
+
|
86 |
+
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
87 |
+
|
88 |
+
def transpose_for_scores(self, x):
|
89 |
+
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
90 |
+
x = x.view(*new_x_shape)
|
91 |
+
return x.permute(0, 2, 1, 3)
|
92 |
+
|
93 |
+
def forward(self, hidden_states):
|
94 |
+
query_layer = self.transpose_for_scores(self.query(hidden_states))
|
95 |
+
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
96 |
+
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
97 |
+
|
98 |
+
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
99 |
+
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
100 |
+
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
|
101 |
+
attention_probs = self.dropout(attention_probs)
|
102 |
+
|
103 |
+
context_layer = torch.matmul(attention_probs, value_layer)
|
104 |
+
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
105 |
+
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
106 |
+
context_layer = context_layer.view(*new_context_layer_shape)
|
107 |
+
return context_layer
|
108 |
+
|
109 |
+
class OctagonSelfOutput(nn.Module):
|
110 |
+
def __init__(self, config):
|
111 |
+
super().__init__()
|
112 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
113 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
114 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
115 |
+
|
116 |
+
def forward(self, hidden_states, input_tensor):
|
117 |
+
hidden_states = self.dense(hidden_states)
|
118 |
+
hidden_states = self.dropout(hidden_states)
|
119 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
120 |
+
return hidden_states
|
121 |
+
|
122 |
+
class OctagonAttention(nn.Module):
|
123 |
+
def __init__(self, config):
|
124 |
+
super().__init__()
|
125 |
+
self.self = OctagonSelfAttention(config)
|
126 |
+
self.output = OctagonSelfOutput(config)
|
127 |
+
|
128 |
+
def forward(self, hidden_states):
|
129 |
+
self_outputs = self.self(hidden_states)
|
130 |
+
attention_output = self.output(self_outputs, hidden_states)
|
131 |
+
return attention_output
|
132 |
+
|
133 |
+
class OctagonIntermediate(nn.Module):
|
134 |
+
def __init__(self, config):
|
135 |
+
super().__init__()
|
136 |
+
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
137 |
+
self.intermediate_act_fn = nn.GELU()
|
138 |
+
|
139 |
+
def forward(self, hidden_states):
|
140 |
+
hidden_states = self.dense(hidden_states)
|
141 |
+
hidden_states = self.intermediate_act_fn(hidden_states)
|
142 |
+
return hidden_states
|
143 |
+
|
144 |
+
class OctagonOutput(nn.Module):
|
145 |
+
def __init__(self, config):
|
146 |
+
super().__init__()
|
147 |
+
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
148 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
149 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
150 |
+
|
151 |
+
def forward(self, hidden_states, input_tensor):
|
152 |
+
hidden_states = self.dense(hidden_states)
|
153 |
+
hidden_states = self.dropout(hidden_states)
|
154 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
155 |
+
return hidden_states
|
156 |
+
|
157 |
+
class OctagonLayer(nn.Module):
|
158 |
+
def __init__(self, config):
|
159 |
+
super().__init__()
|
160 |
+
self.attention = OctagonAttention(config)
|
161 |
+
self.intermediate = OctagonIntermediate(config)
|
162 |
+
self.output = OctagonOutput(config)
|
163 |
+
|
164 |
+
def forward(self, hidden_states):
|
165 |
+
attention_output = self.attention(hidden_states)
|
166 |
+
intermediate_output = self.intermediate(attention_output)
|
167 |
+
layer_output = self.output(intermediate_output, attention_output)
|
168 |
+
return layer_output
|
169 |
+
|
170 |
+
class OctagonEncoder(nn.Module):
|
171 |
+
def __init__(self, config):
|
172 |
+
super().__init__()
|
173 |
+
self.layer = nn.ModuleList([OctagonLayer(config) for _ in range(config.num_hidden_layers)])
|
174 |
+
|
175 |
+
def forward(self, hidden_states):
|
176 |
+
for layer_module in self.layer:
|
177 |
+
hidden_states = layer_module(hidden_states)
|
178 |
+
return hidden_states
|
179 |
+
|
180 |
+
class OctagonModel(PreTrainedModel):
|
181 |
+
config_class = OctagonConfig
|
182 |
+
|
183 |
+
def __init__(self, config):
|
184 |
+
super().__init__(config)
|
185 |
+
self.config = config
|
186 |
+
self.embeddings = OctagonEmbeddings(config)
|
187 |
+
self.encoder = OctagonEncoder(config)
|
188 |
+
self.pooler = nn.Linear(config.hidden_size, config.hidden_size)
|
189 |
+
self.tanh = nn.Tanh()
|
190 |
+
|
191 |
+
self.post_init()
|
192 |
+
|
193 |
+
def forward(self, input_ids=None, token_type_ids=None, position_ids=None):
|
194 |
+
if input_ids is not None:
|
195 |
+
input_shape = input_ids.size()
|
196 |
+
else:
|
197 |
+
raise ValueError("You have to specify input_ids")
|
198 |
+
|
199 |
+
embedding_output = self.embeddings(
|
200 |
+
input_ids=input_ids,
|
201 |
+
token_type_ids=token_type_ids,
|
202 |
+
position_ids=position_ids
|
203 |
+
)
|
204 |
+
|
205 |
+
encoder_outputs = self.encoder(embedding_output)
|
206 |
+
pooled_output = self.pooler(encoder_outputs[:, 0])
|
207 |
+
pooled_output = self.tanh(pooled_output)
|
208 |
+
|
209 |
+
return encoder_outputs, pooled_output
|
210 |
+
|
211 |
+
class OctagonForSequenceClassification(PreTrainedModel):
|
212 |
+
config_class = OctagonConfig
|
213 |
+
|
214 |
+
def __init__(self, config):
|
215 |
+
super().__init__(config)
|
216 |
+
self.num_labels = config.num_labels
|
217 |
+
self.octagon = OctagonModel(config)
|
218 |
+
classifier_dropout = (
|
219 |
+
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
|
220 |
+
)
|
221 |
+
self.dropout = nn.Dropout(classifier_dropout)
|
222 |
+
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
223 |
+
|
224 |
+
self.post_init()
|
225 |
+
|
226 |
+
def forward(self, input_ids=None, token_type_ids=None, position_ids=None, labels=None):
|
227 |
+
_, pooled_output = self.octagon(
|
228 |
+
input_ids=input_ids,
|
229 |
+
token_type_ids=token_type_ids,
|
230 |
+
position_ids=position_ids
|
231 |
+
)
|
232 |
+
|
233 |
+
pooled_output = self.dropout(pooled_output)
|
234 |
+
logits = self.classifier(pooled_output)
|
235 |
+
|
236 |
+
loss = None
|
237 |
+
if labels is not None:
|
238 |
+
loss_fct = nn.CrossEntropyLoss()
|
239 |
+
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
240 |
+
|
241 |
+
return {"loss": loss, "logits": logits}
|