Georg4000 commited on
Commit
e1386e9
·
verified ·
1 Parent(s): 1040c7d

Create modeling_octagon.py

Browse files
Files changed (1) hide show
  1. modeling_octagon.py +241 -0
modeling_octagon.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig, PreTrainedModel
2
+ from torch import nn
3
+ import torch
4
+
5
+ class OctagonConfig(PretrainedConfig):
6
+ model_type = "octagon"
7
+
8
+ def __init__(
9
+ self,
10
+ vocab_size=30522,
11
+ hidden_size=768,
12
+ num_hidden_layers=8, # Octagon has 8 sides!
13
+ num_attention_heads=8,
14
+ intermediate_size=3072,
15
+ hidden_act="gelu",
16
+ hidden_dropout_prob=0.1,
17
+ attention_probs_dropout_prob=0.1,
18
+ max_position_embeddings=512,
19
+ type_vocab_size=2,
20
+ initializer_range=0.02,
21
+ layer_norm_eps=1e-12,
22
+ pad_token_id=0,
23
+ position_embedding_type="absolute",
24
+ classifier_dropout=None,
25
+ num_labels=2,
26
+ **kwargs
27
+ ):
28
+ super().__init__(pad_token_id=pad_token_id, **kwargs)
29
+ self.vocab_size = vocab_size
30
+ self.hidden_size = hidden_size
31
+ self.num_hidden_layers = num_hidden_layers
32
+ self.num_attention_heads = num_attention_heads
33
+ self.intermediate_size = intermediate_size
34
+ self.hidden_act = hidden_act
35
+ self.hidden_dropout_prob = hidden_dropout_prob
36
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
37
+ self.max_position_embeddings = max_position_embeddings
38
+ self.type_vocab_size = type_vocab_size
39
+ self.initializer_range = initializer_range
40
+ self.layer_norm_eps = layer_norm_eps
41
+ self.position_embedding_type = position_embedding_type
42
+ self.classifier_dropout = classifier_dropout
43
+ self.num_labels = num_labels
44
+
45
+ class OctagonEmbeddings(nn.Module):
46
+ def __init__(self, config):
47
+ super().__init__()
48
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
49
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
50
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
51
+
52
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
53
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
54
+
55
+ self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
56
+
57
+ def forward(self, input_ids=None, token_type_ids=None, position_ids=None):
58
+ seq_length = input_ids.size(1)
59
+
60
+ if position_ids is None:
61
+ position_ids = self.position_ids[:, :seq_length]
62
+
63
+ if token_type_ids is None:
64
+ token_type_ids = torch.zeros_like(input_ids)
65
+
66
+ word_embeddings = self.word_embeddings(input_ids)
67
+ position_embeddings = self.position_embeddings(position_ids)
68
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
69
+
70
+ embeddings = word_embeddings + position_embeddings + token_type_embeddings
71
+ embeddings = self.LayerNorm(embeddings)
72
+ embeddings = self.dropout(embeddings)
73
+ return embeddings
74
+
75
+ class OctagonSelfAttention(nn.Module):
76
+ def __init__(self, config):
77
+ super().__init__()
78
+ self.num_attention_heads = config.num_attention_heads
79
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
80
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
81
+
82
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
83
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
84
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
85
+
86
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
87
+
88
+ def transpose_for_scores(self, x):
89
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
90
+ x = x.view(*new_x_shape)
91
+ return x.permute(0, 2, 1, 3)
92
+
93
+ def forward(self, hidden_states):
94
+ query_layer = self.transpose_for_scores(self.query(hidden_states))
95
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
96
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
97
+
98
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
99
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
100
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
101
+ attention_probs = self.dropout(attention_probs)
102
+
103
+ context_layer = torch.matmul(attention_probs, value_layer)
104
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
105
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
106
+ context_layer = context_layer.view(*new_context_layer_shape)
107
+ return context_layer
108
+
109
+ class OctagonSelfOutput(nn.Module):
110
+ def __init__(self, config):
111
+ super().__init__()
112
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
113
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
114
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
115
+
116
+ def forward(self, hidden_states, input_tensor):
117
+ hidden_states = self.dense(hidden_states)
118
+ hidden_states = self.dropout(hidden_states)
119
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
120
+ return hidden_states
121
+
122
+ class OctagonAttention(nn.Module):
123
+ def __init__(self, config):
124
+ super().__init__()
125
+ self.self = OctagonSelfAttention(config)
126
+ self.output = OctagonSelfOutput(config)
127
+
128
+ def forward(self, hidden_states):
129
+ self_outputs = self.self(hidden_states)
130
+ attention_output = self.output(self_outputs, hidden_states)
131
+ return attention_output
132
+
133
+ class OctagonIntermediate(nn.Module):
134
+ def __init__(self, config):
135
+ super().__init__()
136
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
137
+ self.intermediate_act_fn = nn.GELU()
138
+
139
+ def forward(self, hidden_states):
140
+ hidden_states = self.dense(hidden_states)
141
+ hidden_states = self.intermediate_act_fn(hidden_states)
142
+ return hidden_states
143
+
144
+ class OctagonOutput(nn.Module):
145
+ def __init__(self, config):
146
+ super().__init__()
147
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
148
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
149
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
150
+
151
+ def forward(self, hidden_states, input_tensor):
152
+ hidden_states = self.dense(hidden_states)
153
+ hidden_states = self.dropout(hidden_states)
154
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
155
+ return hidden_states
156
+
157
+ class OctagonLayer(nn.Module):
158
+ def __init__(self, config):
159
+ super().__init__()
160
+ self.attention = OctagonAttention(config)
161
+ self.intermediate = OctagonIntermediate(config)
162
+ self.output = OctagonOutput(config)
163
+
164
+ def forward(self, hidden_states):
165
+ attention_output = self.attention(hidden_states)
166
+ intermediate_output = self.intermediate(attention_output)
167
+ layer_output = self.output(intermediate_output, attention_output)
168
+ return layer_output
169
+
170
+ class OctagonEncoder(nn.Module):
171
+ def __init__(self, config):
172
+ super().__init__()
173
+ self.layer = nn.ModuleList([OctagonLayer(config) for _ in range(config.num_hidden_layers)])
174
+
175
+ def forward(self, hidden_states):
176
+ for layer_module in self.layer:
177
+ hidden_states = layer_module(hidden_states)
178
+ return hidden_states
179
+
180
+ class OctagonModel(PreTrainedModel):
181
+ config_class = OctagonConfig
182
+
183
+ def __init__(self, config):
184
+ super().__init__(config)
185
+ self.config = config
186
+ self.embeddings = OctagonEmbeddings(config)
187
+ self.encoder = OctagonEncoder(config)
188
+ self.pooler = nn.Linear(config.hidden_size, config.hidden_size)
189
+ self.tanh = nn.Tanh()
190
+
191
+ self.post_init()
192
+
193
+ def forward(self, input_ids=None, token_type_ids=None, position_ids=None):
194
+ if input_ids is not None:
195
+ input_shape = input_ids.size()
196
+ else:
197
+ raise ValueError("You have to specify input_ids")
198
+
199
+ embedding_output = self.embeddings(
200
+ input_ids=input_ids,
201
+ token_type_ids=token_type_ids,
202
+ position_ids=position_ids
203
+ )
204
+
205
+ encoder_outputs = self.encoder(embedding_output)
206
+ pooled_output = self.pooler(encoder_outputs[:, 0])
207
+ pooled_output = self.tanh(pooled_output)
208
+
209
+ return encoder_outputs, pooled_output
210
+
211
+ class OctagonForSequenceClassification(PreTrainedModel):
212
+ config_class = OctagonConfig
213
+
214
+ def __init__(self, config):
215
+ super().__init__(config)
216
+ self.num_labels = config.num_labels
217
+ self.octagon = OctagonModel(config)
218
+ classifier_dropout = (
219
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
220
+ )
221
+ self.dropout = nn.Dropout(classifier_dropout)
222
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
223
+
224
+ self.post_init()
225
+
226
+ def forward(self, input_ids=None, token_type_ids=None, position_ids=None, labels=None):
227
+ _, pooled_output = self.octagon(
228
+ input_ids=input_ids,
229
+ token_type_ids=token_type_ids,
230
+ position_ids=position_ids
231
+ )
232
+
233
+ pooled_output = self.dropout(pooled_output)
234
+ logits = self.classifier(pooled_output)
235
+
236
+ loss = None
237
+ if labels is not None:
238
+ loss_fct = nn.CrossEntropyLoss()
239
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
240
+
241
+ return {"loss": loss, "logits": logits}