Uploading patch
Browse files- modeling_gpt_bert.py +67 -3
modeling_gpt_bert.py
CHANGED
@@ -17,6 +17,58 @@ from transformers.modeling_outputs import (
|
|
17 |
from typing import Optional, Union
|
18 |
|
19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
class Layer(nn.Module):
|
21 |
|
22 |
def __init__(self: Layer, config: ModelConfig, layer_idx: int = 0):
|
@@ -284,7 +336,14 @@ class GPTBERT(GPTBERTPreTrainedModel):
|
|
284 |
self.hidden_size = config.hidden_size
|
285 |
|
286 |
self.embedding = Embedding(config)
|
287 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
288 |
self.is_causal = is_causal
|
289 |
|
290 |
def get_input_embeddings(self):
|
@@ -316,8 +375,13 @@ class GPTBERT(GPTBERTPreTrainedModel):
|
|
316 |
static_embeddings, relative_embeddings = self.embedding(input_ids.t())
|
317 |
contextualized_embeddings = [static_embeddings]
|
318 |
attention_probs = []
|
319 |
-
|
320 |
-
|
|
|
|
|
|
|
|
|
|
|
321 |
contextualized_embeddings.append(layer_embeddings)
|
322 |
attention_probs.append(layer_attention_probs)
|
323 |
contextualized_embeddings = [emb.transpose(0, 1) for emb in contextualized_embeddings]
|
|
|
17 |
from typing import Optional, Union
|
18 |
|
19 |
|
20 |
+
# From https://github.com/epfml/DenseFormer
|
21 |
+
class InPlaceSetSlice(torch.autograd.Function):
|
22 |
+
@staticmethod
|
23 |
+
def forward(ctx, full_tensor, last_slice, x_idx, x_val):
|
24 |
+
full_tensor[x_idx] = x_val
|
25 |
+
ctx.x_idx = x_idx
|
26 |
+
ret = torch.Tensor().to(full_tensor.device)
|
27 |
+
ret.set_(full_tensor[:x_idx + 1])
|
28 |
+
return ret
|
29 |
+
|
30 |
+
@staticmethod
|
31 |
+
def backward(ctx, grad_out):
|
32 |
+
if ctx.x_idx == 0:
|
33 |
+
return None, None, None, grad_out[ctx.x_idx]
|
34 |
+
else:
|
35 |
+
return None, grad_out[:ctx.x_idx], None, grad_out[ctx.x_idx]
|
36 |
+
|
37 |
+
|
38 |
+
def apply_inplace_set(x_acc, x_idx, x_val):
|
39 |
+
full_tensor, last_slice = x_acc
|
40 |
+
new_slice = InPlaceSetSlice.apply(full_tensor, last_slice, x_idx, x_val)
|
41 |
+
return full_tensor, new_slice
|
42 |
+
|
43 |
+
|
44 |
+
class DWAModules(torch.nn.Module):
|
45 |
+
def __init__(self, hidden_size, n_blocks):
|
46 |
+
super().__init__()
|
47 |
+
self.n_blocks = n_blocks
|
48 |
+
self.alphas = nn.ParameterList([nn.Parameter(torch.zeros(i + 2)) for i in range(n_blocks)])
|
49 |
+
self.accumulator = None
|
50 |
+
self._init_weights()
|
51 |
+
|
52 |
+
def _init_weights(self):
|
53 |
+
for module in self.alphas:
|
54 |
+
module.data.zero_()
|
55 |
+
module.data[-1] = 1.0
|
56 |
+
|
57 |
+
def init_accumulator(self, x):
|
58 |
+
self.accumulator = (torch.zeros((self.n_blocks + 1, *x.shape), device=x.device, dtype=x.dtype), None)
|
59 |
+
self.accumulator = apply_inplace_set(self.accumulator, 0, x)
|
60 |
+
|
61 |
+
def forward(self, x, block_idx):
|
62 |
+
assert self.accumulator is not None, "`init_accumulator(x)` needs to be called first"
|
63 |
+
self.accumulator = apply_inplace_set(
|
64 |
+
self.accumulator,
|
65 |
+
block_idx + 1,
|
66 |
+
x
|
67 |
+
)
|
68 |
+
x = torch.tensordot(self.alphas[block_idx], self.accumulator[1], dims=1)
|
69 |
+
return x
|
70 |
+
|
71 |
+
|
72 |
class Layer(nn.Module):
|
73 |
|
74 |
def __init__(self: Layer, config: ModelConfig, layer_idx: int = 0):
|
|
|
336 |
self.hidden_size = config.hidden_size
|
337 |
|
338 |
self.embedding = Embedding(config)
|
339 |
+
self.attention_layers = nn.ModuleList([Attention(config) for _ in range(config.num_layers)])
|
340 |
+
self.mlp_layers = nn.ModuleList([FeedForward(config) for _ in range(config.num_layers)])
|
341 |
+
self.dwa_modules = DWAModules(config.hidden_size, config.num_hidden_layers * 2)
|
342 |
+
|
343 |
+
for i, layer in enumerate(self.mlp_layers):
|
344 |
+
layer.mlp[1].weight.data *= math.sqrt(1.0 / (2.0 * (1 + i)))
|
345 |
+
layer.mlp[-2].weight.data *= math.sqrt(1.0 / (2.0 * (1 + i)))
|
346 |
+
|
347 |
self.is_causal = is_causal
|
348 |
|
349 |
def get_input_embeddings(self):
|
|
|
375 |
static_embeddings, relative_embeddings = self.embedding(input_ids.t())
|
376 |
contextualized_embeddings = [static_embeddings]
|
377 |
attention_probs = []
|
378 |
+
self.dwa_modules.init_accumulator(static_embeddings)
|
379 |
+
for i, (attention_layer, mlp_layer) in enumerate(zip(self.attention_layers, self.mlp_layers)):
|
380 |
+
attention, layer_attention_probs = attention_layer(contextualized_embeddings[-1], attention_mask, relative_embeddings)
|
381 |
+
layer_embeddings = contextualized_embeddings[-1] + attention
|
382 |
+
layer_embeddings = self.dwa_modules(layer_embeddings, block_idx=i * 2)
|
383 |
+
layer_embeddings += mlp_layer(layer_embeddings)
|
384 |
+
layer_embeddings = self.dwa_modules(layer_embeddings, block_idx=i * 2 + 1)
|
385 |
contextualized_embeddings.append(layer_embeddings)
|
386 |
attention_probs.append(layer_attention_probs)
|
387 |
contextualized_embeddings = [emb.transpose(0, 1) for emb in contextualized_embeddings]
|