asharsha30's picture
Upload Train_GPT2_diff.txt
0153fea verified
35a36
> from hellaswag import render_example, iterate_examples
36a38,39
> import torch._dynamo
> torch._dynamo.config.suppress_errors = True
48c51,54
< class CausalSelfAttention(nn.Module):
---
> class NewGELU(nn.Module):
> """Careful there are a few versions of GeLU, this one is the exact one used by OpenAI"""
> def forward(self, input):
> return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
49a56,79
> # Rotary Position Embedding
> def apply_rotary_pos_emb(q, k, sin, cos):
> q_embed = (q * cos) + rotate_half(q) * sin
> k_embed = (k * cos) + rotate_half(k) * sin
> return q_embed, k_embed
>
> def rotate_half(x):
> x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
> return torch.cat((-x2, x1), dim=-1)
>
> class RotaryEmbedding(nn.Module):
> def __init__(self, dim):
> super().__init__()
> inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
> self.register_buffer("inv_freq", inv_freq)
>
> def forward(self, seq_len, device):
> t = torch.arange(seq_len, device=device).type_as(self.inv_freq)
> freqs = torch.einsum("i , j -> i j", t, self.inv_freq)
> emb = torch.cat((freqs, freqs), dim=-1)
> sin, cos = emb.sin(), emb.cos()
> return sin, cos
>
> class CausalSelfAttention(nn.Module):
53,58d82
< # key, query, value projections for all heads, but in a batch
< self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
< # output projection
< self.c_proj = nn.Linear(config.n_embd, config.n_embd)
< self.c_proj.LLMC_RESIDUAL_SCALE_FLAG = 1
< # regularization
61c85,92
< # not really a 'bias', more of a mask, but following the OpenAI/HF naming though
---
> self.grouped_heads = config.grouped_heads
> self.head_dim = config.n_embd // config.n_head
>
> self.c_attn = nn.Linear(config.n_embd, (2 * config.n_head + self.grouped_heads) * self.head_dim)
> self.c_proj = nn.Linear(config.n_embd, config.n_embd)
>
> self.rotary_embedding = RotaryEmbedding(self.head_dim)
>
63c94
< .view(1, 1, config.block_size, config.block_size))
---
> .view(1, 1, config.block_size, config.block_size))
66,67c97
< B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
< # calculate query, key, values for all heads in batch and move head forward to be the batch dim
---
> B, T, C = x.size()
69,84c99,120
< q, k, v = qkv.split(self.n_embd, dim=2)
< k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
< q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
< v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
< if FLASH:
< # flashattention
< y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
< else:
< # manual implementation of attention
< # this materializes the large (T,T) matrix for all the queries and keys
< att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
< att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
< att = F.softmax(att, dim=-1)
< y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
< y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
< # output projection
---
> q, k, v = torch.split(qkv, [self.grouped_heads * self.head_dim, self.n_head * self.head_dim, self.n_head * self.head_dim], dim=2)
>
> # Reshape for multi-head attention
> q = q.view(B, T, self.grouped_heads, self.head_dim).transpose(1, 2)
> k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
> v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
>
> # Apply RoPE
> sin, cos = self.rotary_embedding(T, x.device)
> q, k = apply_rotary_pos_emb(q, k, sin, cos)
>
> # Expand q to match the number of key/value heads
> q = q.repeat_interleave(self.n_head // self.grouped_heads, dim=1)
>
> # Attention computation
> att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(self.head_dim))
> att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float('-inf'))
> att = F.softmax(att, dim=-1)
> y = att @ v
>
> # Reshape output
> y = y.transpose(1, 2).contiguous().view(B, T, C)
87c123,124
<
---
>
>
89d125
<
92,95c128,130
< self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd)
< self.gelu = NewGELU()
< self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd)
< self.c_proj.LLMC_RESIDUAL_SCALE_FLAG = 1
---
> self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd)
> self.gelu = NewGELU() # Using GeLU activation
> self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd)
99c134
< x = self.gelu(x)
---
> x = self.gelu(x) # GeLU activation
104d138
<
117,119d150
< # -----------------------------------------------------------------------------
< # The main GPT-2 model
<
122c153
< block_size: int = 1024
---
> block_size: int = 2048
124,126c155,158
< n_layer: int = 12
< n_head: int = 12
< n_embd: int = 768
---
> n_layer: int = 16
> n_head: int = 16
> grouped_heads: int = 4 # Number of grouped heads for GQA
> n_embd: int = 1024
129d160
<
135,136c166
< wte = nn.Embedding(config.vocab_size, config.n_embd),
< wpe = nn.Embedding(config.block_size, config.n_embd),
---
> wte = nn.Embedding(config.vocab_size, config.n_embd), # Token embedding only, no wpe
140,142d169
< self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
< self.lm_head.LLMC_SKIP_INIT = 1 # don't init this one, we will tie weights
< self.transformer.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying
144,160c171,173
< # init all weights, use a torch rng object to be very careful
< self.init_rng = torch.Generator()
< self.init_rng.manual_seed(42)
< self.apply(self._init_weights)
<
< def _init_weights(self, module):
< if isinstance(module, nn.Linear):
< # apply special scaled init to the residual projections, per GPT-2 paper
< std = 0.02 if not hasattr(module, 'LLMC_RESIDUAL_SCALE_FLAG') else 0.02/math.sqrt(2 * self.config.n_layer)
< # we want to skip initializing lm_head, which shares parameters with wte
< # and wte was already initialized down below during the Embedding init
< if not hasattr(module, 'LLMC_SKIP_INIT'):
< torch.nn.init.normal_(module.weight, mean=0.0, std=std, generator=self.init_rng)
< if module.bias is not None:
< torch.nn.init.zeros_(module.bias)
< elif isinstance(module, nn.Embedding):
< torch.nn.init.normal_(module.weight, mean=0.0, std=0.02, generator=self.init_rng)
---
> self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
> self.lm_head.LLMC_SKIP_INIT = 1 # Weight tying
> self.transformer.wte.weight = self.lm_head.weight
166d178
< pos = torch.arange(0, t, dtype=torch.long, device=device) # shape (t)
168,171c180,182
< # forward the GPT model itself
< tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
< pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
< x = tok_emb + pos_emb
---
> # forward GPT model
> tok_emb = self.transformer.wte(idx) # token embeddings
> x = tok_emb
176a188,189
> logits = self.lm_head(x)
>
178,179d190
< # if we are given some desired targets also calculate the loss
< logits = self.lm_head(x)
182,183d192
< # inference-time mini-optimization: only forward the lm_head on the very last position
< logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
186c195
< # there are performance reasons why not returning logits is prudent, if not needed
---
> # if return_logits is False, return only the loss (used for training)
188c197
< logits = None
---
> return None, loss
189a199
> # return logits and optionally the loss (used for inference and training)
201,204c211,214
< 'gpt2': dict(n_layer=12, n_head=12, n_embd=768), # 124M params
< 'gpt2-medium': dict(n_layer=24, n_head=16, n_embd=1024), # 350M params
< 'gpt2-large': dict(n_layer=36, n_head=20, n_embd=1280), # 774M params
< 'gpt2-xl': dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params
---
> 'gpt2': dict(n_layer=12, n_head=12, grouped_heads=4, n_embd=768), # 124M params
> 'gpt2-medium': dict(n_layer=24, n_head=16, grouped_heads=8, n_embd=1024), # 350M params
> 'gpt2-large': dict(n_layer=36, n_head=20, grouped_heads=10, n_embd=1280), # 774M params
> 'gpt2-xl': dict(n_layer=48, n_head=25, grouped_heads=12, n_embd=1600), # 1558M params
298a309
>
378a390,407
> def get_most_likely_row(tokens, mask, logits):
> # evaluate the autoregressive loss at all positions
> shift_logits = (logits[..., :-1, :]).contiguous()
> shift_tokens = (tokens[..., 1:]).contiguous()
> flat_shift_logits = shift_logits.view(-1, shift_logits.size(-1))
> flat_shift_tokens = shift_tokens.view(-1)
> shift_losses = F.cross_entropy(flat_shift_logits, flat_shift_tokens, reduction='none')
> shift_losses = shift_losses.view(tokens.size(0), -1)
> # now get the average loss just for the completion region (where mask == 1), in each row
> shift_mask = (mask[..., 1:]).contiguous() # we must shift mask, so we start at the last prompt token
> masked_shift_losses = shift_losses * shift_mask
> # sum and divide by the number of 1s in the mask
> sum_loss = masked_shift_losses.sum(dim=1)
> avg_loss = sum_loss / shift_mask.sum(dim=1)
> # now we have a loss for each of the 4 completions
> # the one with the lowest loss should be the most likely
> pred_norm = avg_loss.argmin().item()
> return pred_norm
655c684
< "d12": GPTConfig(block_size=1024, vocab_size=50257, n_layer=12, n_head=12, n_embd=768),
---
> "d12": GPTConfig(block_size=2024, vocab_size=50257, n_layer=12, n_head=12, n_embd=1024),
702,705d730
< # -------------------------------------------------------------------------
< # main training loop
<
< # here we wrap model into DDP container
738a764,765
>
>
758c785,786
< _, loss = model(x, y, return_logits=False)
---
> logits, loss = model(x, y)
> print(logits.shape)
782a811,853
>
> if step in [50,5000,10000,15000,19560]:
> save_path = f"{args.output_dir}/model_checkpoint_{step}.bin"
> torch.save(model.state_dict(), save_path)
> print0(f"Model saved at step {step} to {save_path}")
>
> if (step % 250 == 0 or last_step or step == 10): #and (not use_compile):
> num_correct_norm = 0
> num_total = 0
> for i, example in enumerate(iterate_examples("val")):
> # only process examples where i % ddp_world_size == ddp_rank
> if i % ddp_world_size != ddp_rank:
> continue
> # render the example into tokens and labels
> _, tokens, mask, label = render_example(example)
> tokens = tokens.to(device)
> mask = mask.to(device)
> # get the logits
> with torch.no_grad():
>
> with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
> logits, loss = model(tokens)
>
> # print(f"Step {step}:")
> # print(f"tokens shape: {tokens.shape}")
> # print(f"mask shape: {mask.shape}")
> # print(f"logits shape: {logits.shape}")
> pred_norm = get_most_likely_row(tokens, mask, logits)
> num_total += 1
> num_correct_norm += int(pred_norm == label)
> # reduce the stats across all processes
> if ddp:
> num_total = torch.tensor(num_total, dtype=torch.long, device=device)
> num_correct_norm = torch.tensor(num_correct_norm, dtype=torch.long, device=device)
> dist.all_reduce(num_total, op=dist.ReduceOp.SUM)
> dist.all_reduce(num_correct_norm, op=dist.ReduceOp.SUM)
> num_total = num_total.item()
> num_correct_norm = num_correct_norm.item()
> acc_norm = num_correct_norm / num_total
> if master_process:
> print(f"HellaSwag accuracy: {num_correct_norm}/{num_total}={acc_norm:.4f}")
> with open(logfile, "a") as f:
> f.write(f"{step} hella {acc_norm:.4f}\n")