|
import os, sys |
|
from os.path import dirname as up |
|
|
|
sys.path.append(os.path.abspath(os.path.join(up(__file__), os.pardir))) |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from block.transformer import TransformerBlock |
|
from block.rms_norm import RMSNorm |
|
from block.rope import compute_rope_params |
|
|
|
class Gemma3Model(nn.Module): |
|
def __init__(self, cfg): |
|
super().__init__() |
|
assert cfg["layer_types"] is not None and len(cfg["layer_types"]) == cfg["n_layers"] |
|
|
|
|
|
self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"], dtype=cfg["dtype"]) |
|
|
|
self.blocks = nn.ModuleList([ |
|
TransformerBlock(cfg, attn_type)for attn_type in cfg["layer_types"] |
|
]) |
|
|
|
self.final_norm = RMSNorm(cfg["emb_dim"], eps=1e-6) |
|
self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False, dtype=cfg["dtype"]) |
|
self.cfg = cfg |
|
|
|
|
|
cos_local, sin_local = compute_rope_params( |
|
head_dim=cfg["head_dim"], |
|
theta_base=cfg["rope_local_base"], |
|
context_length=cfg["context_length"], |
|
dtype=torch.float32, |
|
) |
|
cos_global, sin_global = compute_rope_params( |
|
head_dim=cfg["head_dim"], |
|
theta_base=cfg["rope_base"], |
|
context_length=cfg["context_length"], |
|
dtype=torch.float32, |
|
) |
|
self.register_buffer("cos_local", cos_local, persistent=False) |
|
self.register_buffer("sin_local", sin_local, persistent=False) |
|
self.register_buffer("cos_global", cos_global, persistent=False) |
|
self.register_buffer("sin_global", sin_global, persistent=False) |
|
|
|
def _create_masks(self, seq_len, device): |
|
ones = torch.ones((seq_len, seq_len), dtype=torch.bool, device=device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mask_global = torch.triu(ones, diagonal=1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
far_past = torch.triu(ones, diagonal=self.cfg["sliding_window"]).T |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mask_local = mask_global | far_past |
|
return mask_global, mask_local |
|
|
|
def forward(self, input_ids, targets=None): |
|
b, seq_len = input_ids.shape |
|
x = self.tok_emb(input_ids) * (self.cfg["emb_dim"] ** 0.5) |
|
mask_global, mask_local = self._create_masks(seq_len, x.device) |
|
|
|
for block in self.blocks: |
|
x = block( |
|
x, |
|
mask_global=mask_global, |
|
mask_local=mask_local, |
|
cos_global=self.cos_global, |
|
sin_global=self.sin_global, |
|
cos_local=self.cos_local, |
|
sin_local=self.sin_local, |
|
) |
|
|
|
x = self.final_norm(x) |
|
logits = self.out_head(x.to(self.cfg["dtype"])) |
|
loss = None |
|
if targets is not None: |
|
loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), targets.reshape(-1)) |
|
return logits, loss |
|
|
|
@torch.no_grad() |
|
def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None): |
|
for _ in range(max_new_tokens): |
|
ctx_len = self.cfg["context_length"] |
|
idx_cond = idx if idx.size(1) <= ctx_len else idx[:, -ctx_len:] |
|
logits, _ = self(idx_cond) |
|
logits = logits[:, -1, :] / temperature |
|
if top_k is not None: |
|
v, _ = torch.topk(logits, min(top_k, logits.size(-1))) |
|
logits[logits < v[:, [-1]]] = float("-inf") |
|
probs = F.softmax(logits, dim=-1) |
|
idx_next = torch.multinomial(probs, num_samples=1) |
|
idx = torch.cat((idx, idx_next), dim=1) |
|
return idx |
|
|
|
|