disham993's picture
Architecture code included.
0009ef5 verified
import os, sys
from os.path import dirname as up
sys.path.append(os.path.abspath(os.path.join(up(__file__), os.pardir)))
import torch
import torch.nn as nn
import torch.nn.functional as F
from block.transformer import TransformerBlock
from block.rms_norm import RMSNorm
from block.rope import compute_rope_params
class Gemma3Model(nn.Module):
def __init__(self, cfg):
super().__init__()
assert cfg["layer_types"] is not None and len(cfg["layer_types"]) == cfg["n_layers"]
# Main model parameters
self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"], dtype=cfg["dtype"])
self.blocks = nn.ModuleList([
TransformerBlock(cfg, attn_type)for attn_type in cfg["layer_types"]
])
self.final_norm = RMSNorm(cfg["emb_dim"], eps=1e-6)
self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False, dtype=cfg["dtype"])
self.cfg = cfg
# Reusuable utilities
cos_local, sin_local = compute_rope_params(
head_dim=cfg["head_dim"],
theta_base=cfg["rope_local_base"],
context_length=cfg["context_length"],
dtype=torch.float32,
)
cos_global, sin_global = compute_rope_params(
head_dim=cfg["head_dim"],
theta_base=cfg["rope_base"],
context_length=cfg["context_length"],
dtype=torch.float32,
)
self.register_buffer("cos_local", cos_local, persistent=False)
self.register_buffer("sin_local", sin_local, persistent=False)
self.register_buffer("cos_global", cos_global, persistent=False)
self.register_buffer("sin_global", sin_global, persistent=False)
def _create_masks(self, seq_len, device):
ones = torch.ones((seq_len, seq_len), dtype=torch.bool, device=device)
# mask_global (future is masked: j > i)
# j: 0 1 2 3 4 5 6 7
# i
# 0: 0 1 1 1 1 1 1 1
# 1: 0 0 1 1 1 1 1 1
# 2: 0 0 0 1 1 1 1 1
# 3: 0 0 0 0 1 1 1 1
# 4: 0 0 0 0 0 1 1 1
# 5: 0 0 0 0 0 0 1 1
# 6: 0 0 0 0 0 0 0 1
# 7: 0 0 0 0 0 0 0 0
mask_global = torch.triu(ones, diagonal=1)
# far_past (too far back is masked: i - j >= sliding_window)
# where sliding_window = 4
# j: 0 1 2 3 4 5 6 7
# i
# 0: 0 0 0 0 0 0 0 0
# 1: 0 0 0 0 0 0 0 0
# 2: 0 0 0 0 0 0 0 0
# 3: 0 0 0 0 0 0 0 0
# 4: 1 0 0 0 0 0 0 0
# 5: 1 1 0 0 0 0 0 0
# 6: 1 1 1 0 0 0 0 0
# 7: 1 1 1 1 0 0 0 0
far_past = torch.triu(ones, diagonal=self.cfg["sliding_window"]).T
# Local (sliding_window) = future OR far-past
# mask_local
# j: 0 1 2 3 4 5 6 7
# i
# 0: 0 1 1 1 1 1 1 1
# 1: 0 0 1 1 1 1 1 1
# 2: 0 0 0 1 1 1 1 1
# 3: 0 0 0 0 1 1 1 1
# 4: 1 0 0 0 0 1 1 1
# 5: 1 1 0 0 0 0 1 1
# 6: 1 1 1 0 0 0 0 1
# 7: 1 1 1 1 0 0 0 0
mask_local = mask_global | far_past
return mask_global, mask_local
def forward(self, input_ids, targets=None):
b, seq_len = input_ids.shape
x = self.tok_emb(input_ids) * (self.cfg["emb_dim"] ** 0.5)
mask_global, mask_local = self._create_masks(seq_len, x.device)
for block in self.blocks:
x = block(
x,
mask_global=mask_global,
mask_local=mask_local,
cos_global=self.cos_global,
sin_global=self.sin_global,
cos_local=self.cos_local,
sin_local=self.sin_local,
)
x = self.final_norm(x)
logits = self.out_head(x.to(self.cfg["dtype"]))
loss = None
if targets is not None:
loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), targets.reshape(-1))
return logits, loss
@torch.no_grad()
def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
for _ in range(max_new_tokens):
ctx_len = self.cfg["context_length"]
idx_cond = idx if idx.size(1) <= ctx_len else idx[:, -ctx_len:]
logits, _ = self(idx_cond) # targets=None by default
logits = logits[:, -1, :] / temperature
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = float("-inf")
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1)
idx = torch.cat((idx, idx_next), dim=1)
return idx