asharsha30
commited on
Upload Train_GPT2_diff.txt
Browse files- Train_GPT2_diff.txt +298 -0
Train_GPT2_diff.txt
ADDED
@@ -0,0 +1,298 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
35a36
|
2 |
+
> from hellaswag import render_example, iterate_examples
|
3 |
+
36a38,39
|
4 |
+
> import torch._dynamo
|
5 |
+
> torch._dynamo.config.suppress_errors = True
|
6 |
+
48c51,54
|
7 |
+
< class CausalSelfAttention(nn.Module):
|
8 |
+
---
|
9 |
+
> class NewGELU(nn.Module):
|
10 |
+
> """Careful there are a few versions of GeLU, this one is the exact one used by OpenAI"""
|
11 |
+
> def forward(self, input):
|
12 |
+
> return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
|
13 |
+
49a56,79
|
14 |
+
> # Rotary Position Embedding
|
15 |
+
> def apply_rotary_pos_emb(q, k, sin, cos):
|
16 |
+
> q_embed = (q * cos) + rotate_half(q) * sin
|
17 |
+
> k_embed = (k * cos) + rotate_half(k) * sin
|
18 |
+
> return q_embed, k_embed
|
19 |
+
>
|
20 |
+
> def rotate_half(x):
|
21 |
+
> x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
|
22 |
+
> return torch.cat((-x2, x1), dim=-1)
|
23 |
+
>
|
24 |
+
> class RotaryEmbedding(nn.Module):
|
25 |
+
> def __init__(self, dim):
|
26 |
+
> super().__init__()
|
27 |
+
> inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
|
28 |
+
> self.register_buffer("inv_freq", inv_freq)
|
29 |
+
>
|
30 |
+
> def forward(self, seq_len, device):
|
31 |
+
> t = torch.arange(seq_len, device=device).type_as(self.inv_freq)
|
32 |
+
> freqs = torch.einsum("i , j -> i j", t, self.inv_freq)
|
33 |
+
> emb = torch.cat((freqs, freqs), dim=-1)
|
34 |
+
> sin, cos = emb.sin(), emb.cos()
|
35 |
+
> return sin, cos
|
36 |
+
>
|
37 |
+
> class CausalSelfAttention(nn.Module):
|
38 |
+
53,58d82
|
39 |
+
< # key, query, value projections for all heads, but in a batch
|
40 |
+
< self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
|
41 |
+
< # output projection
|
42 |
+
< self.c_proj = nn.Linear(config.n_embd, config.n_embd)
|
43 |
+
< self.c_proj.LLMC_RESIDUAL_SCALE_FLAG = 1
|
44 |
+
< # regularization
|
45 |
+
61c85,92
|
46 |
+
< # not really a 'bias', more of a mask, but following the OpenAI/HF naming though
|
47 |
+
---
|
48 |
+
> self.grouped_heads = config.grouped_heads
|
49 |
+
> self.head_dim = config.n_embd // config.n_head
|
50 |
+
>
|
51 |
+
> self.c_attn = nn.Linear(config.n_embd, (2 * config.n_head + self.grouped_heads) * self.head_dim)
|
52 |
+
> self.c_proj = nn.Linear(config.n_embd, config.n_embd)
|
53 |
+
>
|
54 |
+
> self.rotary_embedding = RotaryEmbedding(self.head_dim)
|
55 |
+
>
|
56 |
+
63c94
|
57 |
+
< .view(1, 1, config.block_size, config.block_size))
|
58 |
+
---
|
59 |
+
> .view(1, 1, config.block_size, config.block_size))
|
60 |
+
66,67c97
|
61 |
+
< B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
|
62 |
+
< # calculate query, key, values for all heads in batch and move head forward to be the batch dim
|
63 |
+
---
|
64 |
+
> B, T, C = x.size()
|
65 |
+
69,84c99,120
|
66 |
+
< q, k, v = qkv.split(self.n_embd, dim=2)
|
67 |
+
< k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
68 |
+
< q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
69 |
+
< v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
70 |
+
< if FLASH:
|
71 |
+
< # flashattention
|
72 |
+
< y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
|
73 |
+
< else:
|
74 |
+
< # manual implementation of attention
|
75 |
+
< # this materializes the large (T,T) matrix for all the queries and keys
|
76 |
+
< att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
|
77 |
+
< att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
|
78 |
+
< att = F.softmax(att, dim=-1)
|
79 |
+
< y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
|
80 |
+
< y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
|
81 |
+
< # output projection
|
82 |
+
---
|
83 |
+
> q, k, v = torch.split(qkv, [self.grouped_heads * self.head_dim, self.n_head * self.head_dim, self.n_head * self.head_dim], dim=2)
|
84 |
+
>
|
85 |
+
> # Reshape for multi-head attention
|
86 |
+
> q = q.view(B, T, self.grouped_heads, self.head_dim).transpose(1, 2)
|
87 |
+
> k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
|
88 |
+
> v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
|
89 |
+
>
|
90 |
+
> # Apply RoPE
|
91 |
+
> sin, cos = self.rotary_embedding(T, x.device)
|
92 |
+
> q, k = apply_rotary_pos_emb(q, k, sin, cos)
|
93 |
+
>
|
94 |
+
> # Expand q to match the number of key/value heads
|
95 |
+
> q = q.repeat_interleave(self.n_head // self.grouped_heads, dim=1)
|
96 |
+
>
|
97 |
+
> # Attention computation
|
98 |
+
> att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(self.head_dim))
|
99 |
+
> att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float('-inf'))
|
100 |
+
> att = F.softmax(att, dim=-1)
|
101 |
+
> y = att @ v
|
102 |
+
>
|
103 |
+
> # Reshape output
|
104 |
+
> y = y.transpose(1, 2).contiguous().view(B, T, C)
|
105 |
+
87c123,124
|
106 |
+
<
|
107 |
+
---
|
108 |
+
>
|
109 |
+
>
|
110 |
+
89d125
|
111 |
+
<
|
112 |
+
92,95c128,130
|
113 |
+
< self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd)
|
114 |
+
< self.gelu = NewGELU()
|
115 |
+
< self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd)
|
116 |
+
< self.c_proj.LLMC_RESIDUAL_SCALE_FLAG = 1
|
117 |
+
---
|
118 |
+
> self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd)
|
119 |
+
> self.gelu = NewGELU() # Using GeLU activation
|
120 |
+
> self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd)
|
121 |
+
99c134
|
122 |
+
< x = self.gelu(x)
|
123 |
+
---
|
124 |
+
> x = self.gelu(x) # GeLU activation
|
125 |
+
104d138
|
126 |
+
<
|
127 |
+
117,119d150
|
128 |
+
< # -----------------------------------------------------------------------------
|
129 |
+
< # The main GPT-2 model
|
130 |
+
<
|
131 |
+
122c153
|
132 |
+
< block_size: int = 1024
|
133 |
+
---
|
134 |
+
> block_size: int = 2048
|
135 |
+
124,126c155,158
|
136 |
+
< n_layer: int = 12
|
137 |
+
< n_head: int = 12
|
138 |
+
< n_embd: int = 768
|
139 |
+
---
|
140 |
+
> n_layer: int = 16
|
141 |
+
> n_head: int = 16
|
142 |
+
> grouped_heads: int = 4 # Number of grouped heads for GQA
|
143 |
+
> n_embd: int = 1024
|
144 |
+
129d160
|
145 |
+
<
|
146 |
+
135,136c166
|
147 |
+
< wte = nn.Embedding(config.vocab_size, config.n_embd),
|
148 |
+
< wpe = nn.Embedding(config.block_size, config.n_embd),
|
149 |
+
---
|
150 |
+
> wte = nn.Embedding(config.vocab_size, config.n_embd), # Token embedding only, no wpe
|
151 |
+
140,142d169
|
152 |
+
< self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
153 |
+
< self.lm_head.LLMC_SKIP_INIT = 1 # don't init this one, we will tie weights
|
154 |
+
< self.transformer.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying
|
155 |
+
144,160c171,173
|
156 |
+
< # init all weights, use a torch rng object to be very careful
|
157 |
+
< self.init_rng = torch.Generator()
|
158 |
+
< self.init_rng.manual_seed(42)
|
159 |
+
< self.apply(self._init_weights)
|
160 |
+
<
|
161 |
+
< def _init_weights(self, module):
|
162 |
+
< if isinstance(module, nn.Linear):
|
163 |
+
< # apply special scaled init to the residual projections, per GPT-2 paper
|
164 |
+
< std = 0.02 if not hasattr(module, 'LLMC_RESIDUAL_SCALE_FLAG') else 0.02/math.sqrt(2 * self.config.n_layer)
|
165 |
+
< # we want to skip initializing lm_head, which shares parameters with wte
|
166 |
+
< # and wte was already initialized down below during the Embedding init
|
167 |
+
< if not hasattr(module, 'LLMC_SKIP_INIT'):
|
168 |
+
< torch.nn.init.normal_(module.weight, mean=0.0, std=std, generator=self.init_rng)
|
169 |
+
< if module.bias is not None:
|
170 |
+
< torch.nn.init.zeros_(module.bias)
|
171 |
+
< elif isinstance(module, nn.Embedding):
|
172 |
+
< torch.nn.init.normal_(module.weight, mean=0.0, std=0.02, generator=self.init_rng)
|
173 |
+
---
|
174 |
+
> self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
175 |
+
> self.lm_head.LLMC_SKIP_INIT = 1 # Weight tying
|
176 |
+
> self.transformer.wte.weight = self.lm_head.weight
|
177 |
+
166d178
|
178 |
+
< pos = torch.arange(0, t, dtype=torch.long, device=device) # shape (t)
|
179 |
+
168,171c180,182
|
180 |
+
< # forward the GPT model itself
|
181 |
+
< tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
|
182 |
+
< pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
|
183 |
+
< x = tok_emb + pos_emb
|
184 |
+
---
|
185 |
+
> # forward GPT model
|
186 |
+
> tok_emb = self.transformer.wte(idx) # token embeddings
|
187 |
+
> x = tok_emb
|
188 |
+
176a188,189
|
189 |
+
> logits = self.lm_head(x)
|
190 |
+
>
|
191 |
+
178,179d190
|
192 |
+
< # if we are given some desired targets also calculate the loss
|
193 |
+
< logits = self.lm_head(x)
|
194 |
+
182,183d192
|
195 |
+
< # inference-time mini-optimization: only forward the lm_head on the very last position
|
196 |
+
< logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
|
197 |
+
186c195
|
198 |
+
< # there are performance reasons why not returning logits is prudent, if not needed
|
199 |
+
---
|
200 |
+
> # if return_logits is False, return only the loss (used for training)
|
201 |
+
188c197
|
202 |
+
< logits = None
|
203 |
+
---
|
204 |
+
> return None, loss
|
205 |
+
189a199
|
206 |
+
> # return logits and optionally the loss (used for inference and training)
|
207 |
+
201,204c211,214
|
208 |
+
< 'gpt2': dict(n_layer=12, n_head=12, n_embd=768), # 124M params
|
209 |
+
< 'gpt2-medium': dict(n_layer=24, n_head=16, n_embd=1024), # 350M params
|
210 |
+
< 'gpt2-large': dict(n_layer=36, n_head=20, n_embd=1280), # 774M params
|
211 |
+
< 'gpt2-xl': dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params
|
212 |
+
---
|
213 |
+
> 'gpt2': dict(n_layer=12, n_head=12, grouped_heads=4, n_embd=768), # 124M params
|
214 |
+
> 'gpt2-medium': dict(n_layer=24, n_head=16, grouped_heads=8, n_embd=1024), # 350M params
|
215 |
+
> 'gpt2-large': dict(n_layer=36, n_head=20, grouped_heads=10, n_embd=1280), # 774M params
|
216 |
+
> 'gpt2-xl': dict(n_layer=48, n_head=25, grouped_heads=12, n_embd=1600), # 1558M params
|
217 |
+
298a309
|
218 |
+
>
|
219 |
+
378a390,407
|
220 |
+
> def get_most_likely_row(tokens, mask, logits):
|
221 |
+
> # evaluate the autoregressive loss at all positions
|
222 |
+
> shift_logits = (logits[..., :-1, :]).contiguous()
|
223 |
+
> shift_tokens = (tokens[..., 1:]).contiguous()
|
224 |
+
> flat_shift_logits = shift_logits.view(-1, shift_logits.size(-1))
|
225 |
+
> flat_shift_tokens = shift_tokens.view(-1)
|
226 |
+
> shift_losses = F.cross_entropy(flat_shift_logits, flat_shift_tokens, reduction='none')
|
227 |
+
> shift_losses = shift_losses.view(tokens.size(0), -1)
|
228 |
+
> # now get the average loss just for the completion region (where mask == 1), in each row
|
229 |
+
> shift_mask = (mask[..., 1:]).contiguous() # we must shift mask, so we start at the last prompt token
|
230 |
+
> masked_shift_losses = shift_losses * shift_mask
|
231 |
+
> # sum and divide by the number of 1s in the mask
|
232 |
+
> sum_loss = masked_shift_losses.sum(dim=1)
|
233 |
+
> avg_loss = sum_loss / shift_mask.sum(dim=1)
|
234 |
+
> # now we have a loss for each of the 4 completions
|
235 |
+
> # the one with the lowest loss should be the most likely
|
236 |
+
> pred_norm = avg_loss.argmin().item()
|
237 |
+
> return pred_norm
|
238 |
+
655c684
|
239 |
+
< "d12": GPTConfig(block_size=1024, vocab_size=50257, n_layer=12, n_head=12, n_embd=768),
|
240 |
+
---
|
241 |
+
> "d12": GPTConfig(block_size=2024, vocab_size=50257, n_layer=12, n_head=12, n_embd=1024),
|
242 |
+
702,705d730
|
243 |
+
< # -------------------------------------------------------------------------
|
244 |
+
< # main training loop
|
245 |
+
<
|
246 |
+
< # here we wrap model into DDP container
|
247 |
+
738a764,765
|
248 |
+
>
|
249 |
+
>
|
250 |
+
758c785,786
|
251 |
+
< _, loss = model(x, y, return_logits=False)
|
252 |
+
---
|
253 |
+
> logits, loss = model(x, y)
|
254 |
+
> print(logits.shape)
|
255 |
+
782a811,853
|
256 |
+
>
|
257 |
+
> if step in [50,5000,10000,15000,19560]:
|
258 |
+
> save_path = f"{args.output_dir}/model_checkpoint_{step}.bin"
|
259 |
+
> torch.save(model.state_dict(), save_path)
|
260 |
+
> print0(f"Model saved at step {step} to {save_path}")
|
261 |
+
>
|
262 |
+
> if (step % 250 == 0 or last_step or step == 10): #and (not use_compile):
|
263 |
+
> num_correct_norm = 0
|
264 |
+
> num_total = 0
|
265 |
+
> for i, example in enumerate(iterate_examples("val")):
|
266 |
+
> # only process examples where i % ddp_world_size == ddp_rank
|
267 |
+
> if i % ddp_world_size != ddp_rank:
|
268 |
+
> continue
|
269 |
+
> # render the example into tokens and labels
|
270 |
+
> _, tokens, mask, label = render_example(example)
|
271 |
+
> tokens = tokens.to(device)
|
272 |
+
> mask = mask.to(device)
|
273 |
+
> # get the logits
|
274 |
+
> with torch.no_grad():
|
275 |
+
>
|
276 |
+
> with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
|
277 |
+
> logits, loss = model(tokens)
|
278 |
+
>
|
279 |
+
> # print(f"Step {step}:")
|
280 |
+
> # print(f"tokens shape: {tokens.shape}")
|
281 |
+
> # print(f"mask shape: {mask.shape}")
|
282 |
+
> # print(f"logits shape: {logits.shape}")
|
283 |
+
> pred_norm = get_most_likely_row(tokens, mask, logits)
|
284 |
+
> num_total += 1
|
285 |
+
> num_correct_norm += int(pred_norm == label)
|
286 |
+
> # reduce the stats across all processes
|
287 |
+
> if ddp:
|
288 |
+
> num_total = torch.tensor(num_total, dtype=torch.long, device=device)
|
289 |
+
> num_correct_norm = torch.tensor(num_correct_norm, dtype=torch.long, device=device)
|
290 |
+
> dist.all_reduce(num_total, op=dist.ReduceOp.SUM)
|
291 |
+
> dist.all_reduce(num_correct_norm, op=dist.ReduceOp.SUM)
|
292 |
+
> num_total = num_total.item()
|
293 |
+
> num_correct_norm = num_correct_norm.item()
|
294 |
+
> acc_norm = num_correct_norm / num_total
|
295 |
+
> if master_process:
|
296 |
+
> print(f"HellaSwag accuracy: {num_correct_norm}/{num_total}={acc_norm:.4f}")
|
297 |
+
> with open(logfile, "a") as f:
|
298 |
+
> f.write(f"{step} hella {acc_norm:.4f}\n")
|