lgcharpe commited on
Commit
9ec854e
·
verified ·
1 Parent(s): e3a9de8

Uploading patch

Browse files
Files changed (1) hide show
  1. modeling_gpt_bert.py +67 -3
modeling_gpt_bert.py CHANGED
@@ -17,6 +17,58 @@ from transformers.modeling_outputs import (
17
  from typing import Optional, Union
18
 
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  class Layer(nn.Module):
21
 
22
  def __init__(self: Layer, config: ModelConfig, layer_idx: int = 0):
@@ -284,7 +336,14 @@ class GPTBERT(GPTBERTPreTrainedModel):
284
  self.hidden_size = config.hidden_size
285
 
286
  self.embedding = Embedding(config)
287
- self.layers = nn.ModuleList([Layer(config) for _ in range(config.num_layers)])
 
 
 
 
 
 
 
288
  self.is_causal = is_causal
289
 
290
  def get_input_embeddings(self):
@@ -316,8 +375,13 @@ class GPTBERT(GPTBERTPreTrainedModel):
316
  static_embeddings, relative_embeddings = self.embedding(input_ids.t())
317
  contextualized_embeddings = [static_embeddings]
318
  attention_probs = []
319
- for layer in self.layers:
320
- layer_embeddings, layer_attention_probs = layer(contextualized_embeddings[-1], attention_mask, relative_embeddings)
 
 
 
 
 
321
  contextualized_embeddings.append(layer_embeddings)
322
  attention_probs.append(layer_attention_probs)
323
  contextualized_embeddings = [emb.transpose(0, 1) for emb in contextualized_embeddings]
 
17
  from typing import Optional, Union
18
 
19
 
20
+ # From https://github.com/epfml/DenseFormer
21
+ class InPlaceSetSlice(torch.autograd.Function):
22
+ @staticmethod
23
+ def forward(ctx, full_tensor, last_slice, x_idx, x_val):
24
+ full_tensor[x_idx] = x_val
25
+ ctx.x_idx = x_idx
26
+ ret = torch.Tensor().to(full_tensor.device)
27
+ ret.set_(full_tensor[:x_idx + 1])
28
+ return ret
29
+
30
+ @staticmethod
31
+ def backward(ctx, grad_out):
32
+ if ctx.x_idx == 0:
33
+ return None, None, None, grad_out[ctx.x_idx]
34
+ else:
35
+ return None, grad_out[:ctx.x_idx], None, grad_out[ctx.x_idx]
36
+
37
+
38
+ def apply_inplace_set(x_acc, x_idx, x_val):
39
+ full_tensor, last_slice = x_acc
40
+ new_slice = InPlaceSetSlice.apply(full_tensor, last_slice, x_idx, x_val)
41
+ return full_tensor, new_slice
42
+
43
+
44
+ class DWAModules(torch.nn.Module):
45
+ def __init__(self, hidden_size, n_blocks):
46
+ super().__init__()
47
+ self.n_blocks = n_blocks
48
+ self.alphas = nn.ParameterList([nn.Parameter(torch.zeros(i + 2)) for i in range(n_blocks)])
49
+ self.accumulator = None
50
+ self._init_weights()
51
+
52
+ def _init_weights(self):
53
+ for module in self.alphas:
54
+ module.data.zero_()
55
+ module.data[-1] = 1.0
56
+
57
+ def init_accumulator(self, x):
58
+ self.accumulator = (torch.zeros((self.n_blocks + 1, *x.shape), device=x.device, dtype=x.dtype), None)
59
+ self.accumulator = apply_inplace_set(self.accumulator, 0, x)
60
+
61
+ def forward(self, x, block_idx):
62
+ assert self.accumulator is not None, "`init_accumulator(x)` needs to be called first"
63
+ self.accumulator = apply_inplace_set(
64
+ self.accumulator,
65
+ block_idx + 1,
66
+ x
67
+ )
68
+ x = torch.tensordot(self.alphas[block_idx], self.accumulator[1], dims=1)
69
+ return x
70
+
71
+
72
  class Layer(nn.Module):
73
 
74
  def __init__(self: Layer, config: ModelConfig, layer_idx: int = 0):
 
336
  self.hidden_size = config.hidden_size
337
 
338
  self.embedding = Embedding(config)
339
+ self.attention_layers = nn.ModuleList([Attention(config) for _ in range(config.num_layers)])
340
+ self.mlp_layers = nn.ModuleList([FeedForward(config) for _ in range(config.num_layers)])
341
+ self.dwa_modules = DWAModules(config.hidden_size, config.num_hidden_layers * 2)
342
+
343
+ for i, layer in enumerate(self.mlp_layers):
344
+ layer.mlp[1].weight.data *= math.sqrt(1.0 / (2.0 * (1 + i)))
345
+ layer.mlp[-2].weight.data *= math.sqrt(1.0 / (2.0 * (1 + i)))
346
+
347
  self.is_causal = is_causal
348
 
349
  def get_input_embeddings(self):
 
375
  static_embeddings, relative_embeddings = self.embedding(input_ids.t())
376
  contextualized_embeddings = [static_embeddings]
377
  attention_probs = []
378
+ self.dwa_modules.init_accumulator(static_embeddings)
379
+ for i, (attention_layer, mlp_layer) in enumerate(zip(self.attention_layers, self.mlp_layers)):
380
+ attention, layer_attention_probs = attention_layer(contextualized_embeddings[-1], attention_mask, relative_embeddings)
381
+ layer_embeddings = contextualized_embeddings[-1] + attention
382
+ layer_embeddings = self.dwa_modules(layer_embeddings, block_idx=i * 2)
383
+ layer_embeddings += mlp_layer(layer_embeddings)
384
+ layer_embeddings = self.dwa_modules(layer_embeddings, block_idx=i * 2 + 1)
385
  contextualized_embeddings.append(layer_embeddings)
386
  attention_probs.append(layer_attention_probs)
387
  contextualized_embeddings = [emb.transpose(0, 1) for emb in contextualized_embeddings]