appledora commited on
Commit
048ac33
·
verified ·
1 Parent(s): 87d78a9

Upload modeling_recast_llama.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_recast_llama.py +875 -0
modeling_recast_llama.py ADDED
@@ -0,0 +1,875 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # filename: recastmlp_llama_model.py
2
+ from .configuration_recast_llama import RECAST8b_llama
3
+ from transformers import PreTrainedModel
4
+ import math
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from typing import Optional, Tuple, Union, List
9
+ from transformers import AutoConfig
10
+ from transformers.utils import logging
11
+ from transformers.cache_utils import Cache, StaticCache
12
+ from transformers.modeling_outputs import CausalLMOutputWithPast
13
+ from transformers.generation import GenerationMixin
14
+ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
15
+ from transformers.models.llama.modeling_llama import (
16
+ LlamaDecoderLayer,
17
+ LlamaRotaryEmbedding,
18
+ LlamaRMSNorm,
19
+ apply_rotary_pos_emb,
20
+ repeat_kv,
21
+ )
22
+ from transformers.modeling_outputs import BaseModelOutputWithPast
23
+
24
+ logger = logging.get_logger(__name__)
25
+
26
+
27
+ class MLPTemplateBank(nn.Module):
28
+ def __init__(self, config, coef_rows, coef_columns):
29
+ super().__init__()
30
+ self.hidden_size = config.hidden_size
31
+ self.intermediate_size = config.intermediate_size
32
+ self.coef_shape = (coef_rows, coef_columns)
33
+
34
+ assert coef_columns is not None, "coef_columns must not be None"
35
+
36
+ # Ensure divisibility for proper reshaping
37
+ assert (
38
+ self.hidden_size * self.intermediate_size
39
+ ) % coef_rows == 0, f"hidden_size * intermediate_size ({self.hidden_size * self.intermediate_size}) must be divisible by coef_rows ({coef_rows})"
40
+
41
+ template_size = self.hidden_size * self.intermediate_size // coef_rows
42
+
43
+ self.up_templates = nn.Parameter(torch.randn(coef_columns, template_size))
44
+ self.gate_templates = nn.Parameter(torch.randn(coef_columns, template_size))
45
+
46
+ # Better initialization
47
+ nn.init.xavier_uniform_(self.up_templates)
48
+ nn.init.xavier_uniform_(self.gate_templates)
49
+
50
+ def forward(self, up_coeffs, gate_coeffs):
51
+ # Compute chunked weights
52
+ up_chunks = torch.matmul(up_coeffs, self.up_templates)
53
+ gate_chunks = torch.matmul(gate_coeffs, self.gate_templates)
54
+
55
+ # Reshape to final weight matrices
56
+ up_weights = up_chunks.reshape(self.intermediate_size, self.hidden_size)
57
+ gate_weights = gate_chunks.reshape(self.intermediate_size, self.hidden_size)
58
+
59
+ return up_weights, gate_weights
60
+
61
+
62
+ class SharedLlamaMLP(nn.Module):
63
+ def __init__(self, config, bank):
64
+ super().__init__()
65
+ self.config = config
66
+ self.bank = bank
67
+ self.hidden_size = config.hidden_size
68
+ self.intermediate_size = config.intermediate_size
69
+ self.down_proj = nn.Linear(
70
+ config.intermediate_size, config.hidden_size, bias=False
71
+ )
72
+
73
+ # Initialize coefficients with proper shapes
74
+ self.up_coefficients = nn.Parameter(torch.randn(bank.coef_shape))
75
+ self.gate_coefficients = nn.Parameter(torch.randn(bank.coef_shape))
76
+
77
+ # Initialize with small random values instead of ones, then orthogonalize
78
+ nn.init.orthogonal_(self.up_coefficients)
79
+ nn.init.orthogonal_(self.gate_coefficients)
80
+
81
+ if config.mlp_bias:
82
+ self.gate_bias = nn.Parameter(torch.zeros(self.intermediate_size))
83
+ self.up_bias = nn.Parameter(torch.zeros(self.intermediate_size))
84
+ else:
85
+ self.register_parameter("gate_bias", None)
86
+ self.register_parameter("up_bias", None)
87
+
88
+ self.act_fn = F.silu
89
+
90
+ def forward(self, x):
91
+ # Generate weights using template bank
92
+ up_weights, gate_weights = self.bank(
93
+ self.up_coefficients, self.gate_coefficients # Fixed order
94
+ )
95
+
96
+ # Apply SwiGLU: SiLU(gate * x) * up * x
97
+ hidden_states = self.act_fn(
98
+ F.linear(x, gate_weights, self.gate_bias)
99
+ ) * F.linear(x, up_weights, self.up_bias)
100
+ output = self.down_proj(hidden_states)
101
+
102
+ return output
103
+
104
+
105
+ class AttTemplateBank(nn.Module):
106
+ def __init__(self, config, coef_rows, coef_columns):
107
+ super().__init__()
108
+ self.hidden_size = config.hidden_size
109
+ self.num_heads = config.num_attention_heads
110
+ self.head_dim = config.hidden_size // config.num_attention_heads
111
+ self.num_key_value_heads = getattr(
112
+ config, "num_key_value_heads", config.num_attention_heads
113
+ )
114
+ self.kv_dim = self.num_key_value_heads * self.head_dim
115
+ self.coef_shape = (coef_rows, coef_columns)
116
+
117
+ # Ensure divisibility
118
+ assert (
119
+ self.hidden_size * self.hidden_size
120
+ ) % coef_rows == 0, "Q projection size must be divisible by coef_rows"
121
+ assert (
122
+ self.kv_dim * self.hidden_size
123
+ ) % coef_rows == 0, "K/V projection size must be divisible by coef_rows"
124
+
125
+ # Create templates for Q, K, V
126
+ self.q_templates = nn.Parameter(
127
+ torch.randn(coef_columns, self.hidden_size * self.hidden_size // coef_rows)
128
+ )
129
+ self.k_templates = nn.Parameter(
130
+ torch.randn(coef_columns, self.kv_dim * self.hidden_size // coef_rows)
131
+ )
132
+ self.v_templates = nn.Parameter(
133
+ torch.randn(coef_columns, self.kv_dim * self.hidden_size // coef_rows)
134
+ )
135
+
136
+ # Initialize templates
137
+ nn.init.xavier_uniform_(self.q_templates)
138
+ nn.init.xavier_uniform_(self.k_templates)
139
+ nn.init.xavier_uniform_(self.v_templates)
140
+
141
+ def forward(self, q_coeffs, k_coeffs, v_coeffs):
142
+ # Compute chunked weights
143
+ q_chunks = torch.matmul(q_coeffs, self.q_templates)
144
+ k_chunks = torch.matmul(k_coeffs, self.k_templates)
145
+ v_chunks = torch.matmul(v_coeffs, self.v_templates)
146
+
147
+ # Reshape to final weight matrices
148
+ q_weights = q_chunks.reshape(self.hidden_size, self.hidden_size)
149
+ k_weights = k_chunks.reshape(self.kv_dim, self.hidden_size)
150
+ v_weights = v_chunks.reshape(self.kv_dim, self.hidden_size)
151
+
152
+ return q_weights, k_weights, v_weights
153
+
154
+
155
+ class SharedLlamaAttention(nn.Module):
156
+ def __init__(
157
+ self,
158
+ config,
159
+ layer_idx: Optional[int] = None,
160
+ bank: Optional[AttTemplateBank] = None,
161
+ ):
162
+ super().__init__()
163
+ self.config = config
164
+ self.bank = bank
165
+ self.layer_idx = layer_idx
166
+ self.attention_dropout = config.attention_dropout
167
+ self.hidden_size = config.hidden_size
168
+ self.num_heads = config.num_attention_heads
169
+ self.head_dim = self.hidden_size // self.num_heads
170
+ self.num_key_value_heads = getattr(
171
+ config, "num_key_value_heads", config.num_attention_heads
172
+ )
173
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
174
+ self.max_position_embeddings = config.max_position_embeddings
175
+ self.rope_theta = getattr(config, "rope_theta", 10000.0)
176
+ self.is_causal = True
177
+
178
+ self.o_proj = nn.Linear(
179
+ self.hidden_size,
180
+ self.hidden_size,
181
+ bias=getattr(config, "attention_bias", False),
182
+ )
183
+ self.rotary_emb = LlamaRotaryEmbedding(config=self.config)
184
+
185
+ # Initialize coefficients with proper shapes
186
+ self.q_coefficients = nn.Parameter(torch.randn(bank.coef_shape))
187
+ self.k_coefficients = nn.Parameter(torch.randn(bank.coef_shape))
188
+ self.v_coefficients = nn.Parameter(torch.randn(bank.coef_shape))
189
+
190
+ # Initialize with small random values
191
+ nn.init.orthogonal_(self.q_coefficients)
192
+ nn.init.orthogonal_(self.k_coefficients)
193
+ nn.init.orthogonal_(self.v_coefficients)
194
+
195
+ def forward(
196
+ self,
197
+ hidden_states,
198
+ attention_mask=None,
199
+ past_key_value=None,
200
+ cache_position=None,
201
+ position_embeddings=None,
202
+ position_ids=None,
203
+ output_attentions=False,
204
+ use_cache=False,
205
+ **kwargs,
206
+ ):
207
+ bsz, q_len, _ = hidden_states.size()
208
+
209
+ # Generate weights using template bank
210
+ q_weights, k_weights, v_weights = self.bank(
211
+ self.q_coefficients, self.k_coefficients, self.v_coefficients
212
+ )
213
+
214
+ # Apply projections
215
+ query_states = F.linear(hidden_states, q_weights)
216
+ key_states = F.linear(hidden_states, k_weights)
217
+ value_states = F.linear(hidden_states, v_weights)
218
+
219
+ # Reshape for multi-head attention
220
+ query_states = query_states.view(
221
+ bsz, q_len, self.num_heads, self.head_dim
222
+ ).transpose(1, 2)
223
+ key_states = key_states.view(
224
+ bsz, q_len, self.num_key_value_heads, self.head_dim
225
+ ).transpose(1, 2)
226
+ value_states = value_states.view(
227
+ bsz, q_len, self.num_key_value_heads, self.head_dim
228
+ ).transpose(1, 2)
229
+
230
+ # Apply rotary embeddings
231
+ if position_embeddings is None:
232
+ cos, sin = self.rotary_emb(value_states, position_ids)
233
+ else:
234
+ cos, sin = position_embeddings
235
+ query_states, key_states = apply_rotary_pos_emb(
236
+ query_states, key_states, cos, sin
237
+ )
238
+
239
+ # Handle past key values
240
+ if past_key_value is not None:
241
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
242
+ key_states, value_states = past_key_value.update(
243
+ key_states, value_states, self.layer_idx, cache_kwargs
244
+ )
245
+
246
+ # Repeat key/value for grouped query attention
247
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
248
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
249
+
250
+ # Compute attention
251
+ attn_weights = torch.matmul(
252
+ query_states, key_states.transpose(2, 3)
253
+ ) / math.sqrt(self.head_dim)
254
+
255
+ if attention_mask is not None:
256
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
257
+ attn_weights = attn_weights + causal_mask
258
+
259
+ # Apply softmax and dropout
260
+ attn_weights = nn.functional.softmax(
261
+ attn_weights, dim=-1, dtype=torch.float32
262
+ ).to(query_states.dtype)
263
+ attn_weights = nn.functional.dropout(
264
+ attn_weights, p=self.attention_dropout, training=self.training
265
+ )
266
+ attn_output = torch.matmul(attn_weights, value_states)
267
+
268
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
269
+ raise ValueError(
270
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
271
+ f" {attn_output.size()}"
272
+ )
273
+
274
+ attn_output = attn_output.transpose(1, 2).contiguous()
275
+ attn_output = attn_output.reshape(bsz, q_len, -1)
276
+ attn_output = self.o_proj(attn_output)
277
+
278
+ if not output_attentions:
279
+ attn_weights = None
280
+
281
+ return attn_output, attn_weights, past_key_value
282
+
283
+
284
+ def fixed_cross_entropy(
285
+ source,
286
+ target,
287
+ num_items_in_batch: int = None,
288
+ ignore_index: int = -100,
289
+ **kwargs,
290
+ ):
291
+ reduction = "sum" if num_items_in_batch is not None else "mean"
292
+ loss = nn.functional.cross_entropy(
293
+ source, target, ignore_index=ignore_index, reduction=reduction
294
+ )
295
+ if reduction == "sum":
296
+ loss = loss / num_items_in_batch
297
+ return loss
298
+
299
+
300
+ class RECAST8b_llamaModel(PreTrainedModel):
301
+ config_class = RECAST8b_llama
302
+ base_model_prefix = "llama"
303
+ supports_gradient_checkpointing = True
304
+
305
+ def __init__(self, config):
306
+ super().__init__(config)
307
+ self.padding_idx = config.pad_token_id
308
+ self.vocab_size = config.vocab_size
309
+
310
+ self.embed_tokens = nn.Embedding(
311
+ config.vocab_size, config.hidden_size, self.padding_idx
312
+ )
313
+
314
+ original_config = AutoConfig.from_pretrained(
315
+ "meta-llama/Llama-3.1-8b", trust_remote_code=True
316
+ )
317
+ self.rotary_emb = LlamaRotaryEmbedding(
318
+ config=original_config,
319
+ )
320
+
321
+ # Create template banks first
322
+ self.mlp_banks = []
323
+ self.attn_banks = []
324
+ layers_per_group = config.num_hidden_layers // config.num_groups
325
+ # Explicitly calculate coef_width if not provided in config
326
+ if hasattr(config, "coef_width") and config.coef_width is not None:
327
+ coef_width = config.coef_width
328
+ else:
329
+ coef_width = config.coef_height * layers_per_group
330
+ config.coef_width = coef_width
331
+ print(
332
+ f"Model config: num_groups={config.num_groups}, layers_per_group={layers_per_group}"
333
+ )
334
+ print(f"Coefficient shape: ({config.coef_height}, {config.coef_width})")
335
+ mlp_banks = nn.ModuleList(
336
+ [
337
+ MLPTemplateBank(
338
+ config=config, coef_rows=config.coef_height, coef_columns=coef_width
339
+ )
340
+ for _ in range(config.num_groups)
341
+ ]
342
+ )
343
+
344
+ attn_banks = nn.ModuleList(
345
+ [
346
+ AttTemplateBank(
347
+ config=config, coef_rows=config.coef_height, coef_columns=coef_width
348
+ )
349
+ for _ in range(config.num_groups)
350
+ ]
351
+ )
352
+ self.mlp_banks = mlp_banks
353
+ self.attn_banks = attn_banks
354
+ # Create layers using LlamaDecoderLayer but replace MLPs
355
+ self.layers = nn.ModuleList()
356
+ for layer_idx in range(config.num_hidden_layers):
357
+ # Create standard LlamaDecoderLayer
358
+ decoder_layer = LlamaDecoderLayer(config, layer_idx)
359
+
360
+ # Replace its MLP with our SharedLlamaMLP
361
+ group_idx = layer_idx // layers_per_group
362
+ decoder_layer.mlp = SharedLlamaMLP(config, self.mlp_banks[group_idx])
363
+ decoder_layer.self_attn = SharedLlamaAttention(
364
+ config, layer_idx, self.attn_banks[group_idx]
365
+ )
366
+
367
+ self.layers.append(decoder_layer)
368
+
369
+ self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
370
+ self.gradient_checkpointing = False
371
+
372
+ def forward(
373
+ self,
374
+ input_ids: torch.LongTensor = None,
375
+ attention_mask: Optional[torch.Tensor] = None,
376
+ position_ids: Optional[torch.LongTensor] = None,
377
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
378
+ inputs_embeds: Optional[torch.FloatTensor] = None,
379
+ use_cache: Optional[bool] = None,
380
+ output_attentions: Optional[bool] = None,
381
+ output_hidden_states: Optional[bool] = None,
382
+ return_dict: Optional[bool] = None,
383
+ cache_position: Optional[torch.LongTensor] = None,
384
+ **flash_attn_kwargs,
385
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
386
+ output_attentions = (
387
+ output_attentions
388
+ if output_attentions is not None
389
+ else self.config.output_attentions
390
+ )
391
+ output_hidden_states = (
392
+ output_hidden_states
393
+ if output_hidden_states is not None
394
+ else self.config.output_hidden_states
395
+ )
396
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
397
+ return_dict = (
398
+ return_dict if return_dict is not None else self.config.use_return_dict
399
+ )
400
+
401
+ if (input_ids is None) ^ (inputs_embeds is not None):
402
+ raise ValueError(
403
+ "You must specify exactly one of input_ids or inputs_embeds"
404
+ )
405
+
406
+ if self.gradient_checkpointing and self.training and use_cache:
407
+ logger.warning_once(
408
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
409
+ )
410
+ use_cache = False
411
+
412
+ if inputs_embeds is None:
413
+ inputs_embeds = self.embed_tokens(input_ids)
414
+ # Set up cache position if not provided
415
+ if cache_position is None:
416
+ past_seen_tokens = (
417
+ 0
418
+ if past_key_values is None
419
+ else (
420
+ past_key_values.get_seq_length()
421
+ if isinstance(past_key_values, Cache)
422
+ else past_key_values[0][0].size(-2) if past_key_values else 0
423
+ )
424
+ )
425
+ cache_position = torch.arange(
426
+ past_seen_tokens,
427
+ past_seen_tokens + inputs_embeds.shape[1],
428
+ device=inputs_embeds.device,
429
+ )
430
+ # Create position embeddings to be shared across the decoder layers
431
+ # Set up position IDs if not provided
432
+ if position_ids is None:
433
+ position_ids = cache_position.unsqueeze(0)
434
+ # Get updated causal mask
435
+ causal_mask = self._update_causal_mask(
436
+ attention_mask,
437
+ inputs_embeds,
438
+ cache_position,
439
+ past_key_values,
440
+ output_attentions,
441
+ )
442
+ hidden_states = inputs_embeds
443
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
444
+
445
+ # Initialize outputs
446
+ all_hidden_states = () if output_hidden_states else None
447
+ all_self_attns = () if output_attentions else None
448
+ next_decoder_cache = None
449
+
450
+ # Process through layers
451
+ for decoder_layer in self.layers:
452
+ if output_hidden_states:
453
+ all_hidden_states += (hidden_states,)
454
+
455
+ if self.gradient_checkpointing and self.training:
456
+ layer_outputs = self._gradient_checkpointing_func(
457
+ decoder_layer.__call__,
458
+ hidden_states,
459
+ causal_mask,
460
+ position_ids,
461
+ past_key_values,
462
+ output_attentions,
463
+ use_cache,
464
+ position_embeddings,
465
+ )
466
+ else:
467
+ layer_outputs = decoder_layer(
468
+ hidden_states,
469
+ attention_mask=causal_mask,
470
+ position_ids=position_ids,
471
+ past_key_value=past_key_values,
472
+ output_attentions=output_attentions,
473
+ use_cache=use_cache,
474
+ position_embeddings=position_embeddings,
475
+ **flash_attn_kwargs,
476
+ )
477
+
478
+ hidden_states = layer_outputs[0]
479
+
480
+ if use_cache:
481
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
482
+
483
+ if output_attentions:
484
+ all_self_attns += (layer_outputs[1],)
485
+
486
+ # Final layer norm
487
+ hidden_states = self.norm(hidden_states)
488
+
489
+ # Add last hidden state
490
+ if output_hidden_states:
491
+ all_hidden_states += (hidden_states,)
492
+
493
+ next_cache = next_decoder_cache if use_cache else None
494
+
495
+ if not return_dict:
496
+ return tuple(
497
+ v
498
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
499
+ if v is not None
500
+ )
501
+
502
+ return BaseModelOutputWithPast(
503
+ last_hidden_state=hidden_states,
504
+ past_key_values=next_cache,
505
+ hidden_states=all_hidden_states,
506
+ attentions=all_self_attns,
507
+ )
508
+
509
+ @classmethod
510
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
511
+ if isinstance(
512
+ pretrained_model_name_or_path, str
513
+ ) and pretrained_model_name_or_path.endswith(".pt"):
514
+ print("Loading from local checkpoint")
515
+ # Load from local checkpoint
516
+ config = kwargs.get("config", None)
517
+ if config is None:
518
+ config = AutoConfig.from_pretrained(
519
+ pretrained_model_name_or_path, trust_remote_code=True
520
+ )
521
+
522
+ model = cls(config)
523
+ checkpoint = torch.load(pretrained_model_name_or_path, map_location="cpu")
524
+ state_dict = checkpoint["model_state_dict"]
525
+ logger.info(
526
+ f"Loaded checkpoint from epoch {checkpoint.get('epoch')} with loss {checkpoint.get('loss')}"
527
+ )
528
+
529
+ missing_keys, unexpected_keys = model.load_state_dict(
530
+ state_dict, strict=False
531
+ )
532
+
533
+ if len(missing_keys) > 0:
534
+ logger.warning(f"Missing keys: {missing_keys}")
535
+ if len(unexpected_keys) > 0:
536
+ logger.warning(f"Unexpected keys: {unexpected_keys}")
537
+
538
+ return model
539
+ else:
540
+ print("Loading from hub")
541
+ # Load from hub using parent's from_pretrained
542
+ return super().from_pretrained(
543
+ pretrained_model_name_or_path, *model_args, **kwargs
544
+ )
545
+
546
+ def get_input_embeddings(self):
547
+ return self.embed_tokens
548
+
549
+ def set_input_embeddings(self, value):
550
+ self.embed_tokens = value
551
+
552
+ def _update_causal_mask(
553
+ self,
554
+ attention_mask: torch.Tensor,
555
+ input_tensor: torch.Tensor,
556
+ cache_position: torch.Tensor,
557
+ past_key_values: Cache,
558
+ output_attentions: bool,
559
+ ):
560
+ if self.config._attn_implementation == "flash_attention_2":
561
+ if attention_mask is not None and 0.0 in attention_mask:
562
+ return attention_mask
563
+ return None
564
+
565
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
566
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
567
+ # to infer the attention mask.
568
+ past_seen_tokens = (
569
+ past_key_values.get_seq_length() if past_key_values is not None else 0
570
+ )
571
+ using_static_cache = isinstance(past_key_values, StaticCache)
572
+
573
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
574
+ if (
575
+ self.config._attn_implementation == "sdpa"
576
+ and not using_static_cache
577
+ and not output_attentions
578
+ ):
579
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
580
+ attention_mask,
581
+ inputs_embeds=input_tensor,
582
+ past_key_values_length=past_seen_tokens,
583
+ is_training=self.training,
584
+ ):
585
+ return None
586
+
587
+ dtype, device = input_tensor.dtype, input_tensor.device
588
+ sequence_length = input_tensor.shape[1]
589
+ if using_static_cache:
590
+ target_length = past_key_values.get_max_cache_shape()
591
+ else:
592
+ target_length = (
593
+ attention_mask.shape[-1]
594
+ if isinstance(attention_mask, torch.Tensor)
595
+ else past_seen_tokens + sequence_length + 1
596
+ )
597
+
598
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
599
+ causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
600
+ attention_mask,
601
+ sequence_length=sequence_length,
602
+ target_length=target_length,
603
+ dtype=dtype,
604
+ device=device,
605
+ cache_position=cache_position,
606
+ batch_size=input_tensor.shape[0],
607
+ )
608
+
609
+ if (
610
+ self.config._attn_implementation == "sdpa"
611
+ and attention_mask is not None
612
+ and attention_mask.device.type == "cuda"
613
+ and not output_attentions
614
+ ):
615
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
616
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
617
+ # Details: https://github.com/pytorch/pytorch/issues/110213
618
+ min_dtype = torch.finfo(dtype).min
619
+ causal_mask = AttentionMaskConverter._unmask_unattended(
620
+ causal_mask, min_dtype
621
+ )
622
+
623
+ return causal_mask
624
+
625
+ @staticmethod
626
+ def _prepare_4d_causal_attention_mask_with_cache_position(
627
+ attention_mask: torch.Tensor,
628
+ sequence_length: int,
629
+ target_length: int,
630
+ dtype: torch.dtype,
631
+ device: torch.device,
632
+ cache_position: torch.Tensor,
633
+ batch_size: int,
634
+ **kwargs,
635
+ ):
636
+ if attention_mask is not None and attention_mask.dim() == 4:
637
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
638
+ causal_mask = attention_mask
639
+ else:
640
+ min_dtype = torch.finfo(dtype).min
641
+ causal_mask = torch.full(
642
+ (sequence_length, target_length),
643
+ fill_value=min_dtype,
644
+ dtype=dtype,
645
+ device=device,
646
+ )
647
+ if sequence_length != 1:
648
+ causal_mask = torch.triu(causal_mask, diagonal=1)
649
+ causal_mask *= torch.arange(
650
+ target_length, device=device
651
+ ) > cache_position.reshape(-1, 1)
652
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
653
+ if attention_mask is not None:
654
+ causal_mask = (
655
+ causal_mask.clone()
656
+ ) # copy to contiguous memory for in-place edit
657
+ mask_length = attention_mask.shape[-1]
658
+ padding_mask = (
659
+ causal_mask[:, :, :, :mask_length]
660
+ + attention_mask[:, None, None, :]
661
+ )
662
+ padding_mask = padding_mask == 0
663
+ causal_mask[:, :, :, :mask_length] = causal_mask[
664
+ :, :, :, :mask_length
665
+ ].masked_fill(padding_mask, min_dtype)
666
+
667
+ return causal_mask
668
+
669
+
670
+ class RECAST8b_LlamaForCausalLM(PreTrainedModel, GenerationMixin):
671
+ _tied_weights_keys = ["lm_head.weight"]
672
+ _tp_plan = {"lm_head": "colwise_rep"}
673
+ config_class = RECAST8b_llama
674
+ base_model_prefix = "llama"
675
+ supports_gradient_checkpointing = True
676
+
677
+ def __init__(self, config):
678
+ super().__init__(config)
679
+ self.model = RECAST8b_llamaModel(config)
680
+ self.vocab_size = config.vocab_size
681
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
682
+
683
+ # Initialize weights and apply final processing
684
+ self.post_init()
685
+
686
+ def get_input_embeddings(self):
687
+ return self.model.embed_tokens
688
+
689
+ def set_input_embeddings(self, value):
690
+ self.model.embed_tokens = value
691
+
692
+ def get_output_embeddings(self):
693
+ return self.lm_head
694
+
695
+ def set_output_embeddings(self, new_embeddings):
696
+ self.lm_head = new_embeddings
697
+
698
+ def set_decoder(self, decoder):
699
+ self.model = decoder
700
+
701
+ def get_decoder(self):
702
+ return self.model
703
+
704
+ def loss_function(
705
+ self,
706
+ logits,
707
+ labels,
708
+ vocab_size: int,
709
+ num_items_in_batch: int = None,
710
+ ignore_index: int = -100,
711
+ **kwargs,
712
+ ):
713
+ # Upcast to float if we need to compute the loss to avoid potential precision issues
714
+ logits = logits.float()
715
+ # Shift so that tokens < n predict n
716
+ shift_logits = logits[..., :-1, :].contiguous()
717
+ shift_labels = labels[..., 1:].contiguous()
718
+ # Flatten the tokens
719
+ shift_logits = shift_logits.view(-1, vocab_size)
720
+ shift_labels = shift_labels.view(-1)
721
+ # Enable model parallelism
722
+ shift_labels = shift_labels.to(shift_logits.device)
723
+ loss = fixed_cross_entropy(
724
+ shift_logits, shift_labels, num_items_in_batch, ignore_index, **kwargs
725
+ )
726
+ return loss
727
+
728
+ def forward(
729
+ self,
730
+ input_ids: torch.LongTensor = None,
731
+ attention_mask: Optional[torch.Tensor] = None,
732
+ position_ids: Optional[torch.LongTensor] = None,
733
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
734
+ inputs_embeds: Optional[torch.FloatTensor] = None,
735
+ labels: Optional[torch.LongTensor] = None,
736
+ use_cache: Optional[bool] = None,
737
+ output_attentions: Optional[bool] = None,
738
+ output_hidden_states: Optional[bool] = None,
739
+ return_dict: Optional[bool] = None,
740
+ cache_position: Optional[torch.LongTensor] = None,
741
+ num_logits_to_keep: int = 0,
742
+ **kwargs,
743
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
744
+ """
745
+ Args:
746
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
747
+ Labels for computing the masked language modeling loss. Indices should be in
748
+ `[0, ..., config.vocab_size]` or -100 (masked tokens).
749
+ num_logits_to_keep (`int`, *optional*):
750
+ Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate all logits.
751
+ """
752
+ output_attentions = (
753
+ output_attentions
754
+ if output_attentions is not None
755
+ else self.config.output_attentions
756
+ )
757
+ output_hidden_states = (
758
+ output_hidden_states
759
+ if output_hidden_states is not None
760
+ else self.config.output_hidden_states
761
+ )
762
+ return_dict = (
763
+ return_dict if return_dict is not None else self.config.use_return_dict
764
+ )
765
+
766
+ outputs = self.model(
767
+ input_ids=input_ids,
768
+ attention_mask=attention_mask,
769
+ position_ids=position_ids,
770
+ past_key_values=past_key_values,
771
+ inputs_embeds=inputs_embeds,
772
+ use_cache=use_cache,
773
+ output_attentions=output_attentions,
774
+ output_hidden_states=output_hidden_states,
775
+ return_dict=return_dict,
776
+ cache_position=cache_position,
777
+ **kwargs,
778
+ )
779
+
780
+ hidden_states = outputs[0]
781
+ # Only compute necessary logits
782
+ logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
783
+
784
+ loss = None
785
+ if labels is not None:
786
+ # Calculate batch size for loss function
787
+ num_items_in_batch = (
788
+ input_ids.size(0) if input_ids is not None else inputs_embeds.size(0)
789
+ )
790
+ loss = self.loss_function(
791
+ logits=logits,
792
+ labels=labels,
793
+ vocab_size=self.config.vocab_size,
794
+ num_items_in_batch=num_items_in_batch,
795
+ **kwargs,
796
+ )
797
+
798
+ if not return_dict:
799
+ output = (logits,) + outputs[1:]
800
+ return (loss,) + output if loss is not None else output
801
+
802
+ return CausalLMOutputWithPast(
803
+ loss=loss,
804
+ logits=logits,
805
+ past_key_values=outputs.past_key_values,
806
+ hidden_states=outputs.hidden_states,
807
+ attentions=outputs.attentions,
808
+ )
809
+
810
+ def prepare_inputs_for_generation(
811
+ self,
812
+ input_ids,
813
+ past_key_values=None,
814
+ attention_mask=None,
815
+ inputs_embeds=None,
816
+ **kwargs,
817
+ ):
818
+ if past_key_values:
819
+ input_ids = input_ids[:, -1:]
820
+
821
+ position_ids = kwargs.get("position_ids", None)
822
+ if attention_mask is not None and position_ids is None:
823
+ # create position_ids on the fly for batch generation
824
+ position_ids = attention_mask.long().cumsum(-1) - 1
825
+ position_ids.masked_fill_(attention_mask == 0, 1)
826
+ if past_key_values:
827
+ position_ids = position_ids[:, -1].unsqueeze(-1)
828
+
829
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
830
+ if inputs_embeds is not None and past_key_values is None:
831
+ model_inputs = {"inputs_embeds": inputs_embeds}
832
+ else:
833
+ model_inputs = {"input_ids": input_ids}
834
+
835
+ model_inputs.update(
836
+ {
837
+ "position_ids": position_ids,
838
+ "past_key_values": past_key_values,
839
+ "use_cache": kwargs.get("use_cache"),
840
+ "attention_mask": attention_mask,
841
+ }
842
+ )
843
+ return model_inputs
844
+
845
+ @classmethod
846
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
847
+ if isinstance(
848
+ pretrained_model_name_or_path, str
849
+ ) and pretrained_model_name_or_path.endswith(".pt"):
850
+ print("Loading from local checkpoint")
851
+ config = kwargs.get("config", None)
852
+ if config is None:
853
+ config = AutoConfig.from_pretrained(
854
+ pretrained_model_name_or_path, trust_remote_code=True
855
+ )
856
+ model = torch.load(pretrained_model_name_or_path, map_location="cpu")
857
+ # model = cls(config)
858
+ # checkpoint = torch.load(pretrained_model_name_or_path, map_location="cpu")
859
+ # state_dict = checkpoint["model_state_dict"]
860
+
861
+ # missing_keys, unexpected_keys = model.load_state_dict(
862
+ # state_dict, strict=False
863
+ # )
864
+
865
+ # if len(missing_keys) > 0:
866
+ # logger.warning(f"Missing keys: {missing_keys}")
867
+ # if len(unexpected_keys) > 0:
868
+ # logger.warning(f"Unexpected keys: {unexpected_keys}")
869
+
870
+ return model
871
+ else:
872
+ print("Loading from hub")
873
+ return super().from_pretrained(
874
+ pretrained_model_name_or_path, *model_args, **kwargs
875
+ )