File size: 4,741 Bytes
0009ef5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
import os, sys
from os.path import dirname as up
sys.path.append(os.path.abspath(os.path.join(up(__file__), os.pardir)))
import torch
import torch.nn as nn
import torch.nn.functional as F
from block.transformer import TransformerBlock
from block.rms_norm import RMSNorm
from block.rope import compute_rope_params
class Gemma3Model(nn.Module):
def __init__(self, cfg):
super().__init__()
assert cfg["layer_types"] is not None and len(cfg["layer_types"]) == cfg["n_layers"]
# Main model parameters
self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"], dtype=cfg["dtype"])
self.blocks = nn.ModuleList([
TransformerBlock(cfg, attn_type)for attn_type in cfg["layer_types"]
])
self.final_norm = RMSNorm(cfg["emb_dim"], eps=1e-6)
self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False, dtype=cfg["dtype"])
self.cfg = cfg
# Reusuable utilities
cos_local, sin_local = compute_rope_params(
head_dim=cfg["head_dim"],
theta_base=cfg["rope_local_base"],
context_length=cfg["context_length"],
dtype=torch.float32,
)
cos_global, sin_global = compute_rope_params(
head_dim=cfg["head_dim"],
theta_base=cfg["rope_base"],
context_length=cfg["context_length"],
dtype=torch.float32,
)
self.register_buffer("cos_local", cos_local, persistent=False)
self.register_buffer("sin_local", sin_local, persistent=False)
self.register_buffer("cos_global", cos_global, persistent=False)
self.register_buffer("sin_global", sin_global, persistent=False)
def _create_masks(self, seq_len, device):
ones = torch.ones((seq_len, seq_len), dtype=torch.bool, device=device)
# mask_global (future is masked: j > i)
# j: 0 1 2 3 4 5 6 7
# i
# 0: 0 1 1 1 1 1 1 1
# 1: 0 0 1 1 1 1 1 1
# 2: 0 0 0 1 1 1 1 1
# 3: 0 0 0 0 1 1 1 1
# 4: 0 0 0 0 0 1 1 1
# 5: 0 0 0 0 0 0 1 1
# 6: 0 0 0 0 0 0 0 1
# 7: 0 0 0 0 0 0 0 0
mask_global = torch.triu(ones, diagonal=1)
# far_past (too far back is masked: i - j >= sliding_window)
# where sliding_window = 4
# j: 0 1 2 3 4 5 6 7
# i
# 0: 0 0 0 0 0 0 0 0
# 1: 0 0 0 0 0 0 0 0
# 2: 0 0 0 0 0 0 0 0
# 3: 0 0 0 0 0 0 0 0
# 4: 1 0 0 0 0 0 0 0
# 5: 1 1 0 0 0 0 0 0
# 6: 1 1 1 0 0 0 0 0
# 7: 1 1 1 1 0 0 0 0
far_past = torch.triu(ones, diagonal=self.cfg["sliding_window"]).T
# Local (sliding_window) = future OR far-past
# mask_local
# j: 0 1 2 3 4 5 6 7
# i
# 0: 0 1 1 1 1 1 1 1
# 1: 0 0 1 1 1 1 1 1
# 2: 0 0 0 1 1 1 1 1
# 3: 0 0 0 0 1 1 1 1
# 4: 1 0 0 0 0 1 1 1
# 5: 1 1 0 0 0 0 1 1
# 6: 1 1 1 0 0 0 0 1
# 7: 1 1 1 1 0 0 0 0
mask_local = mask_global | far_past
return mask_global, mask_local
def forward(self, input_ids, targets=None):
b, seq_len = input_ids.shape
x = self.tok_emb(input_ids) * (self.cfg["emb_dim"] ** 0.5)
mask_global, mask_local = self._create_masks(seq_len, x.device)
for block in self.blocks:
x = block(
x,
mask_global=mask_global,
mask_local=mask_local,
cos_global=self.cos_global,
sin_global=self.sin_global,
cos_local=self.cos_local,
sin_local=self.sin_local,
)
x = self.final_norm(x)
logits = self.out_head(x.to(self.cfg["dtype"]))
loss = None
if targets is not None:
loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), targets.reshape(-1))
return logits, loss
@torch.no_grad()
def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
for _ in range(max_new_tokens):
ctx_len = self.cfg["context_length"]
idx_cond = idx if idx.size(1) <= ctx_len else idx[:, -ctx_len:]
logits, _ = self(idx_cond) # targets=None by default
logits = logits[:, -1, :] / temperature
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = float("-inf")
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1)
idx = torch.cat((idx, idx_next), dim=1)
return idx
|