|
35a36 |
|
> from hellaswag import render_example, iterate_examples |
|
36a38,39 |
|
> import torch._dynamo |
|
> torch._dynamo.config.suppress_errors = True |
|
48c51,54 |
|
< class CausalSelfAttention(nn.Module): |
|
--- |
|
> class NewGELU(nn.Module): |
|
> """Careful there are a few versions of GeLU, this one is the exact one used by OpenAI""" |
|
> def forward(self, input): |
|
> return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0)))) |
|
49a56,79 |
|
> # Rotary Position Embedding |
|
> def apply_rotary_pos_emb(q, k, sin, cos): |
|
> q_embed = (q * cos) + rotate_half(q) * sin |
|
> k_embed = (k * cos) + rotate_half(k) * sin |
|
> return q_embed, k_embed |
|
> |
|
> def rotate_half(x): |
|
> x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:] |
|
> return torch.cat((-x2, x1), dim=-1) |
|
> |
|
> class RotaryEmbedding(nn.Module): |
|
> def __init__(self, dim): |
|
> super().__init__() |
|
> inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) |
|
> self.register_buffer("inv_freq", inv_freq) |
|
> |
|
> def forward(self, seq_len, device): |
|
> t = torch.arange(seq_len, device=device).type_as(self.inv_freq) |
|
> freqs = torch.einsum("i , j -> i j", t, self.inv_freq) |
|
> emb = torch.cat((freqs, freqs), dim=-1) |
|
> sin, cos = emb.sin(), emb.cos() |
|
> return sin, cos |
|
> |
|
> class CausalSelfAttention(nn.Module): |
|
53,58d82 |
|
< # key, query, value projections for all heads, but in a batch |
|
< self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd) |
|
< # output projection |
|
< self.c_proj = nn.Linear(config.n_embd, config.n_embd) |
|
< self.c_proj.LLMC_RESIDUAL_SCALE_FLAG = 1 |
|
< # regularization |
|
61c85,92 |
|
< # not really a 'bias', more of a mask, but following the OpenAI/HF naming though |
|
--- |
|
> self.grouped_heads = config.grouped_heads |
|
> self.head_dim = config.n_embd // config.n_head |
|
> |
|
> self.c_attn = nn.Linear(config.n_embd, (2 * config.n_head + self.grouped_heads) * self.head_dim) |
|
> self.c_proj = nn.Linear(config.n_embd, config.n_embd) |
|
> |
|
> self.rotary_embedding = RotaryEmbedding(self.head_dim) |
|
> |
|
63c94 |
|
< .view(1, 1, config.block_size, config.block_size)) |
|
--- |
|
> .view(1, 1, config.block_size, config.block_size)) |
|
66,67c97 |
|
< B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) |
|
< # calculate query, key, values for all heads in batch and move head forward to be the batch dim |
|
--- |
|
> B, T, C = x.size() |
|
69,84c99,120 |
|
< q, k, v = qkv.split(self.n_embd, dim=2) |
|
< k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) |
|
< q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) |
|
< v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) |
|
< if FLASH: |
|
< # flashattention |
|
< y = F.scaled_dot_product_attention(q, k, v, is_causal=True) |
|
< else: |
|
< # manual implementation of attention |
|
< # this materializes the large (T,T) matrix for all the queries and keys |
|
< att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) |
|
< att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf')) |
|
< att = F.softmax(att, dim=-1) |
|
< y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) |
|
< y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side |
|
< # output projection |
|
--- |
|
> q, k, v = torch.split(qkv, [self.grouped_heads * self.head_dim, self.n_head * self.head_dim, self.n_head * self.head_dim], dim=2) |
|
> |
|
> # Reshape for multi-head attention |
|
> q = q.view(B, T, self.grouped_heads, self.head_dim).transpose(1, 2) |
|
> k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2) |
|
> v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2) |
|
> |
|
> # Apply RoPE |
|
> sin, cos = self.rotary_embedding(T, x.device) |
|
> q, k = apply_rotary_pos_emb(q, k, sin, cos) |
|
> |
|
> # Expand q to match the number of key/value heads |
|
> q = q.repeat_interleave(self.n_head // self.grouped_heads, dim=1) |
|
> |
|
> # Attention computation |
|
> att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(self.head_dim)) |
|
> att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float('-inf')) |
|
> att = F.softmax(att, dim=-1) |
|
> y = att @ v |
|
> |
|
> # Reshape output |
|
> y = y.transpose(1, 2).contiguous().view(B, T, C) |
|
87c123,124 |
|
< |
|
--- |
|
> |
|
> |
|
89d125 |
|
< |
|
92,95c128,130 |
|
< self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd) |
|
< self.gelu = NewGELU() |
|
< self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd) |
|
< self.c_proj.LLMC_RESIDUAL_SCALE_FLAG = 1 |
|
--- |
|
> self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd) |
|
> self.gelu = NewGELU() # Using GeLU activation |
|
> self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd) |
|
99c134 |
|
< x = self.gelu(x) |
|
--- |
|
> x = self.gelu(x) # GeLU activation |
|
104d138 |
|
< |
|
117,119d150 |
|
< # ----------------------------------------------------------------------------- |
|
< # The main GPT-2 model |
|
< |
|
122c153 |
|
< block_size: int = 1024 |
|
--- |
|
> block_size: int = 2048 |
|
124,126c155,158 |
|
< n_layer: int = 12 |
|
< n_head: int = 12 |
|
< n_embd: int = 768 |
|
--- |
|
> n_layer: int = 16 |
|
> n_head: int = 16 |
|
> grouped_heads: int = 4 # Number of grouped heads for GQA |
|
> n_embd: int = 1024 |
|
129d160 |
|
< |
|
135,136c166 |
|
< wte = nn.Embedding(config.vocab_size, config.n_embd), |
|
< wpe = nn.Embedding(config.block_size, config.n_embd), |
|
--- |
|
> wte = nn.Embedding(config.vocab_size, config.n_embd), # Token embedding only, no wpe |
|
140,142d169 |
|
< self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) |
|
< self.lm_head.LLMC_SKIP_INIT = 1 # don't init this one, we will tie weights |
|
< self.transformer.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying |
|
144,160c171,173 |
|
< # init all weights, use a torch rng object to be very careful |
|
< self.init_rng = torch.Generator() |
|
< self.init_rng.manual_seed(42) |
|
< self.apply(self._init_weights) |
|
< |
|
< def _init_weights(self, module): |
|
< if isinstance(module, nn.Linear): |
|
< # apply special scaled init to the residual projections, per GPT-2 paper |
|
< std = 0.02 if not hasattr(module, 'LLMC_RESIDUAL_SCALE_FLAG') else 0.02/math.sqrt(2 * self.config.n_layer) |
|
< # we want to skip initializing lm_head, which shares parameters with wte |
|
< # and wte was already initialized down below during the Embedding init |
|
< if not hasattr(module, 'LLMC_SKIP_INIT'): |
|
< torch.nn.init.normal_(module.weight, mean=0.0, std=std, generator=self.init_rng) |
|
< if module.bias is not None: |
|
< torch.nn.init.zeros_(module.bias) |
|
< elif isinstance(module, nn.Embedding): |
|
< torch.nn.init.normal_(module.weight, mean=0.0, std=0.02, generator=self.init_rng) |
|
--- |
|
> self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) |
|
> self.lm_head.LLMC_SKIP_INIT = 1 # Weight tying |
|
> self.transformer.wte.weight = self.lm_head.weight |
|
166d178 |
|
< pos = torch.arange(0, t, dtype=torch.long, device=device) # shape (t) |
|
168,171c180,182 |
|
< # forward the GPT model itself |
|
< tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) |
|
< pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd) |
|
< x = tok_emb + pos_emb |
|
--- |
|
> # forward GPT model |
|
> tok_emb = self.transformer.wte(idx) # token embeddings |
|
> x = tok_emb |
|
176a188,189 |
|
> logits = self.lm_head(x) |
|
> |
|
178,179d190 |
|
< # if we are given some desired targets also calculate the loss |
|
< logits = self.lm_head(x) |
|
182,183d192 |
|
< # inference-time mini-optimization: only forward the lm_head on the very last position |
|
< logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim |
|
186c195 |
|
< # there are performance reasons why not returning logits is prudent, if not needed |
|
--- |
|
> # if return_logits is False, return only the loss (used for training) |
|
188c197 |
|
< logits = None |
|
--- |
|
> return None, loss |
|
189a199 |
|
> # return logits and optionally the loss (used for inference and training) |
|
201,204c211,214 |
|
< 'gpt2': dict(n_layer=12, n_head=12, n_embd=768), # 124M params |
|
< 'gpt2-medium': dict(n_layer=24, n_head=16, n_embd=1024), # 350M params |
|
< 'gpt2-large': dict(n_layer=36, n_head=20, n_embd=1280), # 774M params |
|
< 'gpt2-xl': dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params |
|
--- |
|
> 'gpt2': dict(n_layer=12, n_head=12, grouped_heads=4, n_embd=768), # 124M params |
|
> 'gpt2-medium': dict(n_layer=24, n_head=16, grouped_heads=8, n_embd=1024), # 350M params |
|
> 'gpt2-large': dict(n_layer=36, n_head=20, grouped_heads=10, n_embd=1280), # 774M params |
|
> 'gpt2-xl': dict(n_layer=48, n_head=25, grouped_heads=12, n_embd=1600), # 1558M params |
|
298a309 |
|
> |
|
378a390,407 |
|
> def get_most_likely_row(tokens, mask, logits): |
|
> # evaluate the autoregressive loss at all positions |
|
> shift_logits = (logits[..., :-1, :]).contiguous() |
|
> shift_tokens = (tokens[..., 1:]).contiguous() |
|
> flat_shift_logits = shift_logits.view(-1, shift_logits.size(-1)) |
|
> flat_shift_tokens = shift_tokens.view(-1) |
|
> shift_losses = F.cross_entropy(flat_shift_logits, flat_shift_tokens, reduction='none') |
|
> shift_losses = shift_losses.view(tokens.size(0), -1) |
|
> # now get the average loss just for the completion region (where mask == 1), in each row |
|
> shift_mask = (mask[..., 1:]).contiguous() # we must shift mask, so we start at the last prompt token |
|
> masked_shift_losses = shift_losses * shift_mask |
|
> # sum and divide by the number of 1s in the mask |
|
> sum_loss = masked_shift_losses.sum(dim=1) |
|
> avg_loss = sum_loss / shift_mask.sum(dim=1) |
|
> # now we have a loss for each of the 4 completions |
|
> # the one with the lowest loss should be the most likely |
|
> pred_norm = avg_loss.argmin().item() |
|
> return pred_norm |
|
655c684 |
|
< "d12": GPTConfig(block_size=1024, vocab_size=50257, n_layer=12, n_head=12, n_embd=768), |
|
--- |
|
> "d12": GPTConfig(block_size=2024, vocab_size=50257, n_layer=12, n_head=12, n_embd=1024), |
|
702,705d730 |
|
< # ------------------------------------------------------------------------- |
|
< # main training loop |
|
< |
|
< # here we wrap model into DDP container |
|
738a764,765 |
|
> |
|
> |
|
758c785,786 |
|
< _, loss = model(x, y, return_logits=False) |
|
--- |
|
> logits, loss = model(x, y) |
|
> print(logits.shape) |
|
782a811,853 |
|
> |
|
> if step in [50,5000,10000,15000,19560]: |
|
> save_path = f"{args.output_dir}/model_checkpoint_{step}.bin" |
|
> torch.save(model.state_dict(), save_path) |
|
> print0(f"Model saved at step {step} to {save_path}") |
|
> |
|
> if (step % 250 == 0 or last_step or step == 10): #and (not use_compile): |
|
> num_correct_norm = 0 |
|
> num_total = 0 |
|
> for i, example in enumerate(iterate_examples("val")): |
|
> # only process examples where i % ddp_world_size == ddp_rank |
|
> if i % ddp_world_size != ddp_rank: |
|
> continue |
|
> # render the example into tokens and labels |
|
> _, tokens, mask, label = render_example(example) |
|
> tokens = tokens.to(device) |
|
> mask = mask.to(device) |
|
> # get the logits |
|
> with torch.no_grad(): |
|
> |
|
> with torch.autocast(device_type=device_type, dtype=torch.bfloat16): |
|
> logits, loss = model(tokens) |
|
> |
|
> # print(f"Step {step}:") |
|
> # print(f"tokens shape: {tokens.shape}") |
|
> # print(f"mask shape: {mask.shape}") |
|
> # print(f"logits shape: {logits.shape}") |
|
> pred_norm = get_most_likely_row(tokens, mask, logits) |
|
> num_total += 1 |
|
> num_correct_norm += int(pred_norm == label) |
|
> # reduce the stats across all processes |
|
> if ddp: |
|
> num_total = torch.tensor(num_total, dtype=torch.long, device=device) |
|
> num_correct_norm = torch.tensor(num_correct_norm, dtype=torch.long, device=device) |
|
> dist.all_reduce(num_total, op=dist.ReduceOp.SUM) |
|
> dist.all_reduce(num_correct_norm, op=dist.ReduceOp.SUM) |
|
> num_total = num_total.item() |
|
> num_correct_norm = num_correct_norm.item() |
|
> acc_norm = num_correct_norm / num_total |
|
> if master_process: |
|
> print(f"HellaSwag accuracy: {num_correct_norm}/{num_total}={acc_norm:.4f}") |
|
> with open(logfile, "a") as f: |
|
> f.write(f"{step} hella {acc_norm:.4f}\n") |