35a36 > from hellaswag import render_example, iterate_examples 36a38,39 > import torch._dynamo > torch._dynamo.config.suppress_errors = True 48c51,54 < class CausalSelfAttention(nn.Module): --- > class NewGELU(nn.Module): > """Careful there are a few versions of GeLU, this one is the exact one used by OpenAI""" > def forward(self, input): > return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0)))) 49a56,79 > # Rotary Position Embedding > def apply_rotary_pos_emb(q, k, sin, cos): > q_embed = (q * cos) + rotate_half(q) * sin > k_embed = (k * cos) + rotate_half(k) * sin > return q_embed, k_embed > > def rotate_half(x): > x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:] > return torch.cat((-x2, x1), dim=-1) > > class RotaryEmbedding(nn.Module): > def __init__(self, dim): > super().__init__() > inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) > self.register_buffer("inv_freq", inv_freq) > > def forward(self, seq_len, device): > t = torch.arange(seq_len, device=device).type_as(self.inv_freq) > freqs = torch.einsum("i , j -> i j", t, self.inv_freq) > emb = torch.cat((freqs, freqs), dim=-1) > sin, cos = emb.sin(), emb.cos() > return sin, cos > > class CausalSelfAttention(nn.Module): 53,58d82 < # key, query, value projections for all heads, but in a batch < self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd) < # output projection < self.c_proj = nn.Linear(config.n_embd, config.n_embd) < self.c_proj.LLMC_RESIDUAL_SCALE_FLAG = 1 < # regularization 61c85,92 < # not really a 'bias', more of a mask, but following the OpenAI/HF naming though --- > self.grouped_heads = config.grouped_heads > self.head_dim = config.n_embd // config.n_head > > self.c_attn = nn.Linear(config.n_embd, (2 * config.n_head + self.grouped_heads) * self.head_dim) > self.c_proj = nn.Linear(config.n_embd, config.n_embd) > > self.rotary_embedding = RotaryEmbedding(self.head_dim) > 63c94 < .view(1, 1, config.block_size, config.block_size)) --- > .view(1, 1, config.block_size, config.block_size)) 66,67c97 < B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) < # calculate query, key, values for all heads in batch and move head forward to be the batch dim --- > B, T, C = x.size() 69,84c99,120 < q, k, v = qkv.split(self.n_embd, dim=2) < k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) < q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) < v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) < if FLASH: < # flashattention < y = F.scaled_dot_product_attention(q, k, v, is_causal=True) < else: < # manual implementation of attention < # this materializes the large (T,T) matrix for all the queries and keys < att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) < att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf')) < att = F.softmax(att, dim=-1) < y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) < y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side < # output projection --- > q, k, v = torch.split(qkv, [self.grouped_heads * self.head_dim, self.n_head * self.head_dim, self.n_head * self.head_dim], dim=2) > > # Reshape for multi-head attention > q = q.view(B, T, self.grouped_heads, self.head_dim).transpose(1, 2) > k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2) > v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2) > > # Apply RoPE > sin, cos = self.rotary_embedding(T, x.device) > q, k = apply_rotary_pos_emb(q, k, sin, cos) > > # Expand q to match the number of key/value heads > q = q.repeat_interleave(self.n_head // self.grouped_heads, dim=1) > > # Attention computation > att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(self.head_dim)) > att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float('-inf')) > att = F.softmax(att, dim=-1) > y = att @ v > > # Reshape output > y = y.transpose(1, 2).contiguous().view(B, T, C) 87c123,124 < --- > > 89d125 < 92,95c128,130 < self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd) < self.gelu = NewGELU() < self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd) < self.c_proj.LLMC_RESIDUAL_SCALE_FLAG = 1 --- > self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd) > self.gelu = NewGELU() # Using GeLU activation > self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd) 99c134 < x = self.gelu(x) --- > x = self.gelu(x) # GeLU activation 104d138 < 117,119d150 < # ----------------------------------------------------------------------------- < # The main GPT-2 model < 122c153 < block_size: int = 1024 --- > block_size: int = 2048 124,126c155,158 < n_layer: int = 12 < n_head: int = 12 < n_embd: int = 768 --- > n_layer: int = 16 > n_head: int = 16 > grouped_heads: int = 4 # Number of grouped heads for GQA > n_embd: int = 1024 129d160 < 135,136c166 < wte = nn.Embedding(config.vocab_size, config.n_embd), < wpe = nn.Embedding(config.block_size, config.n_embd), --- > wte = nn.Embedding(config.vocab_size, config.n_embd), # Token embedding only, no wpe 140,142d169 < self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) < self.lm_head.LLMC_SKIP_INIT = 1 # don't init this one, we will tie weights < self.transformer.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying 144,160c171,173 < # init all weights, use a torch rng object to be very careful < self.init_rng = torch.Generator() < self.init_rng.manual_seed(42) < self.apply(self._init_weights) < < def _init_weights(self, module): < if isinstance(module, nn.Linear): < # apply special scaled init to the residual projections, per GPT-2 paper < std = 0.02 if not hasattr(module, 'LLMC_RESIDUAL_SCALE_FLAG') else 0.02/math.sqrt(2 * self.config.n_layer) < # we want to skip initializing lm_head, which shares parameters with wte < # and wte was already initialized down below during the Embedding init < if not hasattr(module, 'LLMC_SKIP_INIT'): < torch.nn.init.normal_(module.weight, mean=0.0, std=std, generator=self.init_rng) < if module.bias is not None: < torch.nn.init.zeros_(module.bias) < elif isinstance(module, nn.Embedding): < torch.nn.init.normal_(module.weight, mean=0.0, std=0.02, generator=self.init_rng) --- > self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) > self.lm_head.LLMC_SKIP_INIT = 1 # Weight tying > self.transformer.wte.weight = self.lm_head.weight 166d178 < pos = torch.arange(0, t, dtype=torch.long, device=device) # shape (t) 168,171c180,182 < # forward the GPT model itself < tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) < pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd) < x = tok_emb + pos_emb --- > # forward GPT model > tok_emb = self.transformer.wte(idx) # token embeddings > x = tok_emb 176a188,189 > logits = self.lm_head(x) > 178,179d190 < # if we are given some desired targets also calculate the loss < logits = self.lm_head(x) 182,183d192 < # inference-time mini-optimization: only forward the lm_head on the very last position < logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim 186c195 < # there are performance reasons why not returning logits is prudent, if not needed --- > # if return_logits is False, return only the loss (used for training) 188c197 < logits = None --- > return None, loss 189a199 > # return logits and optionally the loss (used for inference and training) 201,204c211,214 < 'gpt2': dict(n_layer=12, n_head=12, n_embd=768), # 124M params < 'gpt2-medium': dict(n_layer=24, n_head=16, n_embd=1024), # 350M params < 'gpt2-large': dict(n_layer=36, n_head=20, n_embd=1280), # 774M params < 'gpt2-xl': dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params --- > 'gpt2': dict(n_layer=12, n_head=12, grouped_heads=4, n_embd=768), # 124M params > 'gpt2-medium': dict(n_layer=24, n_head=16, grouped_heads=8, n_embd=1024), # 350M params > 'gpt2-large': dict(n_layer=36, n_head=20, grouped_heads=10, n_embd=1280), # 774M params > 'gpt2-xl': dict(n_layer=48, n_head=25, grouped_heads=12, n_embd=1600), # 1558M params 298a309 > 378a390,407 > def get_most_likely_row(tokens, mask, logits): > # evaluate the autoregressive loss at all positions > shift_logits = (logits[..., :-1, :]).contiguous() > shift_tokens = (tokens[..., 1:]).contiguous() > flat_shift_logits = shift_logits.view(-1, shift_logits.size(-1)) > flat_shift_tokens = shift_tokens.view(-1) > shift_losses = F.cross_entropy(flat_shift_logits, flat_shift_tokens, reduction='none') > shift_losses = shift_losses.view(tokens.size(0), -1) > # now get the average loss just for the completion region (where mask == 1), in each row > shift_mask = (mask[..., 1:]).contiguous() # we must shift mask, so we start at the last prompt token > masked_shift_losses = shift_losses * shift_mask > # sum and divide by the number of 1s in the mask > sum_loss = masked_shift_losses.sum(dim=1) > avg_loss = sum_loss / shift_mask.sum(dim=1) > # now we have a loss for each of the 4 completions > # the one with the lowest loss should be the most likely > pred_norm = avg_loss.argmin().item() > return pred_norm 655c684 < "d12": GPTConfig(block_size=1024, vocab_size=50257, n_layer=12, n_head=12, n_embd=768), --- > "d12": GPTConfig(block_size=2024, vocab_size=50257, n_layer=12, n_head=12, n_embd=1024), 702,705d730 < # ------------------------------------------------------------------------- < # main training loop < < # here we wrap model into DDP container 738a764,765 > > 758c785,786 < _, loss = model(x, y, return_logits=False) --- > logits, loss = model(x, y) > print(logits.shape) 782a811,853 > > if step in [50,5000,10000,15000,19560]: > save_path = f"{args.output_dir}/model_checkpoint_{step}.bin" > torch.save(model.state_dict(), save_path) > print0(f"Model saved at step {step} to {save_path}") > > if (step % 250 == 0 or last_step or step == 10): #and (not use_compile): > num_correct_norm = 0 > num_total = 0 > for i, example in enumerate(iterate_examples("val")): > # only process examples where i % ddp_world_size == ddp_rank > if i % ddp_world_size != ddp_rank: > continue > # render the example into tokens and labels > _, tokens, mask, label = render_example(example) > tokens = tokens.to(device) > mask = mask.to(device) > # get the logits > with torch.no_grad(): > > with torch.autocast(device_type=device_type, dtype=torch.bfloat16): > logits, loss = model(tokens) > > # print(f"Step {step}:") > # print(f"tokens shape: {tokens.shape}") > # print(f"mask shape: {mask.shape}") > # print(f"logits shape: {logits.shape}") > pred_norm = get_most_likely_row(tokens, mask, logits) > num_total += 1 > num_correct_norm += int(pred_norm == label) > # reduce the stats across all processes > if ddp: > num_total = torch.tensor(num_total, dtype=torch.long, device=device) > num_correct_norm = torch.tensor(num_correct_norm, dtype=torch.long, device=device) > dist.all_reduce(num_total, op=dist.ReduceOp.SUM) > dist.all_reduce(num_correct_norm, op=dist.ReduceOp.SUM) > num_total = num_total.item() > num_correct_norm = num_correct_norm.item() > acc_norm = num_correct_norm / num_total > if master_process: > print(f"HellaSwag accuracy: {num_correct_norm}/{num_total}={acc_norm:.4f}") > with open(logfile, "a") as f: > f.write(f"{step} hella {acc_norm:.4f}\n")