disham993 commited on
Commit
0009ef5
·
verified ·
1 Parent(s): 9a375fc

Architecture code included.

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ architecture/architecture.png filter=lfs diff=lfs merge=lfs -text
architecture/README.md ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Architecture Module
2
+
3
+ This module contains the main Gemma3 model implementation and configuration management.
4
+
5
+ ## Files
6
+
7
+ ### `gemma3.py`
8
+ The core Gemma3Model class implementation featuring:
9
+
10
+ - **Token Embeddings**: Scaled embedding layer with vocabulary size of 50,257
11
+ - **Transformer Blocks**: 18 layers with mixed attention patterns (sliding window and full attention)
12
+ - **Dual RoPE**: Two sets of rotary position embeddings for local and global context
13
+ - **Attention Masks**: Dynamic generation of causal and sliding window masks
14
+ - **Output Head**: Linear projection to vocabulary size for next-token prediction
15
+ - **Generation Method**: Temperature-controlled sampling with top-k filtering
16
+
17
+ Key components:
18
+ - `__init__`: Initializes model layers, embeddings, and precomputes RoPE parameters
19
+ - `_create_masks`: Generates causal and sliding window attention masks
20
+ - `forward`: Main forward pass with optional loss computation
21
+ - `generate`: Autoregressive text generation with temperature and top-k sampling
22
+
23
+ ### `model_config.py`
24
+ Configuration loader that reads model hyperparameters from `config/model_config.json`.
25
+
26
+ ### `__init__.py`
27
+ Module initialization that exports:
28
+ - `model_config`: Dictionary containing all model hyperparameters
29
+ - `Gemma3Model`: The main model class
30
+
31
+ ## Model Architecture Details
32
+
33
+ ### Layer Configuration
34
+ The model uses a strategic mix of attention types across 18 layers:
35
+ - **Layers 1-5**: Sliding window attention (512 token window)
36
+ - **Layer 6**: Full attention (checkpoint layer)
37
+ - **Layers 7-11**: Sliding window attention
38
+ - **Layer 12**: Full attention (checkpoint layer)
39
+ - **Layers 13-17**: Sliding window attention
40
+ - **Layer 18**: Full attention (final layer)
41
+
42
+ This pattern allows the model to:
43
+ - Efficiently process local context with sliding windows
44
+ - Capture long-range dependencies at strategic checkpoints
45
+ - Balance computational efficiency with modeling capability
46
+
47
+ ### Embedding and Normalization
48
+ - **Embedding Scaling**: Input embeddings are scaled by √(embedding_dim) for training stability
49
+ - **Final Normalization**: RMS normalization before the output projection
50
+ - **Weight Tying**: Output projection weights are separate from input embeddings
51
+
52
+ ### Position Encoding
53
+ The model uses dual RoPE (Rotary Position Embeddings):
54
+ - **Local RoPE**: θ_base = 10,000 for sliding window attention
55
+ - **Global RoPE**: θ_base = 1,000,000 for full attention layers
56
+
57
+ This dual approach allows different attention patterns to use position encodings optimized for their respective context ranges.
58
+
59
+ ## Usage Example
60
+
61
+ ```python
62
+ from architecture import Gemma3Model, model_config
63
+ import torch
64
+
65
+ # Initialize model
66
+ model = Gemma3Model(model_config)
67
+
68
+ # Forward pass
69
+ input_ids = torch.randint(0, 50257, (2, 128)) # batch_size=2, seq_len=128
70
+ logits, loss = model(input_ids, targets=None)
71
+
72
+ # Generation
73
+ prompt = torch.randint(0, 50257, (1, 10)) # Single prompt
74
+ generated = model.generate(prompt, max_new_tokens=50, temperature=0.8, top_k=40)
75
+ ```
76
+
77
+ ## Design Decisions
78
+
79
+ 1. **Mixed Attention**: Combines efficiency of sliding windows with the modeling power of full attention
80
+ 2. **Separate RoPE Bases**: Optimizes position encoding for different attention ranges
81
+ 3. **Grouped Query Attention**: Reduces KV cache memory while maintaining performance
82
+ 4. **Gemma3-style Normalization**: Uses (1 + weight) scaling for better training dynamics
architecture/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .gemma3 import Gemma3Model
2
+ from .model_config import model_config
architecture/architecture.png ADDED

Git LFS Details

  • SHA256: 2139dbfc89d90f77abeb28e454bb03b000becf39fa63bf330dec8bca0010a3cf
  • Pointer size: 131 Bytes
  • Size of remote file: 260 kB
architecture/gemma3.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys
2
+ from os.path import dirname as up
3
+
4
+ sys.path.append(os.path.abspath(os.path.join(up(__file__), os.pardir)))
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+ from block.transformer import TransformerBlock
11
+ from block.rms_norm import RMSNorm
12
+ from block.rope import compute_rope_params
13
+
14
+ class Gemma3Model(nn.Module):
15
+ def __init__(self, cfg):
16
+ super().__init__()
17
+ assert cfg["layer_types"] is not None and len(cfg["layer_types"]) == cfg["n_layers"]
18
+
19
+ # Main model parameters
20
+ self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"], dtype=cfg["dtype"])
21
+
22
+ self.blocks = nn.ModuleList([
23
+ TransformerBlock(cfg, attn_type)for attn_type in cfg["layer_types"]
24
+ ])
25
+
26
+ self.final_norm = RMSNorm(cfg["emb_dim"], eps=1e-6)
27
+ self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False, dtype=cfg["dtype"])
28
+ self.cfg = cfg
29
+
30
+ # Reusuable utilities
31
+ cos_local, sin_local = compute_rope_params(
32
+ head_dim=cfg["head_dim"],
33
+ theta_base=cfg["rope_local_base"],
34
+ context_length=cfg["context_length"],
35
+ dtype=torch.float32,
36
+ )
37
+ cos_global, sin_global = compute_rope_params(
38
+ head_dim=cfg["head_dim"],
39
+ theta_base=cfg["rope_base"],
40
+ context_length=cfg["context_length"],
41
+ dtype=torch.float32,
42
+ )
43
+ self.register_buffer("cos_local", cos_local, persistent=False)
44
+ self.register_buffer("sin_local", sin_local, persistent=False)
45
+ self.register_buffer("cos_global", cos_global, persistent=False)
46
+ self.register_buffer("sin_global", sin_global, persistent=False)
47
+
48
+ def _create_masks(self, seq_len, device):
49
+ ones = torch.ones((seq_len, seq_len), dtype=torch.bool, device=device)
50
+
51
+ # mask_global (future is masked: j > i)
52
+ # j: 0 1 2 3 4 5 6 7
53
+ # i
54
+ # 0: 0 1 1 1 1 1 1 1
55
+ # 1: 0 0 1 1 1 1 1 1
56
+ # 2: 0 0 0 1 1 1 1 1
57
+ # 3: 0 0 0 0 1 1 1 1
58
+ # 4: 0 0 0 0 0 1 1 1
59
+ # 5: 0 0 0 0 0 0 1 1
60
+ # 6: 0 0 0 0 0 0 0 1
61
+ # 7: 0 0 0 0 0 0 0 0
62
+ mask_global = torch.triu(ones, diagonal=1)
63
+
64
+ # far_past (too far back is masked: i - j >= sliding_window)
65
+ # where sliding_window = 4
66
+ # j: 0 1 2 3 4 5 6 7
67
+ # i
68
+ # 0: 0 0 0 0 0 0 0 0
69
+ # 1: 0 0 0 0 0 0 0 0
70
+ # 2: 0 0 0 0 0 0 0 0
71
+ # 3: 0 0 0 0 0 0 0 0
72
+ # 4: 1 0 0 0 0 0 0 0
73
+ # 5: 1 1 0 0 0 0 0 0
74
+ # 6: 1 1 1 0 0 0 0 0
75
+ # 7: 1 1 1 1 0 0 0 0
76
+ far_past = torch.triu(ones, diagonal=self.cfg["sliding_window"]).T
77
+
78
+ # Local (sliding_window) = future OR far-past
79
+ # mask_local
80
+ # j: 0 1 2 3 4 5 6 7
81
+ # i
82
+ # 0: 0 1 1 1 1 1 1 1
83
+ # 1: 0 0 1 1 1 1 1 1
84
+ # 2: 0 0 0 1 1 1 1 1
85
+ # 3: 0 0 0 0 1 1 1 1
86
+ # 4: 1 0 0 0 0 1 1 1
87
+ # 5: 1 1 0 0 0 0 1 1
88
+ # 6: 1 1 1 0 0 0 0 1
89
+ # 7: 1 1 1 1 0 0 0 0
90
+ mask_local = mask_global | far_past
91
+ return mask_global, mask_local
92
+
93
+ def forward(self, input_ids, targets=None):
94
+ b, seq_len = input_ids.shape
95
+ x = self.tok_emb(input_ids) * (self.cfg["emb_dim"] ** 0.5)
96
+ mask_global, mask_local = self._create_masks(seq_len, x.device)
97
+
98
+ for block in self.blocks:
99
+ x = block(
100
+ x,
101
+ mask_global=mask_global,
102
+ mask_local=mask_local,
103
+ cos_global=self.cos_global,
104
+ sin_global=self.sin_global,
105
+ cos_local=self.cos_local,
106
+ sin_local=self.sin_local,
107
+ )
108
+
109
+ x = self.final_norm(x)
110
+ logits = self.out_head(x.to(self.cfg["dtype"]))
111
+ loss = None
112
+ if targets is not None:
113
+ loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), targets.reshape(-1))
114
+ return logits, loss
115
+
116
+ @torch.no_grad()
117
+ def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
118
+ for _ in range(max_new_tokens):
119
+ ctx_len = self.cfg["context_length"]
120
+ idx_cond = idx if idx.size(1) <= ctx_len else idx[:, -ctx_len:]
121
+ logits, _ = self(idx_cond) # targets=None by default
122
+ logits = logits[:, -1, :] / temperature
123
+ if top_k is not None:
124
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
125
+ logits[logits < v[:, [-1]]] = float("-inf")
126
+ probs = F.softmax(logits, dim=-1)
127
+ idx_next = torch.multinomial(probs, num_samples=1)
128
+ idx = torch.cat((idx, idx_next), dim=1)
129
+ return idx
130
+
architecture/model_config.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys
2
+ from os.path import dirname as up
3
+
4
+ sys.path.append(os.path.abspath(os.path.join(up(__file__), os.pardir)))
5
+
6
+ import json
7
+
8
+ MODEL_CONFIG_PATH = 'config/model_config.json'
9
+
10
+ with open(MODEL_CONFIG_PATH, 'r') as f:
11
+ model_config = json.load(f)
12
+
13
+ # print(model_config)
14
+
15
+
16
+