# import for colab/kaggle # !pip install datasets transformers wandb -q # !pip install pytorch-lightning lightning tiktoken -q import os import math from dataclasses import dataclass import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import DataLoader from datasets import load_dataset from transformers import GPT2Tokenizer import pytorch_lightning as pl from pytorch_lightning.callbacks import LearningRateMonitor, RichProgressBar from pytorch_lightning.loggers import WandbLogger from lightning.pytorch.callbacks.progress.rich_progress import RichProgressBarTheme from pytorch_lightning.callbacks import ModelCheckpoint block_size = 512 batch_size = 8 max_lr = 1e-3 warmup_steps = 10 max_steps = 25000 log_every_n_steps = 100 save_checkpoints_every_n_steps = 10 effective_batch_size = 32 tokenizer: GPT2Tokenizer = GPT2Tokenizer.from_pretrained( "HuggingFaceTB/cosmo2-tokenizer" ) tokenizer.pad_token = tokenizer.eos_token vocab_size = tokenizer.vocab_size def load_cosmopedia_dataset(batch_size=8, seq_length=1024): """ Returns a torch dataloader for the cosmopedia dataset """ try: dataset = load_dataset( "HuggingFaceTB/smollm-corpus", name="cosmopedia-v2", split="train", streaming=True, ) def encode(examples): tokens = tokenizer( examples["text"], truncation=True, padding="max_length", max_length=seq_length + 1, return_tensors="pt", ) input_ids = tokens["input_ids"].squeeze(0).clone().detach() input_ids = torch.clamp(input_ids, min=0, max=tokenizer.vocab_size - 1) labels = input_ids.clone().detach() labels = labels[1:].to(torch.int64) input_ids = input_ids[:-1].to(torch.int64) return {"input_ids": input_ids, "labels": labels} dataset = dataset.map(encode, remove_columns=["text"], batched=False) dataset = dataset.with_format("torch") dataloader = DataLoader(dataset, batch_size=batch_size) return dataloader except Exception as e: print(e) return None @dataclass class SmolLMConfig: block_size = 1024 vocab_size = 49152 n_layers = 30 n_heads = 9 n_embed = 576 dropout = 0.1 mlp_hidden_dim = 1536 attention_dropout = 0.0 dropout = 0.1 n_key_value_heads = 3 rms_norm_eps = 1e-5 ## Function which enables K and V to have less heads than Q. ## it repeats the K and V heads n_rep times def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: """torch.repeat_interleave(x, dim=2, repeats=n_rep)""" bs, n_kv_heads, slen, head_dim = x.shape if n_rep == 1: return x return ( x[:, :, :, None, :] .expand(bs, n_kv_heads, slen, n_rep, head_dim) .reshape(bs, n_kv_heads * n_rep, slen, head_dim) ) class RMSNorm(torch.nn.Module): def __init__(self, dim: int, eps: float = 1e-6): """ Initialize the RMSNorm normalization layer. Args: dim (int): The dimension of the input tensor. eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. Attributes: eps (float): A small value added to the denominator for numerical stability. weight (nn.Parameter): Learnable scaling parameter. """ super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(dim)) def _norm(self, x): """ Apply the RMSNorm normalization to the input tensor. Args: x (torch.Tensor): The input tensor. Returns: torch.Tensor: The normalized tensor. """ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) def forward(self, x): """ Forward pass through the RMSNorm layer. Args: x (torch.Tensor): The input tensor. Returns: torch.Tensor: The output tensor after applying RMSNorm. """ output = self._norm(x.float()).type_as(x) return output * self.weight class CausalMultiHeadAttention(nn.Module): def __init__(self, config: SmolLMConfig): super().__init__() self.config = config self.n_head = config.n_heads self.n_embd = config.n_embed # Linear projections for Q, K, V # self.c_attn = nn.Linear(config.n_embed, 3 * config.n_embed) # [n_embd, 3 * n_embd] self.w_q = nn.Linear(config.n_embed, config.n_embed, bias=False) self.w_k = nn.Linear( config.n_embed, config.n_embed // config.n_key_value_heads, bias=False ) self.w_v = nn.Linear( config.n_embed, config.n_embed // config.n_key_value_heads, bias=False ) self.c_proj = nn.Linear( config.n_embed, config.n_embed, bias=False ) # [n_embd, n_embd] self.c_proj.NANGPT_SCALE_INIT = 1 self.n_rep = self.config.n_heads // self.config.n_key_value_heads self.resid_dropout = nn.Dropout(config.dropout) self.register_buffer( "bias", torch.tril(torch.ones(config.block_size, config.block_size)).view( 1, 1, config.block_size, config.block_size ), ) def forward(self, x): B, T, C = x.size() # [B, T, n_embd] # Linear projection and split into Q, K, V # q, k, v = self.c_attn(x).split(self.n_embd, dim=2) # [B, T, n_embd] each q = self.w_q(x) # [B, T, 576] k = self.w_k(x) # [B, T, 192] v = self.w_v(x) # [B, T, 192] # Reshape for multi-head attention k = k.view( B, T, self.config.n_key_value_heads, k.size(-1) // self.config.n_key_value_heads, ).transpose( 1, 2 ) # [B, 3, T, 64] q = q.view( B, T, self.config.n_heads, q.size(-1) // self.config.n_heads ).transpose( 1, 2 ) # [B, 9, T, 64] v = v.view( B, T, self.config.n_key_value_heads, v.size(-1) // self.config.n_key_value_heads, ).transpose( 1, 2 ) # [B, 3, T, 64] # repeat k and v for each head k = repeat_kv(k, self.n_rep) v = repeat_kv(v, self.n_rep) # # Attention scores # att = (q @ k.transpose(-2, -1)) * ( # 1.0 / (k.size(-1) ** 0.5) # ) # [B, n_head, T, T] # att = att.masked_fill( # self.bias[:, :, :T, :T] == 0, float("-inf") # ) # [B, n_head, T, T] # att = F.softmax(att, dim=-1) # [B, n_head, T, T] # # Weighted sum of values # y = att @ v # [B, n_head, T, n_embd/n_head] # Flash attention y = F.scaled_dot_product_attention(q, k, v, is_causal=True) # Flash attention # Reshape and project y = y.transpose(1, 2).contiguous().view(B, T, C) # [B, T, n_embd] y = self.c_proj(y) # [B, T, n_embd] y = self.resid_dropout(y) # [B, T, n_embd] return y class MLP(nn.Module): def __init__(self, config: SmolLMConfig): super().__init__() self.c_fc = nn.Linear(config.n_embed, config.mlp_hidden_dim, bias=False) self.silu = nn.SiLU() self.c_proj = nn.Linear(config.mlp_hidden_dim, config.n_embed, bias=False) self.c_proj.NANOGPT_SCALE_INIT = 1 def forward(self, x): x = self.c_fc(x) x = self.silu(x) x = self.c_proj(x) return x class LlamaMLP(nn.Module): def __init__(self, config: SmolLMConfig): super().__init__() self.hidden_dim = config.mlp_hidden_dim # 1536 self.w1 = nn.Linear(config.n_embed, self.hidden_dim, bias=False) self.w2 = nn.Linear(self.hidden_dim, config.n_embed, bias=False) self.w3 = nn.Linear(config.n_embed, self.hidden_dim, bias=False) def forward(self, x): return self.w2(F.silu(self.w1(x)) * self.w3(x)) class DecoderBlockWithRMSNorm(nn.Module): def __init__(self, config: SmolLMConfig): super().__init__() self.config = config self.rms_1 = RMSNorm(self.config.n_embed, eps=self.config.rms_norm_eps) self.attn = CausalMultiHeadAttention(config) self.rms_2 = RMSNorm(self.config.n_embed, eps=self.config.rms_norm_eps) self.mlp = LlamaMLP(config) def forward(self, x): x = x + self.attn(self.rms_1(x)) x = x + self.mlp(self.rms_2(x)) return x class DecoderBlockWithLayerNorm(nn.Module): def __init__(self, config: SmolLMConfig): super().__init__() self.ln_1 = nn.LayerNorm(config.n_embed) self.attn = CausalMultiHeadAttention(config) self.ln_2 = nn.LayerNorm(config.n_embed) self.mlp = MLP(config) def forward(self, x): x = x + self.attn(self.ln_1(x)) x = x + self.mlp(self.ln_2(x)) return x class SmolLM(nn.Module): def __init__(self, config: SmolLMConfig): super().__init__() self.config = config self.wte = nn.Embedding( config.vocab_size, config.n_embed ) # [vocab_size, n_embd] self.wpe = nn.Embedding( config.block_size, config.n_embed ) # [max_seq_len, n_embd] self.drop = nn.Dropout(config.dropout) self.blocks = nn.ModuleList( [DecoderBlockWithRMSNorm(config) for _ in range(config.n_layers)] ) self.rms_norm = RMSNorm(config.n_embed, eps=config.rms_norm_eps) # [n_embd] self.lm_head = nn.Linear( config.n_embed, config.vocab_size, bias=False ) # [n_embd, vocab_size] # weight sharing self.wte.weight = self.lm_head.weight self.apply(self._init_weights) def _init_weights(self, module): if isinstance(module, nn.Linear): std = 0.02 if hasattr(module, "NANGPT_SCALE_INIT"): std *= (2 * self.config.n_layers) ** -0.5 torch.nn.init.normal_(module.weight, mean=0.0, std=std) if module.bias is not None: torch.nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) def forward(self, idx, targets=None): # idx is of shape (B, T) B, T = idx.size() assert ( T <= self.config.block_size ), f"Cannot forward sequence of length {T}, block size is only {self.config.block_size}" pos = torch.arange(0, T, dtype=torch.long, device=idx.device) # shape (T) pos_emb = self.wpe(pos) # position embeddings of shape (T, n_embd) x = self.wte(idx) # token embeddings of shape (B, T, n_embd) x = x + pos_emb # forward the blocks of the transformer for block in self.blocks: x = block(x) # forward the final layernorm and the classifier x = self.rms_norm(x) logits = self.lm_head(x) # (B, T, vocab_size) loss = None if targets is not None: loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) return logits, loss @torch.no_grad() def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None): """ Generate text given a starting sequence of tokens. Args: idx (torch.Tensor): Starting token indices, shape (B, T) max_new_tokens (int): Number of tokens to generate temperature (float): Sampling temperature (1.0 = no change, < 1.0 = less random, > 1.0 = more random) top_k (int): If specified, only sample from the top k most probable tokens """ for _ in range(max_new_tokens): # if the sequence context is growing too long we must crop it at block_size idx_cond = ( idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size :] ) # forward the model to get the logits for the index in the sequence logits, _ = self(idx_cond) # pluck the logits at the final step and scale by desired temperature logits = logits[:, -1, :] / temperature # optionally crop the logits to only the top k options if top_k is not None: v, _ = torch.topk(logits, min(top_k, logits.size(-1))) logits[logits < v[:, [-1]]] = -float("Inf") # apply softmax to convert logits to (normalized) probabilities probs = F.softmax(logits, dim=-1) # sample from the distribution idx_next = torch.multinomial(probs, num_samples=1) # append sampled index to the running sequence idx = torch.cat((idx, idx_next), dim=1) return idx class SmolLMLightning(pl.LightningModule): def __init__(self, config: SmolLMConfig, lr, warmup_steps, max_steps): super().__init__() self.save_hyperparameters() self.config = config self.model = SmolLM(self.config) self.criterion = nn.CrossEntropyLoss() self.tokenizer = tokenizer self.generation_prompt = "Once upon a time" self._generating = False def forward(self, x): return self.model(x) def training_step(self, batch, batch_idx): input_ids = batch["input_ids"] target_ids = batch["labels"] logits, _ = self(input_ids) loss = self.criterion(logits.view(-1, logits.size(-1)), target_ids.view(-1)) # Log the loss with 4 decimal precision self.log( "train_loss", loss, prog_bar=True, on_step=True, on_epoch=False, logger=True ) # Generate text every n steps, but only if we're not already generating if (self.global_step) % log_every_n_steps == 0 and not self._generating: self._generating = True self.generate_and_log_sample() self._generating = False return loss def generate_and_log_sample(self): """Generate and log a sample of text from the model""" try: # Encode the prompt prompt_ids = self.tokenizer.encode( self.generation_prompt, return_tensors="pt" ).to(self.device) # Generate new tokens generated_ids = self.model.generate( prompt_ids, max_new_tokens=50, temperature=0.8, top_k=40 ) # Decode the generated tokens generated_text = self.tokenizer.decode(generated_ids[0].tolist()) # Create a formatted message message = ( f"\n{'='*40}\n" f"Step {self.global_step} generation:\n" f"Prompt: {self.generation_prompt}\n" f"Generated: {generated_text}\n" f"{'='*40}\n" ) print(message) # Log to WandB if hasattr(self.logger, "experiment"): self.logger.experiment.log( {"generated_text": generated_text, "global_step": self.global_step} ) except Exception as e: print(f"Generation failed with error: {str(e)}") def configure_optimizers(self): optimizer = torch.optim.AdamW(self.parameters(), lr=self.hparams.lr) def lr_lambda(current_step): if current_step < self.hparams.warmup_steps: return self.hparams.lr * (current_step + 1) / self.hparams.warmup_steps elif current_step > self.hparams.max_steps: return self.hparams.lr * 0.1 decay_ratio = (current_step - self.hparams.warmup_steps) / ( self.hparams.max_steps - self.hparams.warmup_steps ) coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) return self.hparams.lr * 0.1 + coeff * ( self.hparams.lr - self.hparams.lr * 0.1 ) scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) return [optimizer], [scheduler] if __name__ == "__main__": torch.set_float32_matmul_precision("high") dataloader = load_cosmopedia_dataset(batch_size=batch_size, seq_length=block_size) # Check if checkpoint exists checkpoint_path = "checkpoints/best-checkpoint.ckpt" if os.path.exists(checkpoint_path): print(f"Loading model from checkpoint: {checkpoint_path}") model = SmolLMLightning.load_from_checkpoint( checkpoint_path, config=SmolLMConfig(), lr=max_lr, warmup_steps=warmup_steps, max_steps=max_steps, ) else: print("Starting training from scratch") model = SmolLMLightning(SmolLMConfig(), max_lr, warmup_steps, max_steps) # Replace TensorBoard logger with WandB logger wandb_logger = WandbLogger( project="smollm", # your project name name="transformer_experiment", # name of the run log_model=True, # log model checkpoints ) os.makedirs("checkpoints", exist_ok=True) checkpoint_callback = ModelCheckpoint( dirpath="checkpoints/", filename="best-checkpoint", verbose=True, every_n_train_steps=save_checkpoints_every_n_steps, ) device = "cpu" if torch.cuda.is_available(): device = "cuda" elif torch.backends.mps.is_available(): device = "mps" print(f"using device: {device}") progress_bar = RichProgressBar( refresh_rate=1, leave=False, theme=RichProgressBarTheme( description="", progress_bar="#6206E0", progress_bar_finished="#6206E0", progress_bar_pulse="#6206E0", batch_progress="", time="dim", processing_speed="dim underline", metrics="italic", metrics_text_delimiter=" ", metrics_format=".3f", ), console_kwargs=None, ) trainer = pl.Trainer( max_steps=max_steps, accelerator=device, devices=1, callbacks=[ LearningRateMonitor(logging_interval="step"), progress_bar, checkpoint_callback, ], precision="bf16-mixed", log_every_n_steps=1, enable_progress_bar=True, enable_model_summary=True, logger=wandb_logger, accumulate_grad_batches=effective_batch_size // batch_size, ) trainer.fit(model, dataloader)