Trd-Bobo242 commited on
Commit
788b21a
·
verified ·
1 Parent(s): 53303c3

Create model.py

Browse files
Files changed (1) hide show
  1. model.py +292 -0
model.py ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import PreTrainedModel, PreTrainedTokenizer, AutoConfig
4
+ from typing import Optional, Dict, Any
5
+ import torch.nn.functional as F
6
+
7
+
8
+ class SuperConfig(AutoConfig):
9
+ """
10
+ Configuration class for the Super model
11
+ """
12
+ model_type = "super"
13
+
14
+ def __init__(
15
+ self,
16
+ vocab_size=50257,
17
+ n_embd=768,
18
+ n_layer=12,
19
+ n_head=12,
20
+ n_inner=None,
21
+ activation_function="gelu_new",
22
+ resid_pdrop=0.1,
23
+ embd_pdrop=0.1,
24
+ attn_pdrop=0.1,
25
+ layer_norm_epsilon=1e-5,
26
+ initializer_range=0.02,
27
+ scale_attn_weights=True,
28
+ use_cache=True,
29
+ bos_token_id=50256,
30
+ eos_token_id=50256,
31
+ apply_residual_connection_post_layernorm=False,
32
+ hidden_dropout=0.0,
33
+ attention_dropout=0.0,
34
+ **kwargs
35
+ ):
36
+ super().__init__(
37
+ bos_token_id=bos_token_id,
38
+ eos_token_id=eos_token_id,
39
+ **kwargs
40
+ )
41
+ self.vocab_size = vocab_size
42
+ self.n_embd = n_embd
43
+ self.n_layer = n_layer
44
+ self.n_head = n_head
45
+ self.n_inner = n_inner
46
+ self.activation_function = activation_function
47
+ self.resid_pdrop = resid_pdrop
48
+ self.embd_pdrop = embd_pdrop
49
+ self.attn_pdrop = attn_pdrop
50
+ self.layer_norm_epsilon = layer_norm_epsilon
51
+ self.initializer_range = initializer_range
52
+ self.scale_attn_weights = scale_attn_weights
53
+ self.use_cache = use_cache
54
+ self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm
55
+ self.hidden_dropout = hidden_dropout
56
+ self.attention_dropout = attention_dropout
57
+
58
+
59
+ class SuperAttention(nn.Module):
60
+ """
61
+ Multi-head attention module for Super model
62
+ """
63
+ def __init__(self, config):
64
+ super().__init__()
65
+ self.n_embd = config.n_embd
66
+ self.n_head = config.n_head
67
+ self.head_size = self.n_embd // self.n_head
68
+ self.scale = self.head_size ** -0.5
69
+
70
+ self.c_attn = nn.Linear(self.n_embd, 3 * self.n_embd)
71
+ self.c_proj = nn.Linear(self.n_embd, self.n_embd)
72
+ self.attn_dropout = nn.Dropout(config.attn_pdrop)
73
+ self.resid_dropout = nn.Dropout(config.resid_pdrop)
74
+
75
+ def forward(self, x, attention_mask=None):
76
+ B, T, C = x.size()
77
+
78
+ # Query, Key, Value projections
79
+ qkv = self.c_attn(x)
80
+ q, k, v = qkv.split(self.n_embd, dim=2)
81
+
82
+ # Reshape for multi-head attention
83
+ q = q.view(B, T, self.n_head, self.head_size).transpose(1, 2)
84
+ k = k.view(B, T, self.n_head, self.head_size).transpose(1, 2)
85
+ v = v.view(B, T, self.n_head, self.head_size).transpose(1, 2)
86
+
87
+ # Attention scores
88
+ att = (q @ k.transpose(-2, -1)) * self.scale
89
+
90
+ # Apply attention mask if provided
91
+ if attention_mask is not None:
92
+ att = att.masked_fill(attention_mask == 0, float('-inf'))
93
+
94
+ att = F.softmax(att, dim=-1)
95
+ att = self.attn_dropout(att)
96
+
97
+ # Weighted sum of values
98
+ y = att @ v
99
+ y = y.transpose(1, 2).contiguous().view(B, T, C)
100
+
101
+ # Output projection
102
+ y = self.resid_dropout(self.c_proj(y))
103
+ return y
104
+
105
+
106
+ class SuperMLP(nn.Module):
107
+ """
108
+ Feed-forward network for Super model
109
+ """
110
+ def __init__(self, config):
111
+ super().__init__()
112
+ self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd)
113
+ self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd)
114
+ self.act = nn.GELU()
115
+ self.dropout = nn.Dropout(config.resid_pdrop)
116
+
117
+ def forward(self, x):
118
+ x = self.c_fc(x)
119
+ x = self.act(x)
120
+ x = self.c_proj(x)
121
+ x = self.dropout(x)
122
+ return x
123
+
124
+
125
+ class SuperBlock(nn.Module):
126
+ """
127
+ Transformer block for Super model
128
+ """
129
+ def __init__(self, config):
130
+ super().__init__()
131
+ self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
132
+ self.attn = SuperAttention(config)
133
+ self.ln_2 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
134
+ self.mlp = SuperMLP(config)
135
+
136
+ def forward(self, x, attention_mask=None):
137
+ x = x + self.attn(self.ln_1(x), attention_mask)
138
+ x = x + self.mlp(self.ln_2(x))
139
+ return x
140
+
141
+
142
+ class SuperModel(PreTrainedModel):
143
+ """
144
+ The Super model implementation
145
+ """
146
+ config_class = SuperConfig
147
+
148
+ def __init__(self, config):
149
+ super().__init__(config)
150
+ self.config = config
151
+
152
+ self.wte = nn.Embedding(config.vocab_size, config.n_embd)
153
+ self.wpe = nn.Embedding(1024, config.n_embd) # positional embeddings
154
+ self.drop = nn.Dropout(config.embd_pdrop)
155
+ self.h = nn.ModuleList([SuperBlock(config) for _ in range(config.n_layer)])
156
+ self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
157
+
158
+ # Initialize weights
159
+ self.apply(self._init_weights)
160
+
161
+ def _init_weights(self, module):
162
+ if isinstance(module, (nn.Linear, nn.Embedding)):
163
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
164
+ if isinstance(module, nn.Linear) and module.bias is not None:
165
+ module.bias.data.zero_()
166
+ elif isinstance(module, nn.LayerNorm):
167
+ module.bias.data.zero_()
168
+ module.weight.data.fill_(1.0)
169
+
170
+ def forward(self, input_ids, attention_mask=None, position_ids=None):
171
+ device = input_ids.device
172
+ b, t = input_ids.size()
173
+
174
+ if position_ids is None:
175
+ position_ids = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0)
176
+
177
+ # Token and position embeddings
178
+ tok_emb = self.wte(input_ids)
179
+ pos_emb = self.wpe(position_ids)
180
+ x = self.drop(tok_emb + pos_emb)
181
+
182
+ # Prepare attention mask
183
+ if attention_mask is not None:
184
+ attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
185
+ attention_mask = attention_mask.to(dtype=torch.float32)
186
+ attention_mask = (1.0 - attention_mask) * torch.finfo(torch.float32).min
187
+
188
+ # Transformer blocks
189
+ for block in self.h:
190
+ x = block(x, attention_mask)
191
+
192
+ x = self.ln_f(x)
193
+ return x
194
+
195
+
196
+ class SuperForCausalLM(PreTrainedModel):
197
+ """
198
+ Super model for causal language modeling
199
+ """
200
+ config_class = SuperConfig
201
+
202
+ def __init__(self, config):
203
+ super().__init__(config)
204
+ self.transformer = SuperModel(config)
205
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
206
+
207
+ # Tie weights
208
+ self.lm_head.weight = self.transformer.wte.weight
209
+
210
+ self.apply(self._init_weights)
211
+
212
+ def _init_weights(self, module):
213
+ if isinstance(module, (nn.Linear, nn.Embedding)):
214
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
215
+ if isinstance(module, nn.Linear) and module.bias is not None:
216
+ module.bias.data.zero_()
217
+ elif isinstance(module, nn.LayerNorm):
218
+ module.bias.data.zero_()
219
+ module.weight.data.fill_(1.0)
220
+
221
+ def forward(self, input_ids, attention_mask=None, labels=None):
222
+ hidden_states = self.transformer(input_ids, attention_mask)
223
+ lm_logits = self.lm_head(hidden_states)
224
+
225
+ loss = None
226
+ if labels is not None:
227
+ shift_logits = lm_logits[..., :-1, :].contiguous()
228
+ shift_labels = labels[..., 1:].contiguous()
229
+ loss_fct = nn.CrossEntropyLoss()
230
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
231
+
232
+ return {
233
+ 'loss': loss,
234
+ 'logits': lm_logits,
235
+ 'hidden_states': hidden_states
236
+ }
237
+
238
+ def generate(self, input_ids, max_length=100, temperature=1.0, top_k=50, top_p=0.95):
239
+ """
240
+ Simple generation method
241
+ """
242
+ for _ in range(max_length - input_ids.size(1)):
243
+ with torch.no_grad():
244
+ outputs = self.forward(input_ids)
245
+ next_token_logits = outputs['logits'][:, -1, :] / temperature
246
+
247
+ # Apply top-k filtering
248
+ if top_k > 0:
249
+ indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None]
250
+ next_token_logits[indices_to_remove] = -float('Inf')
251
+
252
+ # Apply top-p filtering
253
+ if top_p < 1.0:
254
+ sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
255
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
256
+ sorted_indices_to_remove = cumulative_probs > top_p
257
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
258
+ sorted_indices_to_remove[..., 0] = 0
259
+ indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
260
+ next_token_logits[indices_to_remove] = -float('Inf')
261
+
262
+ # Sample next token
263
+ probs = F.softmax(next_token_logits, dim=-1)
264
+ next_token = torch.multinomial(probs, num_samples=1)
265
+ input_ids = torch.cat([input_ids, next_token], dim=1)
266
+
267
+ return input_ids
268
+
269
+
270
+ # Example usage and model initialization
271
+ def create_super_model(vocab_size=50257, n_embd=768, n_layer=12, n_head=12):
272
+ """
273
+ Helper function to create a Super model instance
274
+ """
275
+ config = SuperConfig(
276
+ vocab_size=vocab_size,
277
+ n_embd=n_embd,
278
+ n_layer=n_layer,
279
+ n_head=n_head
280
+ )
281
+ return SuperForCausalLM(config)
282
+
283
+
284
+ if __name__ == "__main__":
285
+ # Example usage
286
+ model = create_super_model()
287
+ print(f"Super model created with {sum(p.numel() for p in model.parameters()):,} parameters")
288
+
289
+ # Test forward pass
290
+ input_ids = torch.tensor([[1, 2, 3, 4, 5]])
291
+ outputs = model(input_ids)
292
+ print(f"Output logits shape: {outputs['logits'].shape}")