from dataclasses import dataclass from typing import Literal import torch from torch.nn import functional as F def gelu_approx(x): return F.gelu(x, approximate="tanh") @dataclass class LinearWeights: weight: torch.Tensor bias: torch.Tensor def linear(x: torch.Tensor, w: LinearWeights) -> torch.Tensor: return F.linear(x, w.weight, w.bias) @dataclass class LayerNormWeights: weight: torch.Tensor bias: torch.Tensor def layer_norm(x: torch.Tensor, w: LayerNormWeights) -> torch.Tensor: return F.layer_norm(x, w.bias.shape, w.weight, w.bias) @dataclass class MLPWeights: fc1: LinearWeights fc2: LinearWeights act: Literal["gelu_approx"] = "gelu_approx" def mlp(x: torch.Tensor, w: MLPWeights) -> torch.Tensor: x = linear(x, w.fc1) x = gelu_approx(x) x = linear(x, w.fc2) return x @dataclass class AttentionWeights: qkv: LinearWeights proj: LinearWeights def attn(x: torch.Tensor, w: AttentionWeights, n_heads: int) -> torch.Tensor: bsz, q_len, d_model = x.shape head_dim = d_model // n_heads q, k, v = [ t.view(bsz, q_len, n_heads, head_dim).transpose(1, 2) for t in linear(x, w.qkv).chunk(3, dim=-1) ] out = F.scaled_dot_product_attention(q, k, v) out = out.transpose(1, 2).reshape(bsz, q_len, d_model) out = linear(out, w.proj) return out