Architecture code included.
Browse files- .gitattributes +1 -0
- architecture/README.md +82 -0
- architecture/__init__.py +2 -0
- architecture/architecture.png +3 -0
- architecture/gemma3.py +130 -0
- architecture/model_config.py +16 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
architecture/architecture.png filter=lfs diff=lfs merge=lfs -text
|
architecture/README.md
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Architecture Module
|
2 |
+
|
3 |
+
This module contains the main Gemma3 model implementation and configuration management.
|
4 |
+
|
5 |
+
## Files
|
6 |
+
|
7 |
+
### `gemma3.py`
|
8 |
+
The core Gemma3Model class implementation featuring:
|
9 |
+
|
10 |
+
- **Token Embeddings**: Scaled embedding layer with vocabulary size of 50,257
|
11 |
+
- **Transformer Blocks**: 18 layers with mixed attention patterns (sliding window and full attention)
|
12 |
+
- **Dual RoPE**: Two sets of rotary position embeddings for local and global context
|
13 |
+
- **Attention Masks**: Dynamic generation of causal and sliding window masks
|
14 |
+
- **Output Head**: Linear projection to vocabulary size for next-token prediction
|
15 |
+
- **Generation Method**: Temperature-controlled sampling with top-k filtering
|
16 |
+
|
17 |
+
Key components:
|
18 |
+
- `__init__`: Initializes model layers, embeddings, and precomputes RoPE parameters
|
19 |
+
- `_create_masks`: Generates causal and sliding window attention masks
|
20 |
+
- `forward`: Main forward pass with optional loss computation
|
21 |
+
- `generate`: Autoregressive text generation with temperature and top-k sampling
|
22 |
+
|
23 |
+
### `model_config.py`
|
24 |
+
Configuration loader that reads model hyperparameters from `config/model_config.json`.
|
25 |
+
|
26 |
+
### `__init__.py`
|
27 |
+
Module initialization that exports:
|
28 |
+
- `model_config`: Dictionary containing all model hyperparameters
|
29 |
+
- `Gemma3Model`: The main model class
|
30 |
+
|
31 |
+
## Model Architecture Details
|
32 |
+
|
33 |
+
### Layer Configuration
|
34 |
+
The model uses a strategic mix of attention types across 18 layers:
|
35 |
+
- **Layers 1-5**: Sliding window attention (512 token window)
|
36 |
+
- **Layer 6**: Full attention (checkpoint layer)
|
37 |
+
- **Layers 7-11**: Sliding window attention
|
38 |
+
- **Layer 12**: Full attention (checkpoint layer)
|
39 |
+
- **Layers 13-17**: Sliding window attention
|
40 |
+
- **Layer 18**: Full attention (final layer)
|
41 |
+
|
42 |
+
This pattern allows the model to:
|
43 |
+
- Efficiently process local context with sliding windows
|
44 |
+
- Capture long-range dependencies at strategic checkpoints
|
45 |
+
- Balance computational efficiency with modeling capability
|
46 |
+
|
47 |
+
### Embedding and Normalization
|
48 |
+
- **Embedding Scaling**: Input embeddings are scaled by √(embedding_dim) for training stability
|
49 |
+
- **Final Normalization**: RMS normalization before the output projection
|
50 |
+
- **Weight Tying**: Output projection weights are separate from input embeddings
|
51 |
+
|
52 |
+
### Position Encoding
|
53 |
+
The model uses dual RoPE (Rotary Position Embeddings):
|
54 |
+
- **Local RoPE**: θ_base = 10,000 for sliding window attention
|
55 |
+
- **Global RoPE**: θ_base = 1,000,000 for full attention layers
|
56 |
+
|
57 |
+
This dual approach allows different attention patterns to use position encodings optimized for their respective context ranges.
|
58 |
+
|
59 |
+
## Usage Example
|
60 |
+
|
61 |
+
```python
|
62 |
+
from architecture import Gemma3Model, model_config
|
63 |
+
import torch
|
64 |
+
|
65 |
+
# Initialize model
|
66 |
+
model = Gemma3Model(model_config)
|
67 |
+
|
68 |
+
# Forward pass
|
69 |
+
input_ids = torch.randint(0, 50257, (2, 128)) # batch_size=2, seq_len=128
|
70 |
+
logits, loss = model(input_ids, targets=None)
|
71 |
+
|
72 |
+
# Generation
|
73 |
+
prompt = torch.randint(0, 50257, (1, 10)) # Single prompt
|
74 |
+
generated = model.generate(prompt, max_new_tokens=50, temperature=0.8, top_k=40)
|
75 |
+
```
|
76 |
+
|
77 |
+
## Design Decisions
|
78 |
+
|
79 |
+
1. **Mixed Attention**: Combines efficiency of sliding windows with the modeling power of full attention
|
80 |
+
2. **Separate RoPE Bases**: Optimizes position encoding for different attention ranges
|
81 |
+
3. **Grouped Query Attention**: Reduces KV cache memory while maintaining performance
|
82 |
+
4. **Gemma3-style Normalization**: Uses (1 + weight) scaling for better training dynamics
|
architecture/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .gemma3 import Gemma3Model
|
2 |
+
from .model_config import model_config
|
architecture/architecture.png
ADDED
![]() |
Git LFS Details
|
architecture/gemma3.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, sys
|
2 |
+
from os.path import dirname as up
|
3 |
+
|
4 |
+
sys.path.append(os.path.abspath(os.path.join(up(__file__), os.pardir)))
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
|
10 |
+
from block.transformer import TransformerBlock
|
11 |
+
from block.rms_norm import RMSNorm
|
12 |
+
from block.rope import compute_rope_params
|
13 |
+
|
14 |
+
class Gemma3Model(nn.Module):
|
15 |
+
def __init__(self, cfg):
|
16 |
+
super().__init__()
|
17 |
+
assert cfg["layer_types"] is not None and len(cfg["layer_types"]) == cfg["n_layers"]
|
18 |
+
|
19 |
+
# Main model parameters
|
20 |
+
self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"], dtype=cfg["dtype"])
|
21 |
+
|
22 |
+
self.blocks = nn.ModuleList([
|
23 |
+
TransformerBlock(cfg, attn_type)for attn_type in cfg["layer_types"]
|
24 |
+
])
|
25 |
+
|
26 |
+
self.final_norm = RMSNorm(cfg["emb_dim"], eps=1e-6)
|
27 |
+
self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False, dtype=cfg["dtype"])
|
28 |
+
self.cfg = cfg
|
29 |
+
|
30 |
+
# Reusuable utilities
|
31 |
+
cos_local, sin_local = compute_rope_params(
|
32 |
+
head_dim=cfg["head_dim"],
|
33 |
+
theta_base=cfg["rope_local_base"],
|
34 |
+
context_length=cfg["context_length"],
|
35 |
+
dtype=torch.float32,
|
36 |
+
)
|
37 |
+
cos_global, sin_global = compute_rope_params(
|
38 |
+
head_dim=cfg["head_dim"],
|
39 |
+
theta_base=cfg["rope_base"],
|
40 |
+
context_length=cfg["context_length"],
|
41 |
+
dtype=torch.float32,
|
42 |
+
)
|
43 |
+
self.register_buffer("cos_local", cos_local, persistent=False)
|
44 |
+
self.register_buffer("sin_local", sin_local, persistent=False)
|
45 |
+
self.register_buffer("cos_global", cos_global, persistent=False)
|
46 |
+
self.register_buffer("sin_global", sin_global, persistent=False)
|
47 |
+
|
48 |
+
def _create_masks(self, seq_len, device):
|
49 |
+
ones = torch.ones((seq_len, seq_len), dtype=torch.bool, device=device)
|
50 |
+
|
51 |
+
# mask_global (future is masked: j > i)
|
52 |
+
# j: 0 1 2 3 4 5 6 7
|
53 |
+
# i
|
54 |
+
# 0: 0 1 1 1 1 1 1 1
|
55 |
+
# 1: 0 0 1 1 1 1 1 1
|
56 |
+
# 2: 0 0 0 1 1 1 1 1
|
57 |
+
# 3: 0 0 0 0 1 1 1 1
|
58 |
+
# 4: 0 0 0 0 0 1 1 1
|
59 |
+
# 5: 0 0 0 0 0 0 1 1
|
60 |
+
# 6: 0 0 0 0 0 0 0 1
|
61 |
+
# 7: 0 0 0 0 0 0 0 0
|
62 |
+
mask_global = torch.triu(ones, diagonal=1)
|
63 |
+
|
64 |
+
# far_past (too far back is masked: i - j >= sliding_window)
|
65 |
+
# where sliding_window = 4
|
66 |
+
# j: 0 1 2 3 4 5 6 7
|
67 |
+
# i
|
68 |
+
# 0: 0 0 0 0 0 0 0 0
|
69 |
+
# 1: 0 0 0 0 0 0 0 0
|
70 |
+
# 2: 0 0 0 0 0 0 0 0
|
71 |
+
# 3: 0 0 0 0 0 0 0 0
|
72 |
+
# 4: 1 0 0 0 0 0 0 0
|
73 |
+
# 5: 1 1 0 0 0 0 0 0
|
74 |
+
# 6: 1 1 1 0 0 0 0 0
|
75 |
+
# 7: 1 1 1 1 0 0 0 0
|
76 |
+
far_past = torch.triu(ones, diagonal=self.cfg["sliding_window"]).T
|
77 |
+
|
78 |
+
# Local (sliding_window) = future OR far-past
|
79 |
+
# mask_local
|
80 |
+
# j: 0 1 2 3 4 5 6 7
|
81 |
+
# i
|
82 |
+
# 0: 0 1 1 1 1 1 1 1
|
83 |
+
# 1: 0 0 1 1 1 1 1 1
|
84 |
+
# 2: 0 0 0 1 1 1 1 1
|
85 |
+
# 3: 0 0 0 0 1 1 1 1
|
86 |
+
# 4: 1 0 0 0 0 1 1 1
|
87 |
+
# 5: 1 1 0 0 0 0 1 1
|
88 |
+
# 6: 1 1 1 0 0 0 0 1
|
89 |
+
# 7: 1 1 1 1 0 0 0 0
|
90 |
+
mask_local = mask_global | far_past
|
91 |
+
return mask_global, mask_local
|
92 |
+
|
93 |
+
def forward(self, input_ids, targets=None):
|
94 |
+
b, seq_len = input_ids.shape
|
95 |
+
x = self.tok_emb(input_ids) * (self.cfg["emb_dim"] ** 0.5)
|
96 |
+
mask_global, mask_local = self._create_masks(seq_len, x.device)
|
97 |
+
|
98 |
+
for block in self.blocks:
|
99 |
+
x = block(
|
100 |
+
x,
|
101 |
+
mask_global=mask_global,
|
102 |
+
mask_local=mask_local,
|
103 |
+
cos_global=self.cos_global,
|
104 |
+
sin_global=self.sin_global,
|
105 |
+
cos_local=self.cos_local,
|
106 |
+
sin_local=self.sin_local,
|
107 |
+
)
|
108 |
+
|
109 |
+
x = self.final_norm(x)
|
110 |
+
logits = self.out_head(x.to(self.cfg["dtype"]))
|
111 |
+
loss = None
|
112 |
+
if targets is not None:
|
113 |
+
loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), targets.reshape(-1))
|
114 |
+
return logits, loss
|
115 |
+
|
116 |
+
@torch.no_grad()
|
117 |
+
def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
|
118 |
+
for _ in range(max_new_tokens):
|
119 |
+
ctx_len = self.cfg["context_length"]
|
120 |
+
idx_cond = idx if idx.size(1) <= ctx_len else idx[:, -ctx_len:]
|
121 |
+
logits, _ = self(idx_cond) # targets=None by default
|
122 |
+
logits = logits[:, -1, :] / temperature
|
123 |
+
if top_k is not None:
|
124 |
+
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
125 |
+
logits[logits < v[:, [-1]]] = float("-inf")
|
126 |
+
probs = F.softmax(logits, dim=-1)
|
127 |
+
idx_next = torch.multinomial(probs, num_samples=1)
|
128 |
+
idx = torch.cat((idx, idx_next), dim=1)
|
129 |
+
return idx
|
130 |
+
|
architecture/model_config.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, sys
|
2 |
+
from os.path import dirname as up
|
3 |
+
|
4 |
+
sys.path.append(os.path.abspath(os.path.join(up(__file__), os.pardir)))
|
5 |
+
|
6 |
+
import json
|
7 |
+
|
8 |
+
MODEL_CONFIG_PATH = 'config/model_config.json'
|
9 |
+
|
10 |
+
with open(MODEL_CONFIG_PATH, 'r') as f:
|
11 |
+
model_config = json.load(f)
|
12 |
+
|
13 |
+
# print(model_config)
|
14 |
+
|
15 |
+
|
16 |
+
|