commited on
Upload Train_GPT2_diff.txt
Browse files- Train_GPT2_diff.txt +298 -0
@@ -0,0 +1,298 @@
1 |
2 |
> from hellaswag import render_example, iterate_examples
3 |
4 |
> import torch._dynamo
5 |
> torch._dynamo.config.suppress_errors = True
6 |
7 |
< class CausalSelfAttention(nn.Module):
8 |
9 |
> class NewGELU(nn.Module):
10 |
> """Careful there are a few versions of GeLU, this one is the exact one used by OpenAI"""
11 |
> def forward(self, input):
12 |
> return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
13 |
14 |
> # Rotary Position Embedding
15 |
> def apply_rotary_pos_emb(q, k, sin, cos):
16 |
> q_embed = (q * cos) + rotate_half(q) * sin
17 |
> k_embed = (k * cos) + rotate_half(k) * sin
18 |
> return q_embed, k_embed
19 |
20 |
> def rotate_half(x):
21 |
> x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
22 |
> return, x1), dim=-1)
23 |
24 |
> class RotaryEmbedding(nn.Module):
25 |
> def __init__(self, dim):
26 |
> super().__init__()
27 |
> inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
28 |
> self.register_buffer("inv_freq", inv_freq)
29 |
30 |
> def forward(self, seq_len, device):
31 |
> t = torch.arange(seq_len, device=device).type_as(self.inv_freq)
32 |
> freqs = torch.einsum("i , j -> i j", t, self.inv_freq)
33 |
> emb =, freqs), dim=-1)
34 |
> sin, cos = emb.sin(), emb.cos()
35 |
> return sin, cos
36 |
37 |
> class CausalSelfAttention(nn.Module):
38 |
39 |
< # key, query, value projections for all heads, but in a batch
40 |
< self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
41 |
< # output projection
42 |
< self.c_proj = nn.Linear(config.n_embd, config.n_embd)
43 |
< self.c_proj.LLMC_RESIDUAL_SCALE_FLAG = 1
44 |
< # regularization
45 |
46 |
< # not really a 'bias', more of a mask, but following the OpenAI/HF naming though
47 |
48 |
> self.grouped_heads = config.grouped_heads
49 |
> self.head_dim = config.n_embd // config.n_head
50 |
51 |
> self.c_attn = nn.Linear(config.n_embd, (2 * config.n_head + self.grouped_heads) * self.head_dim)
52 |
> self.c_proj = nn.Linear(config.n_embd, config.n_embd)
53 |
54 |
> self.rotary_embedding = RotaryEmbedding(self.head_dim)
55 |
56 |
57 |
< .view(1, 1, config.block_size, config.block_size))
58 |
59 |
> .view(1, 1, config.block_size, config.block_size))
60 |
61 |
< B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
62 |
< # calculate query, key, values for all heads in batch and move head forward to be the batch dim
63 |
64 |
> B, T, C = x.size()
65 |
66 |
< q, k, v = qkv.split(self.n_embd, dim=2)
67 |
< k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
68 |
< q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
69 |
< v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
70 |
< if FLASH:
71 |
< # flashattention
72 |
< y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
73 |
< else:
74 |
< # manual implementation of attention
75 |
< # this materializes the large (T,T) matrix for all the queries and keys
76 |
< att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
77 |
< att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
78 |
< att = F.softmax(att, dim=-1)
79 |
< y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
80 |
< y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
81 |
< # output projection
82 |
83 |
> q, k, v = torch.split(qkv, [self.grouped_heads * self.head_dim, self.n_head * self.head_dim, self.n_head * self.head_dim], dim=2)
84 |
85 |
> # Reshape for multi-head attention
86 |
> q = q.view(B, T, self.grouped_heads, self.head_dim).transpose(1, 2)
87 |
> k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
88 |
> v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
89 |
90 |
> # Apply RoPE
91 |
> sin, cos = self.rotary_embedding(T, x.device)
92 |
> q, k = apply_rotary_pos_emb(q, k, sin, cos)
93 |
94 |
> # Expand q to match the number of key/value heads
95 |
> q = q.repeat_interleave(self.n_head // self.grouped_heads, dim=1)
96 |
97 |
> # Attention computation
98 |
> att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(self.head_dim))
99 |
> att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float('-inf'))
100 |
> att = F.softmax(att, dim=-1)
101 |
> y = att @ v
102 |
103 |
> # Reshape output
104 |
> y = y.transpose(1, 2).contiguous().view(B, T, C)
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 |
< self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd)
114 |
< self.gelu = NewGELU()
115 |
< self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd)
116 |
< self.c_proj.LLMC_RESIDUAL_SCALE_FLAG = 1
117 |
118 |
> self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd)
119 |
> self.gelu = NewGELU() # Using GeLU activation
120 |
> self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd)
121 |
122 |
< x = self.gelu(x)
123 |
124 |
> x = self.gelu(x) # GeLU activation
125 |
126 |
127 |
128 |
< # -----------------------------------------------------------------------------
129 |
< # The main GPT-2 model
130 |
131 |
132 |
< block_size: int = 1024
133 |
134 |
> block_size: int = 2048
135 |
136 |
< n_layer: int = 12
137 |
< n_head: int = 12
138 |
< n_embd: int = 768
139 |
140 |
> n_layer: int = 16
141 |
> n_head: int = 16
142 |
> grouped_heads: int = 4 # Number of grouped heads for GQA
143 |
> n_embd: int = 1024
144 |
145 |
146 |
147 |
< wte = nn.Embedding(config.vocab_size, config.n_embd),
148 |
< wpe = nn.Embedding(config.block_size, config.n_embd),
149 |
150 |
> wte = nn.Embedding(config.vocab_size, config.n_embd), # Token embedding only, no wpe
151 |
152 |
< self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
153 |
< self.lm_head.LLMC_SKIP_INIT = 1 # don't init this one, we will tie weights
154 |
< self.transformer.wte.weight = self.lm_head.weight #
155 |
156 |
< # init all weights, use a torch rng object to be very careful
157 |
< self.init_rng = torch.Generator()
158 |
< self.init_rng.manual_seed(42)
159 |
< self.apply(self._init_weights)
160 |
161 |
< def _init_weights(self, module):
162 |
< if isinstance(module, nn.Linear):
163 |
< # apply special scaled init to the residual projections, per GPT-2 paper
164 |
< std = 0.02 if not hasattr(module, 'LLMC_RESIDUAL_SCALE_FLAG') else 0.02/math.sqrt(2 * self.config.n_layer)
165 |
< # we want to skip initializing lm_head, which shares parameters with wte
166 |
< # and wte was already initialized down below during the Embedding init
167 |
< if not hasattr(module, 'LLMC_SKIP_INIT'):
168 |
< torch.nn.init.normal_(module.weight, mean=0.0, std=std, generator=self.init_rng)
169 |
< if module.bias is not None:
170 |
< torch.nn.init.zeros_(module.bias)
171 |
< elif isinstance(module, nn.Embedding):
172 |
< torch.nn.init.normal_(module.weight, mean=0.0, std=0.02, generator=self.init_rng)
173 |
174 |
> self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
175 |
> self.lm_head.LLMC_SKIP_INIT = 1 # Weight tying
176 |
> self.transformer.wte.weight = self.lm_head.weight
177 |
178 |
< pos = torch.arange(0, t, dtype=torch.long, device=device) # shape (t)
179 |
180 |
< # forward the GPT model itself
181 |
< tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
182 |
< pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
183 |
< x = tok_emb + pos_emb
184 |
185 |
> # forward GPT model
186 |
> tok_emb = self.transformer.wte(idx) # token embeddings
187 |
> x = tok_emb
188 |
189 |
> logits = self.lm_head(x)
190 |
191 |
192 |
< # if we are given some desired targets also calculate the loss
193 |
< logits = self.lm_head(x)
194 |
195 |
< # inference-time mini-optimization: only forward the lm_head on the very last position
196 |
< logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
197 |
198 |
< # there are performance reasons why not returning logits is prudent, if not needed
199 |
200 |
> # if return_logits is False, return only the loss (used for training)
201 |
202 |
< logits = None
203 |
204 |
> return None, loss
205 |
206 |
> # return logits and optionally the loss (used for inference and training)
207 |
208 |
< 'gpt2': dict(n_layer=12, n_head=12, n_embd=768), # 124M params
209 |
< 'gpt2-medium': dict(n_layer=24, n_head=16, n_embd=1024), # 350M params
210 |
< 'gpt2-large': dict(n_layer=36, n_head=20, n_embd=1280), # 774M params
211 |
< 'gpt2-xl': dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params
212 |
213 |
> 'gpt2': dict(n_layer=12, n_head=12, grouped_heads=4, n_embd=768), # 124M params
214 |
> 'gpt2-medium': dict(n_layer=24, n_head=16, grouped_heads=8, n_embd=1024), # 350M params
215 |
> 'gpt2-large': dict(n_layer=36, n_head=20, grouped_heads=10, n_embd=1280), # 774M params
216 |
> 'gpt2-xl': dict(n_layer=48, n_head=25, grouped_heads=12, n_embd=1600), # 1558M params
217 |
218 |
219 |
220 |
> def get_most_likely_row(tokens, mask, logits):
221 |
> # evaluate the autoregressive loss at all positions
222 |
> shift_logits = (logits[..., :-1, :]).contiguous()
223 |
> shift_tokens = (tokens[..., 1:]).contiguous()
224 |
> flat_shift_logits = shift_logits.view(-1, shift_logits.size(-1))
225 |
> flat_shift_tokens = shift_tokens.view(-1)
226 |
> shift_losses = F.cross_entropy(flat_shift_logits, flat_shift_tokens, reduction='none')
227 |
> shift_losses = shift_losses.view(tokens.size(0), -1)
228 |
> # now get the average loss just for the completion region (where mask == 1), in each row
229 |
> shift_mask = (mask[..., 1:]).contiguous() # we must shift mask, so we start at the last prompt token
230 |
> masked_shift_losses = shift_losses * shift_mask
231 |
> # sum and divide by the number of 1s in the mask
232 |
> sum_loss = masked_shift_losses.sum(dim=1)
233 |
> avg_loss = sum_loss / shift_mask.sum(dim=1)
234 |
> # now we have a loss for each of the 4 completions
235 |
> # the one with the lowest loss should be the most likely
236 |
> pred_norm = avg_loss.argmin().item()
237 |
> return pred_norm
238 |
239 |
< "d12": GPTConfig(block_size=1024, vocab_size=50257, n_layer=12, n_head=12, n_embd=768),
240 |
241 |
> "d12": GPTConfig(block_size=2024, vocab_size=50257, n_layer=12, n_head=12, n_embd=1024),
242 |
243 |
< # -------------------------------------------------------------------------
244 |
< # main training loop
245 |
246 |
< # here we wrap model into DDP container
247 |
248 |
249 |
250 |
251 |
< _, loss = model(x, y, return_logits=False)
252 |
253 |
> logits, loss = model(x, y)
254 |
> print(logits.shape)
255 |
256 |
257 |
> if step in [50,5000,10000,15000,19560]:
258 |
> save_path = f"{args.output_dir}/model_checkpoint_{step}.bin"
259 |
>, save_path)
260 |
> print0(f"Model saved at step {step} to {save_path}")
261 |
262 |
> if (step % 250 == 0 or last_step or step == 10): #and (not use_compile):
263 |
> num_correct_norm = 0
264 |
> num_total = 0
265 |
> for i, example in enumerate(iterate_examples("val")):
266 |
> # only process examples where i % ddp_world_size == ddp_rank
267 |
> if i % ddp_world_size != ddp_rank:
268 |
> continue
269 |
> # render the example into tokens and labels
270 |
> _, tokens, mask, label = render_example(example)
271 |
> tokens =
272 |
> mask =
273 |
> # get the logits
274 |
> with torch.no_grad():
275 |
276 |
> with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
277 |
> logits, loss = model(tokens)
278 |
279 |
> # print(f"Step {step}:")
280 |
> # print(f"tokens shape: {tokens.shape}")
281 |
> # print(f"mask shape: {mask.shape}")
282 |
> # print(f"logits shape: {logits.shape}")
283 |
> pred_norm = get_most_likely_row(tokens, mask, logits)
284 |
> num_total += 1
285 |
> num_correct_norm += int(pred_norm == label)
286 |
> # reduce the stats across all processes
287 |
> if ddp:
288 |
> num_total = torch.tensor(num_total, dtype=torch.long, device=device)
289 |
> num_correct_norm = torch.tensor(num_correct_norm, dtype=torch.long, device=device)
290 |
> dist.all_reduce(num_total, op=dist.ReduceOp.SUM)
291 |
> dist.all_reduce(num_correct_norm, op=dist.ReduceOp.SUM)
292 |
> num_total = num_total.item()
293 |
> num_correct_norm = num_correct_norm.item()
294 |
> acc_norm = num_correct_norm / num_total
295 |
> if master_process:
296 |
> print(f"HellaSwag accuracy: {num_correct_norm}/{num_total}={acc_norm:.4f}")
297 |
> with open(logfile, "a") as f:
298 |
> f.write(f"{step} hella {acc_norm:.4f}\n")