lgcharpe commited on
Commit
c28c53e
·
verified ·
1 Parent(s): e8a5f8b

Uploading the modeling file

Browse files
__init__.py ADDED
File without changes
config.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "GPTBERTFoCausalLM"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_gpt_bert.ModelConfig",
7
+ "AutoModel": "modeling_gpt_bert.GPTBERT",
8
+ "AutoModelForCausalLM": "modeling_gpt_bert.GPTBERTForCausalLM",
9
+ "AutoModelForMaskedLM": "modeling_gpt_bert.GPTBERTForMaskedLM"
10
+ },
11
+ "attention_probs_dropout_prob": 0.1,
12
+ "hidden_dropout_prob": 0.1,
13
+ "hidden_size": 768,
14
+ "intermediate_size": 2560,
15
+ "max_position_embeddings": 512,
16
+ "position_bucket_size": 32,
17
+ "num_attention_heads": 12,
18
+ "num_hidden_layers": 12,
19
+ "vocab_size": 8192,
20
+ "layer_norm_eps": 1.0e-5
21
+ }
configuration_gpt_bert.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import pathlib
5
+ import copy
6
+
7
+ from typing import Any
8
+ from transformers.configuration_utils import PretrainedConfig
9
+
10
+
11
+ class ModelConfig(PretrainedConfig):
12
+
13
+ def __init__(self: ModelConfig, config_file: pathlib.Path | str | None = None, **kwargs):
14
+ """
15
+ """
16
+ super().__init__(**kwargs)
17
+ if config_file is None:
18
+ self.attention_probs_dropout_prob: float = 0.1
19
+ self.hidden_dropout_prob = 0.1
20
+ self.hidden_size = 768
21
+ self.intermediate_size = 2560
22
+ self.max_sequence_length = 512
23
+ self.position_bucket_size = 32
24
+ self.num_attention_heads = 12
25
+ self.num_layers = 12
26
+ self.vocab_size = 8192
27
+ self.layer_norm_eps = 1e-5
28
+ else:
29
+ if config_file == "str":
30
+ config_file = pathlib.Path(config_file)
31
+
32
+ config: dict[str, Any] = json.load(config_file.open("r"))
33
+
34
+ for key, value in config.items():
35
+ setattr(self, key, value)
36
+
37
+ def __repr__(self) -> str:
38
+ return str(self.to_json_string())
39
+
40
+ def to_dict(self) -> dict[str, Any]:
41
+ """Serializes this instance to a Python dictionary."""
42
+ output: dict[str, Any] = copy.deepcopy(self.__dict__)
43
+ return output
44
+
45
+ def to_json_string(self) -> str:
46
+ """Serializes this instance to a JSON string."""
47
+ return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
48
+
49
+ def to_json_file(self, json_file_path: pathlib.Path | str) -> None:
50
+ """Save this instance to a json file."""
51
+ if isinstance(json_file_path, str):
52
+ json_file_path: pathlib.Path = pathlib.Path(json_file_path)
53
+ with json_file_path.open("w", encoding='utf-8') as writer:
54
+ writer.write(self.to_json_string())
modeling_gpt_bert.py ADDED
@@ -0,0 +1,550 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import math
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from torch import _softmax_backward_data as _softmax_backward_data
9
+ from .configuration_gpt_bert import ModelConfig
10
+
11
+ from transformers.modeling_utils import PreTrainedModel
12
+ from transformers.modeling_outputs import (
13
+ BaseModelOutput,
14
+ CausalLMOutput
15
+ )
16
+
17
+ from typing import Optional, Union
18
+
19
+
20
+ class Layer(nn.Module):
21
+
22
+ def __init__(self: Layer, config: ModelConfig, layer_idx: int = 0):
23
+ super().__init__()
24
+ self.attention = Attention(config)
25
+ self.mlp = FeedForward(config)
26
+
27
+ self.mlp.mlp[1].weight.data *= math.sqrt(1.0 / (2.0 * (1 + layer_idx)))
28
+ self.mlp.mlp[-2].weight.data *= math.sqrt(1.0 / (2.0 * (1 + layer_idx)))
29
+
30
+ def forward(self: Layer, x: torch.Tensor, attention_mask: torch.Tensor, relative_embedding: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
31
+ attention: torch.Tensor
32
+ attention_probs: torch.Tensor
33
+ attention, attention_probs = self.attention(x, attention_mask, relative_embedding)
34
+ x += attention
35
+ x += self.mlp(x)
36
+
37
+ return x, attention_probs
38
+
39
+
40
+ class MaskClassifier(nn.Module):
41
+
42
+ def __init__(self: MaskClassifier, config: ModelConfig, subword_embedding: nn.Parameter):
43
+ super().__init__()
44
+ self.nonlinearity = nn.Sequential(
45
+ nn.LayerNorm(config.hidden_size, config.layer_norm_eps, elementwise_affine=False),
46
+ nn.Linear(config.hidden_size, config.hidden_size),
47
+ nn.GELU(),
48
+ nn.LayerNorm(config.hidden_size, config.layer_norm_eps, elementwise_affine=False),
49
+ nn.Dropout(config.hidden_dropout_prob),
50
+ nn.Linear(subword_embedding.size(1), subword_embedding.size(0))
51
+ )
52
+ self.initialize(config.hidden_size, subword_embedding)
53
+
54
+ def initialize(self: MaskClassifier, hidden_size: int, embedding: nn.Parameter):
55
+ std: float = math.sqrt(2.0 / (5.0 * hidden_size))
56
+ nn.init.trunc_normal_(self.nonlinearity[1].weight, mean=0.0, std=std, a=-2*std, b=2*std)
57
+ self.nonlinearity[-1].weight = embedding
58
+ self.nonlinearity[1].bias.data.zero_()
59
+ self.nonlinearity[-1].bias.data.zero_()
60
+
61
+ def forward(self: MaskClassifier, x: torch.Tensor, masked_lm_labels: torch.Tensor | None = None) -> torch.Tensor:
62
+ if masked_lm_labels is not None:
63
+ x = torch.index_select(x.flatten(0, 1), 0, torch.nonzero(masked_lm_labels.flatten() != -100).squeeze())
64
+ x = self.nonlinearity(x)
65
+
66
+ return x
67
+
68
+
69
+ class GeGLU(nn.Module):
70
+ def forward(self: GeGLU, x: torch.Tensor) -> torch.Tensor:
71
+ gate: torch.Tensor
72
+ x, gate = x.chunk(2, dim=-1)
73
+ x = x * F.gelu(gate, approximate='tanh')
74
+ return x
75
+
76
+
77
+ class FeedForward(nn.Module):
78
+ def __init__(self: FeedForward, config: ModelConfig) -> None:
79
+ super().__init__()
80
+ self.mlp = nn.Sequential(
81
+ nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, elementwise_affine=False),
82
+ nn.Linear(config.hidden_size, 2*config.intermediate_size, bias=False),
83
+ GeGLU(),
84
+ nn.LayerNorm(config.intermediate_size, eps=config.layer_norm_eps, elementwise_affine=False),
85
+ nn.Linear(config.intermediate_size, config.hidden_size, bias=False),
86
+ nn.Dropout(config.hidden_dropout_prob)
87
+ )
88
+ self.initialize(config.hidden_size)
89
+
90
+ def initialize(self: FeedForward, hidden_size: int) -> None:
91
+ std: float = math.sqrt(2.0 / (5.0 * hidden_size))
92
+ nn.init.trunc_normal_(self.mlp[1].weight, mean=0.0, std=std, a=-2*std, b=2*std)
93
+ nn.init.trunc_normal_(self.mlp[-2].weight, mean=0.0, std=std, a=-2*std, b=2*std)
94
+
95
+ def forward(self: FeedForward, x: torch.Tensor) -> torch.Tensor:
96
+ return self.mlp(x)
97
+
98
+
99
+ class MaskedSoftmax(torch.autograd.Function):
100
+ @staticmethod
101
+ def forward(self: MaskedSoftmax, x: torch.Tensor, mask: torch.Tensor, dim: int) -> torch.Tensor:
102
+ self.dim = dim
103
+ x.masked_fill_(mask, float('-inf'))
104
+ x = torch.softmax(x, self.dim)
105
+ x.masked_fill_(mask, 0.0)
106
+ self.save_for_backward(x)
107
+ return x
108
+
109
+ @staticmethod
110
+ def backward(self: MaskedSoftmax, grad_output: torch.Tensor) -> tuple[torch.Tensor, None, None]:
111
+ output: torch.Tensor
112
+ output, = self.saved_tensors
113
+ inputGrad: torch.Tensor = _softmax_backward_data(grad_output, output, self.dim, output.dtype)
114
+ return inputGrad, None, None
115
+
116
+
117
+ class Attention(nn.Module):
118
+ def __init__(self: Attention, config: ModelConfig) -> None:
119
+ super().__init__()
120
+
121
+ self.config: ModelConfig = config
122
+
123
+ if config.hidden_size % config.num_attention_heads != 0:
124
+ raise ValueError(f"The hidden size {config.hidden_size} is not a multiple of the number of attention heads {config.num_attention_heads}")
125
+
126
+ self.hidden_size: int = config.hidden_size
127
+ self.num_heads: int = config.num_attention_heads
128
+ self.head_size: int = config.hidden_size // config.num_attention_heads
129
+
130
+ self.in_proj_qk = nn.Linear(config.hidden_size, 2*config.hidden_size, bias=True)
131
+ self.in_proj_vg = nn.Linear(config.hidden_size, 2*config.hidden_size, bias=True)
132
+ self.out_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=True)
133
+
134
+ self.pre_layer_norm = nn.LayerNorm(config.hidden_size, config.layer_norm_eps, elementwise_affine=False)
135
+ self.post_layer_norm = nn.LayerNorm(config.hidden_size, config.layer_norm_eps, elementwise_affine=False)
136
+
137
+ position_indices: torch.Tensor = torch.arange(config.max_position_embeddings, dtype=torch.long).unsqueeze(1) \
138
+ - torch.arange(config.max_position_embeddings, dtype=torch.long).unsqueeze(0)
139
+ position_indices: torch.Tensor = self.make_log_bucket_position(position_indices, config.position_bucket_size, config.max_position_embeddings)
140
+ position_indices = config.position_bucket_size - 1 + position_indices
141
+ self.register_buffer("position_indices", position_indices, persistent=True)
142
+
143
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
144
+ self.scale: float = 1.0 / math.sqrt(3 * self.head_size)
145
+ self.initialize()
146
+
147
+ def make_log_bucket_position(self: Attention, relative_pos: torch.Tensor, bucket_size: int, max_position: int) -> torch.Tensor:
148
+ sign: torch.Tensor = torch.sign(relative_pos)
149
+ mid: int = bucket_size // 2
150
+ abs_pos: torch.Tensor = torch.where((relative_pos < mid) & (relative_pos > -mid), mid - 1, torch.abs(relative_pos).clamp(max=max_position - 1))
151
+ log_pos: torch.Tensor = torch.ceil(torch.log(abs_pos / mid) / math.log((max_position-1) / mid) * (mid - 1)).int() + mid
152
+ bucket_pos: torch.Tensor = torch.where(abs_pos <= mid, relative_pos, log_pos * sign).long()
153
+ return bucket_pos
154
+
155
+ def initialize(self: Attention) -> None:
156
+ std: float = math.sqrt(2.0 / (5.0 * self.hidden_size))
157
+ nn.init.trunc_normal_(self.in_proj_qk.weight, mean=0.0, std=std, a=-2*std, b=2*std)
158
+ nn.init.trunc_normal_(self.in_proj_vg.weight, mean=0.0, std=std, a=-2*std, b=2*std)
159
+ nn.init.trunc_normal_(self.out_proj.weight, mean=0.0, std=std, a=-2*std, b=2*std)
160
+ self.in_proj_qk.bias.data.zero_()
161
+ self.in_proj_vg.bias.data.zero_()
162
+ self.out_proj.bias.data.zero_()
163
+
164
+ def _create_position_tensors(self: Attention, relative_embedding: torch.Tensor, query_len: int, key_len: int) -> tuple[torch.Tensor, torch.Tensor]:
165
+ pos = self.in_proj_qk(self.dropout(relative_embedding)) # shape: [2T-1, 2D]
166
+ pos = F.embedding(self.position_indices[:query_len, :key_len], pos) # shape: [T, T, 2D]
167
+ query_pos, key_pos = pos.chunk(2, dim=-1)
168
+ query_pos = query_pos.view(query_len, key_len, self.num_heads, self.head_size)
169
+ key_pos = key_pos.view(query_len, key_len, self.num_heads, self.head_size)
170
+
171
+ return query_pos, key_pos
172
+
173
+ def attention_operation(self: Attention, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attention_mask: torch.Tensor, query_pos: torch.Tensor, key_pos: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
174
+ key_len: int
175
+ batch_size: int
176
+ key_len, batch_size, _ = key.size()
177
+ query_len: int
178
+ query_len, _, _ = query.size()
179
+
180
+ query = query.reshape(query_len, batch_size * self.num_heads, self.head_size).transpose(0, 1)
181
+ key = key.reshape(key_len, batch_size * self.num_heads, self.head_size).transpose(0, 1)
182
+ value = value.reshape(key_len, batch_size * self.num_heads, self.head_size).transpose(0, 1)
183
+
184
+ attention_probs: torch.Tensor = torch.bmm(query, key.transpose(1, 2) * self.scale)
185
+
186
+ query = query.view(batch_size, self.num_heads, query_len, self.head_size)
187
+ key = key.view(batch_size, self.num_heads, query_len, self.head_size)
188
+ attention_probs = attention_probs.view(batch_size, self.num_heads, query_len, key_len)
189
+ attention_probs.add_(torch.einsum("bhqd,qkhd->bhqk", query, key_pos * self.scale))
190
+ attention_probs.add_(torch.einsum("bhkd,qkhd->bhqk", key * self.scale, query_pos))
191
+
192
+ attention_probs = MaskedSoftmax.apply(attention_probs, attention_mask, -1)
193
+
194
+ attention_probs = self.dropout(attention_probs)
195
+ attention_output: torch.Tensor = torch.bmm(attention_probs.flatten(0, 1), value) # shape: [B*H, Q, D]
196
+ attention_output = attention_output.transpose(0, 1).reshape(query_len, batch_size, self.hidden_size) # shape: [Q, B, H*D]
197
+
198
+ return attention_output, attention_probs
199
+
200
+ def forward(self: Attention, hidden_states: torch.Tensor, attention_mask: torch.Tensor, relative_embedding: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
201
+ key_len: int
202
+ batch_size: int
203
+ key_len, batch_size, _ = hidden_states.size()
204
+ query_len: int = key_len
205
+
206
+ if self.position_indices.size(0) < query_len:
207
+ position_indices = torch.arange(query_len, dtype=torch.long).unsqueeze(1) \
208
+ - torch.arange(query_len, dtype=torch.long).unsqueeze(0)
209
+ position_indices = self.make_log_bucket_position(position_indices, self.config.position_bucket_size, 512)
210
+ position_indices = self.config.position_bucket_size - 1 + position_indices
211
+ self.register_buffer("position_indices", position_indices.to(hidden_states.device), persistent=True)
212
+
213
+ hidden_states = self.pre_layer_norm(hidden_states)
214
+ query, key = self.in_proj_qk(hidden_states).chunk(2, dim=2) # shape: [T, B, D]
215
+ value, gate = self.in_proj_vg(hidden_states).chunk(2, dim=2) # shape: [T, B, D]
216
+ gate = F.gelu(gate)
217
+
218
+ query_pos: torch.Tensor
219
+ key_pos: torch.Tensor
220
+ query_pos, key_pos = self._create_position_tensors(relative_embedding, query_len, key_len)
221
+
222
+ attention_output: torch.Tensor
223
+ attention_probs: torch.Tensor
224
+ attention_output, attention_probs = self.attention_operation(query, key, value, attention_mask, query_pos, key_pos)
225
+ attention_output = attention_output * gate
226
+ attention_output = self.post_layer_norm(attention_output)
227
+ attention_output = self.out_proj(attention_output)
228
+ attention_output = self.dropout(attention_output)
229
+
230
+ return attention_output, attention_probs
231
+
232
+
233
+ class Embedding(nn.Module):
234
+ def __init__(self: Embedding, config: ModelConfig):
235
+ super().__init__()
236
+ self.hidden_size: int = config.hidden_size
237
+
238
+ self.word_embedding = nn.Embedding(config.vocab_size, config.hidden_size)
239
+ self.word_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, elementwise_affine=False)
240
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
241
+
242
+ self.relative_embedding = nn.Parameter(torch.empty(2 * config.position_bucket_size - 1, config.hidden_size))
243
+ self.relative_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
244
+
245
+ self.initialize()
246
+
247
+ def initialize(self: Embedding):
248
+ std: float = math.sqrt(2.0 / (5.0 * self.hidden_size))
249
+ nn.init.trunc_normal_(self.relative_embedding, mean=0.0, std=std, a=-2*std, b=2*std)
250
+ nn.init.trunc_normal_(self.word_embedding.weight, mean=0.0, std=std, a=-2*std, b=2*std)
251
+
252
+ def forward(self: Embedding, input_ids: torch.Tensor):
253
+ word_embedding: torch.Tensor = self.dropout(self.word_layer_norm(self.word_embedding(input_ids)))
254
+ relative_embeddings: torch.Tensor = self.relative_layer_norm(self.relative_embedding)
255
+ return word_embedding, relative_embeddings
256
+
257
+
258
+ class GPTBERTPreTrainedModel(PreTrainedModel):
259
+ config_class = ModelConfig
260
+ supports_gradient_checkpointing = False
261
+
262
+ def _set_gradient_checkpointing(self, module, value=False):
263
+ raise NotImplementedError("Gradient checkpointing is not supported by this model")
264
+
265
+ def _init_weights(self, module):
266
+ std = math.sqrt(2.0 / (5.0 * self.hidden_size))
267
+
268
+ if isinstance(module, nn.Linear):
269
+ nn.init.trunc_normal_(module.weight.data, mean=0.0, std=std, a=-2*std, b=2*std)
270
+ if module.bias is not None:
271
+ module.bias.data.zero_()
272
+ elif isinstance(module, nn.Embedding):
273
+ nn.init.trunc_normal_(module.weight.data, mean=0.0, std=std, a=-2*std, b=2*std)
274
+ elif isinstance(module, nn.LayerNorm):
275
+ module.bias.data.zero_()
276
+ module.weight.data.fill_(1.0)
277
+
278
+
279
+ class GPTBERT(GPTBERTPreTrainedModel):
280
+
281
+ def __init__(self, config: ModelConfig, is_causal: bool, **kwargs):
282
+ super().__init__(config, **kwargs)
283
+ self.config = config
284
+ self.hidden_size = config.hidden_size
285
+
286
+ self.embedding = Embedding(config)
287
+ self.layers = nn.ModuleList([Layer(config) for _ in range(config.num_layers)])
288
+ self.is_causal = is_causal
289
+
290
+ def get_input_embeddings(self):
291
+ return self.embedding.word_embedding
292
+
293
+ def set_input_embeddings(self, value):
294
+ self.embedding.word_embedding = value
295
+
296
+ def get_contextualized_embeddings(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> list[torch.Tensor]:
297
+ """
298
+ """
299
+ input_shape = input_ids.size()
300
+
301
+ batch_size, seq_length = input_shape
302
+
303
+ if attention_mask is None:
304
+ attention_mask = input_ids.new_ones((seq_length, seq_length), dtype=torch.bool).triu(diagonal=1).unsqueeze(0).unsqueeze(0)
305
+
306
+ if attention_mask is not None:
307
+ attention_mask = ~attention_mask.bool()
308
+
309
+ if len(attention_mask.size()) == 2:
310
+ attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
311
+ elif len(attention_mask.size()) == 3:
312
+ attention_mask = attention_mask.unsqueeze(1)
313
+
314
+ if self.is_causal:
315
+ attention_mask = attention_mask | input_ids.new_ones((seq_length, seq_length), dtype=torch.bool).triu(1).unsqueeze(0).unsqueeze(0)
316
+
317
+ static_embeddings, relative_embeddings = self.embedding(input_ids.t())
318
+ contextualized_embeddings = [static_embeddings]
319
+ attention_probs = []
320
+ for layer in self.layers:
321
+ layer_embeddings, layer_attention_probs = layer(contextualized_embeddings[-1], attention_mask, relative_embeddings)
322
+ contextualized_embeddings.append(layer_embeddings)
323
+ attention_probs.append(layer_attention_probs)
324
+ contextualized_embeddings = [emb.transpose(0, 1) for emb in contextualized_embeddings]
325
+ last_layer = contextualized_embeddings[-1]
326
+ return last_layer, contextualized_embeddings, attention_probs
327
+
328
+ def forward(
329
+ self,
330
+ input_ids: torch.Tensor,
331
+ attention_mask: Optional[torch.Tensor] = None,
332
+ token_type_ids: Optional[torch.Tensor] = None,
333
+ position_ids: Optional[torch.Tensor] = None,
334
+ output_hidden_states: Optional[bool] = None,
335
+ output_attentions: Optional[bool] = None,
336
+ return_dict: Optional[bool] = None,
337
+ **kwargs
338
+ ) -> Union[tuple[torch.Tensor], BaseModelOutput]:
339
+ """
340
+ """
341
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
342
+
343
+ sequence_output, contextualized_embeddings, attention_probs = self.get_contextualized_embeddings(input_ids, attention_mask)
344
+
345
+ if not return_dict:
346
+ return (
347
+ sequence_output,
348
+ *([contextualized_embeddings] if output_hidden_states else []),
349
+ *([attention_probs] if output_attentions else [])
350
+ )
351
+
352
+ return BaseModelOutput(
353
+ last_hidden_state=sequence_output,
354
+ hidden_states=contextualized_embeddings if output_hidden_states else None,
355
+ attentions=attention_probs if output_attentions else None
356
+ )
357
+
358
+ # To do Masked Language Modeling instead, you can replace MyModelForCausalLM by MyModelForMaskedLM
359
+ # and change the output type from CausalLMOutput to MaskedLMOutput.
360
+
361
+
362
+ class GPTBERTForCausalLM(GPTBERTPreTrainedModel):
363
+ _keys_to_ignore_on_load_unexpected = ["head"]
364
+
365
+ def __init__(self, config, **kwargs):
366
+ super().__init__(config, **kwargs)
367
+ self.model = GPTBERT(config, is_causal=True, **kwargs)
368
+ self.vocab_size = config.vocab_size
369
+ self.lm_head = MaskClassifier(config, self.model.embedding.word_embedding.weight)
370
+ self.hidden_size = config.hidden_size
371
+
372
+ def get_output_embeddings(self):
373
+ return self.lm_head.nonlinearity[-1].weight
374
+
375
+ def set_output_embeddings(self, new_embeddings):
376
+ self.lm_head.nonlinearity[-1].weight = new_embeddings
377
+
378
+ def get_input_embeddings(self):
379
+ return self.model.embedding.word_embedding
380
+
381
+ def set_input_embeddings(self, value):
382
+ self.model.embedding.word_embedding = value
383
+
384
+ def set_decoder(self, decoder):
385
+ self.model = decoder
386
+
387
+ def get_decoder(self):
388
+ return self.model
389
+
390
+ def can_generate(self):
391
+ return True
392
+
393
+ def forward(
394
+ self,
395
+ input_ids: torch.Tensor,
396
+ attention_mask: Optional[torch.Tensor] = None,
397
+ token_type_ids: Optional[torch.Tensor] = None,
398
+ position_ids: Optional[torch.Tensor] = None,
399
+ output_hidden_states: Optional[bool] = None,
400
+ output_attentions: Optional[bool] = None,
401
+ return_dict: Optional[bool] = None,
402
+ labels: Optional[torch.LongTensor] = None,
403
+ **kwargs
404
+ ) -> Union[tuple, CausalLMOutput]:
405
+
406
+ sequence_output, contextualized_embeddings, attention_probs = self.model.get_contextualized_embeddings(input_ids, attention_mask)
407
+ subword_prediction = self.lm_head(sequence_output)
408
+
409
+ loss = None
410
+ if labels is not None:
411
+ gold_labels = labels.flatten()
412
+ gold_labels = gold_labels[gold_labels != -100]
413
+
414
+ loss = F.cross_entropy(subword_prediction, gold_labels)
415
+
416
+ if not return_dict:
417
+ output = (
418
+ subword_prediction,
419
+ *([contextualized_embeddings] if output_hidden_states else []),
420
+ *([attention_probs] if output_attentions else [])
421
+ )
422
+ return ((loss,) + output) if loss is not None else output
423
+
424
+ return CausalLMOutput(
425
+ loss=loss,
426
+ logits=subword_prediction,
427
+ hidden_states=contextualized_embeddings if output_hidden_states else None,
428
+ attentions=attention_probs if output_attentions else None
429
+ )
430
+
431
+ def prepare_inputs_for_generation(
432
+ self,
433
+ input_ids: torch.Tensor,
434
+ past_key_values: Optional[torch.Tensor] = None,
435
+ attention_mask: Optional[torch.Tensor] = None,
436
+ inputs_embeds: Optional[torch.Tensor] = None,
437
+ cache_position: Optional[torch.LongTensor] = None,
438
+ position_ids: Optional[torch.LongTensor] = None,
439
+ use_cache: bool = True,
440
+ num_logits_to_keep: Optional[int] = None,
441
+ **kwargs,
442
+ ):
443
+ # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
444
+ # Exception 1: when passing input_embeds, input_ids may be missing entries
445
+ # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
446
+ if past_key_values is not None:
447
+ if inputs_embeds is not None: # Exception 1
448
+ input_ids = input_ids[:, -cache_position.shape[0] :]
449
+ elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
450
+ input_ids = input_ids[:, cache_position]
451
+
452
+ if attention_mask is not None and position_ids is None:
453
+ # create position_ids on the fly for batch generation
454
+ position_ids = attention_mask.long().cumsum(-1) - 1
455
+ position_ids.masked_fill_(attention_mask == 0, 1)
456
+ if past_key_values:
457
+ position_ids = position_ids[:, -input_ids.shape[1] :]
458
+
459
+ # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
460
+ position_ids = position_ids.clone(memory_format=torch.contiguous_format)
461
+
462
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
463
+ if inputs_embeds is not None and cache_position[0] == 0:
464
+ model_inputs = {"inputs_embeds": inputs_embeds}
465
+ else:
466
+ model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases
467
+
468
+ if num_logits_to_keep is not None:
469
+ model_inputs["num_logits_to_keep"] = num_logits_to_keep
470
+
471
+ model_inputs.update(
472
+ {
473
+ "position_ids": position_ids,
474
+ "cache_position": cache_position,
475
+ "past_key_values": past_key_values,
476
+ "use_cache": use_cache,
477
+ "attention_mask": attention_mask,
478
+ }
479
+ )
480
+ return model_inputs
481
+
482
+
483
+ class GPTBERTForMaskedLM(GPTBERTPreTrainedModel):
484
+ _keys_to_ignore_on_load_unexpected = ["head"]
485
+
486
+ def __init__(self, config, **kwargs):
487
+ super().__init__(config, **kwargs)
488
+ self.model = GPTBERT(config, is_causal=False, **kwargs)
489
+ self.vocab_size = config.vocab_size
490
+ self.lm_head = MaskClassifier(config, self.model.embedding.word_embedding.weight)
491
+ self.hidden_size = config.hidden_size
492
+
493
+ def get_output_embeddings(self):
494
+ return self.lm_head.nonlinearity[-1].weight
495
+
496
+ def set_output_embeddings(self, new_embeddings):
497
+ self.lm_head.nonlinearity[-1].weight = new_embeddings
498
+
499
+ def get_input_embeddings(self):
500
+ return self.model.embedding.word_embedding
501
+
502
+ def set_input_embeddings(self, value):
503
+ self.model.embedding.word_embedding = value
504
+
505
+ def set_encoder(self, encoder):
506
+ self.model = encoder
507
+
508
+ def get_encoder(self):
509
+ return self.model
510
+
511
+ def can_generate(self):
512
+ return True
513
+
514
+ def forward(
515
+ self,
516
+ input_ids: torch.Tensor,
517
+ attention_mask: Optional[torch.Tensor] = None,
518
+ token_type_ids: Optional[torch.Tensor] = None,
519
+ position_ids: Optional[torch.Tensor] = None,
520
+ output_hidden_states: Optional[bool] = None,
521
+ output_attentions: Optional[bool] = None,
522
+ return_dict: Optional[bool] = None,
523
+ labels: Optional[torch.LongTensor] = None,
524
+ **kwargs
525
+ ) -> Union[tuple, CausalLMOutput]:
526
+
527
+ sequence_output, contextualized_embeddings, attention_probs = self.model.get_contextualized_embeddings(input_ids, attention_mask)
528
+ subword_prediction = self.lm_head(sequence_output)
529
+
530
+ loss = None
531
+ if labels is not None:
532
+ gold_labels = labels.flatten()
533
+ gold_labels = gold_labels[gold_labels != -100]
534
+
535
+ loss = F.cross_entropy(subword_prediction, gold_labels)
536
+
537
+ if not return_dict:
538
+ output = (
539
+ subword_prediction,
540
+ *([contextualized_embeddings] if output_hidden_states else []),
541
+ *([attention_probs] if output_attentions else [])
542
+ )
543
+ return ((loss,) + output) if loss is not None else output
544
+
545
+ return CausalLMOutput(
546
+ loss=loss,
547
+ logits=subword_prediction,
548
+ hidden_states=contextualized_embeddings if output_hidden_states else None,
549
+ attentions=attention_probs if output_attentions else None
550
+ )
special_tokens_map.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"bos_token": "<s>", "eos_token": "</s>", "unk_token": "<unk>", "sep_token": "</s>", "pad_token": "<pad>", "cls_token": "<s>", "mask_token": "<mask>"}
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "tokenizer_class": "PreTrainedTokenizerFast",
3
+ "bos_token": "<s>",
4
+ "eos_token": "</s>",
5
+ "unk_token": "<unk>",
6
+ "sep_token": "</s>",
7
+ "pad_token": "<pad>",
8
+ "cls_token": "<s>",
9
+ "mask_token": "<mask>"
10
+ }