paulpak58 commited on
Commit
c27bc60
·
verified ·
1 Parent(s): 6b79710

Bump transformers source v4.54.0.dev0

Browse files
Files changed (3) hide show
  1. config.json +1 -5
  2. modeling_lfm2.py +0 -924
  3. requirements.txt +0 -2
config.json CHANGED
@@ -42,9 +42,5 @@
42
  "transformers_version": "4.53.0.dev0",
43
  "use_cache": true,
44
  "use_pos_enc": true,
45
- "vocab_size": 65536,
46
- "auto_map": {
47
- "AutoConfig": "modeling_lfm2.LFM2Config",
48
- "AutoModelForCausalLM": "modeling_lfm2.LFM2ForCausalLM"
49
- }
50
  }
 
42
  "transformers_version": "4.53.0.dev0",
43
  "use_cache": true,
44
  "use_pos_enc": true,
45
+ "vocab_size": 65536
 
 
 
 
46
  }
modeling_lfm2.py DELETED
@@ -1,924 +0,0 @@
1
- from typing import Any, Callable, ClassVar, Optional, Union
2
-
3
- import torch
4
- import torch.nn as nn
5
- import torch.nn.functional as F
6
- from transformers.cache_utils import DynamicCache
7
- from transformers.configuration_utils import PretrainedConfig
8
- from transformers.generation import GenerationMixin
9
- from transformers.masking_utils import create_causal_mask
10
- from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
11
- from transformers.modeling_layers import GradientCheckpointingLayer
12
- from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
13
- from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
14
- from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
15
- from transformers.processing_utils import Unpack
16
- from transformers.utils import LossKwargs, auto_docstring, can_return_tuple, logging
17
- from transformers.utils.import_utils import is_causal_conv1d_available
18
-
19
- if is_causal_conv1d_available():
20
- from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
21
- else:
22
- causal_conv1d_fn, causal_conv1d_update = None, None
23
-
24
-
25
- kernel_modules = (causal_conv1d_fn, causal_conv1d_update)
26
- is_fast_path_available = all(kernel_modules)
27
-
28
- logger = logging.get_logger(__name__)
29
-
30
-
31
- # ========================================================
32
- # Config Class (to be removed) once integrated into
33
- # `transformers`. For now, allows for dynamic importing.
34
- # ========================================================s
35
- # from .configuration_lfm2 import LFM2Config
36
-
37
-
38
- class LFM2Config(PretrainedConfig):
39
- model_type = "lfm2"
40
- keys_to_ignore_at_inference: ClassVar = ["past_key_values"]
41
-
42
- def __init__(
43
- self,
44
- vocab_size: int = 65536,
45
- hidden_size: int = 2560,
46
- num_hidden_layers: int = 32,
47
- pad_token_id: int = 0,
48
- bos_token_id: int = 1,
49
- eos_token_id: int = 2,
50
- tie_embedding: bool = True,
51
- theta: float = 1000000.0,
52
- max_position_embeddings: int = 128_000,
53
- use_cache: bool = True,
54
- norm_eps: float = 0.00001,
55
- initializer_range: float = 0.02,
56
- num_attention_heads: int = 32,
57
- num_key_value_heads: int = 8,
58
- conv_bias: bool = False,
59
- conv_dim: int = 2560,
60
- conv_L_cache: int = 3,
61
- block_dim: int = 2560,
62
- block_ff_dim: int = 12288,
63
- block_multiple_of: int = 256,
64
- block_ffn_dim_multiplier: float = 1.0,
65
- block_auto_adjust_ff_dim: bool = True,
66
- full_attn_idxs: Optional[list[int]] = None,
67
- **kwargs,
68
- ):
69
- self.vocab_size = vocab_size
70
- self.hidden_size = hidden_size
71
- self.num_hidden_layers = num_hidden_layers
72
- self.rope_theta = theta
73
- self.max_position_embeddings = max_position_embeddings
74
- self.use_cache = use_cache
75
- self.norm_eps = norm_eps
76
- self.initializer_range = initializer_range
77
-
78
- # attn operator config
79
- self.num_attention_heads = num_attention_heads
80
- self.num_key_value_heads = num_key_value_heads
81
- self.full_attn_idxs = full_attn_idxs
82
-
83
- # custom operator config
84
- self.conv_bias = conv_bias
85
- self.conv_dim = conv_dim
86
- self.conv_L_cache = conv_L_cache
87
-
88
- # block config
89
- self.block_dim = block_dim
90
- self.block_ff_dim = block_ff_dim
91
- self.block_multiple_of = block_multiple_of
92
- self.block_ffn_dim_multiplier = block_ffn_dim_multiplier
93
- self.block_auto_adjust_ff_dim = block_auto_adjust_ff_dim
94
-
95
- super().__init__(
96
- pad_token_id=pad_token_id,
97
- bos_token_id=bos_token_id,
98
- eos_token_id=eos_token_id,
99
- tie_word_embeddings=tie_embedding,
100
- **kwargs,
101
- )
102
-
103
- @property
104
- def layers_block_type(self):
105
- return ["attention" if i in self.full_attn_idxs else "conv" for i in range(self.num_hidden_layers)]
106
-
107
-
108
- class LFM2RMSNorm(torch.nn.Module):
109
- def __init__(self, dim: int, eps: float = 1e-6):
110
- super().__init__()
111
- self.eps = eps
112
- self.weight = nn.Parameter(torch.ones(dim))
113
-
114
- def _norm(self, x):
115
- return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
116
-
117
- def forward(self, x):
118
- output = self._norm(x.float())
119
- return output.type_as(x) * self.weight
120
-
121
-
122
- def rotate_half(x):
123
- """Rotates half the hidden dims of the input."""
124
- x1 = x[..., : x.shape[-1] // 2]
125
- x2 = x[..., x.shape[-1] // 2 :]
126
- return torch.cat((-x2, x1), dim=-1)
127
-
128
-
129
- def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
130
- """Applies Rotary Position Embedding to the query and key tensors."""
131
- cos = cos.unsqueeze(unsqueeze_dim)
132
- sin = sin.unsqueeze(unsqueeze_dim)
133
- q_embed = (q * cos) + (rotate_half(q) * sin)
134
- k_embed = (k * cos) + (rotate_half(k) * sin)
135
- return q_embed, k_embed
136
-
137
-
138
- class LFM2RotaryEmbedding(nn.Module):
139
- def __init__(self, config: LFM2Config, device=None):
140
- super().__init__()
141
- # BC: "rope_type" was originally "type"
142
- if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
143
- self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
144
- else:
145
- self.rope_type = "default"
146
- self.max_seq_len_cached = config.max_position_embeddings
147
- self.original_max_seq_len = config.max_position_embeddings
148
-
149
- self.config = config
150
- self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
151
-
152
- inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
153
- self.register_buffer("inv_freq", inv_freq, persistent=False)
154
- self.original_inv_freq = self.inv_freq
155
-
156
- @torch.no_grad()
157
- @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
158
- def forward(self, x, position_ids):
159
- inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
160
- position_ids_expanded = position_ids[:, None, :].float()
161
-
162
- device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
163
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
164
- freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
165
- emb = torch.cat((freqs, freqs), dim=-1)
166
- cos = emb.cos() * self.attention_scaling
167
- sin = emb.sin() * self.attention_scaling
168
-
169
- return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
170
-
171
-
172
- def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
173
- """
174
- This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
175
- num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
176
- """
177
- batch, num_key_value_heads, slen, head_dim = hidden_states.shape
178
- if n_rep == 1:
179
- return hidden_states
180
- hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
181
- return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
182
-
183
-
184
- def eager_attention_forward(
185
- module: nn.Module,
186
- query: torch.Tensor,
187
- key: torch.Tensor,
188
- value: torch.Tensor,
189
- attention_mask: Optional[torch.Tensor],
190
- scaling: float,
191
- dropout: float = 0.0,
192
- **kwargs,
193
- ):
194
- num_key_value_groups = query.shape[1] // key.shape[1]
195
- key_states = repeat_kv(key, num_key_value_groups)
196
- value_states = repeat_kv(value, num_key_value_groups)
197
-
198
- attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
199
- if attention_mask is not None:
200
- causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
201
- attn_weights = attn_weights + causal_mask
202
- else:
203
- seq_len = key_states.shape[-2]
204
- causal_mask = torch.triu(
205
- torch.full((seq_len, seq_len), float("-inf"), device=attn_weights.device),
206
- diagonal=1,
207
- )
208
- attn_weights = attn_weights + causal_mask
209
-
210
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
211
- attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
212
- attn_output = torch.matmul(attn_weights, value_states)
213
- attn_output = attn_output.transpose(1, 2).contiguous()
214
-
215
- return attn_output, attn_weights
216
-
217
-
218
- class LFM2MLP(nn.Module):
219
- def __init__(
220
- self,
221
- dim: int,
222
- ff_dim: int,
223
- multiple_of: int,
224
- auto_adjust_ff_dim: bool,
225
- ffn_dim_multiplier: Optional[float],
226
- ):
227
- super().__init__()
228
- if auto_adjust_ff_dim:
229
- ff_dim = int(2 * ff_dim / 3)
230
- # custom dim factor multiplier
231
- if ffn_dim_multiplier is not None:
232
- ff_dim = int(ffn_dim_multiplier * ff_dim)
233
- ff_dim = multiple_of * ((ff_dim + multiple_of - 1) // multiple_of)
234
-
235
- self.w1 = nn.Linear(dim, ff_dim, bias=False)
236
- self.w3 = nn.Linear(dim, ff_dim, bias=False)
237
- self.w2 = nn.Linear(ff_dim, dim, bias=False)
238
-
239
- def forward(self, x):
240
- return self.w2(F.silu(self.w1(x)) * self.w3(x))
241
-
242
-
243
- class LFM2Cache(DynamicCache):
244
- """
245
- Attention and conv cache for LFM2.
246
-
247
- It stores the Key and Value states as a list of tensors, one for each layer.
248
- Attention layer cache shape: `[batch_size, num_heads, seq_len, head_dim]`.
249
- Conv layer cache shape: `[batch_size, conv_dim, L_cache-1]`.
250
- """
251
-
252
- def __init__(
253
- self,
254
- config: LFM2Config,
255
- max_batch_size: int,
256
- dtype: torch.dtype = torch.float32,
257
- device: Union[torch.device, str, None] = None,
258
- ):
259
- super().__init__() # initialize key and value cache
260
- self.max_batch_size = max_batch_size
261
- self.full_attn_idxs = config.full_attn_idxs
262
- self.conv_L_cache = config.conv_L_cache
263
- self._dtype = dtype
264
-
265
- self.conv_cache: list[torch.Tensor] = []
266
- device = torch.device(device) if device is not None else None
267
-
268
- for _ in range(config.num_hidden_layers):
269
- conv_state = torch.zeros(
270
- self.max_batch_size,
271
- config.conv_dim,
272
- self.conv_L_cache,
273
- dtype=self._dtype,
274
- device=device,
275
- )
276
- torch._dynamo.mark_static_address(conv_state)
277
- self.conv_cache.append(conv_state)
278
-
279
- def update(
280
- self,
281
- key_states: torch.Tensor,
282
- value_states: torch.Tensor,
283
- layer_idx: int,
284
- cache_kwargs: Optional[dict[str, Any]] = None,
285
- ) -> tuple[torch.Tensor, torch.Tensor]:
286
- """
287
- Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
288
-
289
- Parameters:
290
- key_states (`torch.Tensor`):
291
- The new key states to cache.
292
- value_states (`torch.Tensor`):
293
- The new value states to cache.
294
- layer_idx (`int`):
295
- The index of the layer to cache the states for.
296
- cache_kwargs (`Dict[str, Any]`, `optional`):
297
- Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.
298
-
299
- Return:
300
- A tuple containing the updated key and value states.
301
- """
302
- # Update the number of seen tokens
303
- # if layer_idx == 0:
304
- if layer_idx == self.full_attn_idxs[0]:
305
- self._seen_tokens += key_states.shape[-2]
306
-
307
- # Update the cache
308
- if key_states is not None:
309
- if len(self.key_cache) <= layer_idx:
310
- # There may be skipped layers, fill them with empty lists
311
- for _ in range(len(self.key_cache), layer_idx):
312
- self.key_cache.append(torch.tensor([]))
313
- self.value_cache.append(torch.tensor([]))
314
- self.key_cache.append(key_states)
315
- self.value_cache.append(value_states)
316
- elif (
317
- not self.key_cache[layer_idx].numel() # prefers not t.numel() to len(t) == 0 to export the model
318
- ): # fills previously skipped layers; checking for tensor causes errors
319
- self.key_cache[layer_idx] = key_states
320
- self.value_cache[layer_idx] = value_states
321
- else:
322
- self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
323
- self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
324
-
325
- return self.key_cache[layer_idx], self.value_cache[layer_idx]
326
-
327
- def reorder_cache(self, beam_idx: torch.LongTensor):
328
- """Reorders the cache for beam search, given the selected beam indices."""
329
- for layer_idx in range(len(self.key_cache)):
330
- device = self.key_cache[layer_idx].device
331
- self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
332
- device = self.value_cache[layer_idx].device
333
- self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
334
-
335
- device = self.conv_cache[layer_idx].device
336
- self.conv_cache[layer_idx] = self.conv_cache[layer_idx].index_select(0, beam_idx.to(device))
337
-
338
- def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
339
- """Returns the sequence length of the cached states. A layer index can be optionally passed."""
340
- # take any layer that contains cache and not empty tensor
341
- layer_idx = self.full_attn_idxs[0] if layer_idx not in self.full_attn_idxs else layer_idx
342
- if len(self.key_cache) <= layer_idx or self.key_cache[layer_idx].numel() == 0:
343
- return 0
344
- return self.key_cache[layer_idx].shape[-2]
345
-
346
- def reset(self):
347
- for layer_idx in range(len(self.conv_cache)):
348
- # In-place ops prevent breaking the static address
349
- self.conv_cache[layer_idx].zero_()
350
-
351
-
352
- class LFM2Attention(nn.Module):
353
- def __init__(self, config: LFM2Config, layer_idx: Optional[int] = None, **kwargs):
354
- super().__init__()
355
- self.config = config
356
- self.layer_idx = layer_idx
357
- if layer_idx is None:
358
- logger.warning_once(
359
- f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and "
360
- "will lead to errors during the forward call if caching is used. Please make sure to provide a "
361
- "`layer_idx` when creating this class."
362
- )
363
- self.head_dim = config.hidden_size // config.num_attention_heads
364
- self.num_key_value_heads = config.num_key_value_heads
365
- self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
366
- self.scaling = self.head_dim**-0.5
367
- self.is_causal = True
368
-
369
- self.q_layernorm = LFM2RMSNorm(self.head_dim, eps=config.norm_eps)
370
- self.k_layernorm = LFM2RMSNorm(self.head_dim, eps=config.norm_eps)
371
-
372
- self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False)
373
- self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
374
- self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
375
- self.out_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
376
-
377
- def forward(
378
- self,
379
- hidden_states: torch.Tensor,
380
- position_embeddings: tuple[torch.Tensor, torch.Tensor],
381
- attention_mask: Optional[torch.Tensor],
382
- past_key_value: Optional[LFM2Cache] = None,
383
- cache_position: Optional[torch.LongTensor] = None,
384
- **kwargs,
385
- ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
386
- input_shape = hidden_states.shape[:-1]
387
- hidden_shape = (*input_shape, -1, self.head_dim)
388
-
389
- q = self.q_layernorm(self.q_proj(hidden_states).view(*hidden_shape)).transpose(1, 2)
390
- k = self.k_layernorm(self.k_proj(hidden_states).view(*hidden_shape)).transpose(1, 2)
391
- v = self.v_proj(hidden_states).view(*hidden_shape).transpose(1, 2)
392
-
393
- cos, sin = position_embeddings
394
- q, k = apply_rotary_pos_emb(q, k, cos, sin)
395
-
396
- if past_key_value is not None:
397
- cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
398
- k, v = past_key_value.update(
399
- key_states=k, value_states=v, layer_idx=self.layer_idx, cache_kwargs=cache_kwargs
400
- )
401
-
402
- attention_interface: Callable = eager_attention_forward
403
- if self.config._attn_implementation != "eager":
404
- attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
405
-
406
- attn_output, attn_weights = attention_interface(
407
- self,
408
- q,
409
- k,
410
- v,
411
- attention_mask,
412
- dropout=0.0,
413
- scaling=self.scaling,
414
- **kwargs,
415
- )
416
- attn_output = attn_output.reshape(*input_shape, -1).contiguous()
417
- output = self.out_proj(attn_output)
418
- return output, attn_weights
419
-
420
-
421
- class LFM2ShortConv(nn.Module):
422
- def __init__(
423
- self,
424
- config: LFM2Config,
425
- dim: int,
426
- layer_idx: int,
427
- ):
428
- super().__init__()
429
- self.config = config
430
- self.layer_idx = layer_idx
431
- self.L_cache = config.conv_L_cache
432
- self.bias = config.conv_bias
433
-
434
- self.conv = nn.Conv1d(
435
- in_channels=dim,
436
- out_channels=dim,
437
- kernel_size=self.L_cache,
438
- groups=dim,
439
- bias=self.bias,
440
- padding=self.L_cache - 1,
441
- )
442
- self.in_proj = nn.Linear(dim, 3 * dim, bias=self.bias)
443
- self.out_proj = nn.Linear(dim, dim, bias=self.bias)
444
-
445
- def cuda_kernels_forward(
446
- self,
447
- x: torch.Tensor,
448
- cache_params: Optional[LFM2Cache] = None,
449
- cache_position: Optional[torch.LongTensor] = None,
450
- attention_mask: Optional[torch.Tensor] = None,
451
- ):
452
- BCx = self.in_proj(x).transpose(-1, -2)
453
- B, C, x = BCx.chunk(3, dim=-2)
454
-
455
- Bx = B * x
456
-
457
- conv_weights = self.conv.weight.view(self.conv.weight.size(0), self.conv.weight.size(2))
458
- if cache_params is not None and cache_position[0] > 0:
459
- conv_out = causal_conv1d_update(
460
- Bx.squeeze(-1),
461
- cache_params.conv_cache[self.layer_idx],
462
- conv_weights,
463
- self.conv.bias,
464
- None,
465
- )
466
- conv_out = conv_out.unsqueeze(-1)
467
- else:
468
- if cache_params is not None:
469
- conv_state = nn.functional.pad(Bx, (self.L_cache - Bx.shape[-1], 0))
470
- cache_params.conv_cache[self.layer_idx].copy_(conv_state)
471
-
472
- conv_out = causal_conv1d_fn(Bx, conv_weights, self.conv.bias, activation=None)
473
-
474
- y = C * conv_out
475
- y = self.out_proj(y.transpose(-1, -2).contiguous())
476
- return y
477
-
478
- def slow_forward(
479
- self,
480
- x: torch.Tensor,
481
- cache_params: Optional[LFM2Cache] = None,
482
- cache_position: Optional[torch.LongTensor] = None,
483
- attention_mask: Optional[torch.Tensor] = None,
484
- ):
485
- seqlen = x.shape[1]
486
- BCx = self.in_proj(x).transpose(-1, -2)
487
- B, C, x = BCx.chunk(3, dim=-2)
488
-
489
- Bx = B * x
490
-
491
- if cache_params is not None and cache_position[0] > 0:
492
- conv_state = cache_params.conv_cache[self.layer_idx]
493
- cache_position = cache_position.clamp(0, self.L_cache - 1)
494
- conv_state = conv_state.roll(shifts=-1, dims=-1)
495
- conv_state[:, :, cache_position] = Bx.to(device=conv_state.device, dtype=conv_state.dtype)
496
- cache_params.conv_cache[self.layer_idx].copy_(conv_state)
497
- conv_out = torch.sum(conv_state.to(Bx.device) * self.conv.weight[:, 0, :], dim=-1)
498
- if self.bias:
499
- conv_out += self.conv.bias
500
-
501
- conv_out = conv_out.unsqueeze(-1)
502
- else:
503
- if cache_params is not None:
504
- conv_state = nn.functional.pad(Bx, (self.L_cache - Bx.shape[-1], 0))
505
- cache_params.conv_cache[self.layer_idx].copy_(conv_state)
506
-
507
- conv_out = self.conv(Bx)[..., :seqlen]
508
-
509
- y = C * conv_out
510
- y = y.transpose(-1, -2).contiguous()
511
- y = self.out_proj(y)
512
- return y
513
-
514
- def forward(
515
- self,
516
- x: torch.Tensor,
517
- cache_params: Optional[LFM2Cache] = None,
518
- cache_position: Optional[torch.LongTensor] = None,
519
- attention_mask: Optional[torch.Tensor] = None,
520
- ):
521
- if is_fast_path_available and "cuda" in x.device.type and not torch._dynamo.is_compiling():
522
- return self.cuda_kernels_forward(x, cache_params, cache_position, attention_mask)
523
- return self.slow_forward(x, cache_params, cache_position, attention_mask)
524
-
525
-
526
- class LFM2AttentionDecoderLayer(GradientCheckpointingLayer):
527
- def __init__(self, config: LFM2Config, layer_idx: int):
528
- super().__init__()
529
- self.self_attn = LFM2Attention(config, layer_idx)
530
- self.feed_forward = LFM2MLP(
531
- dim=config.block_dim,
532
- ff_dim=config.block_ff_dim,
533
- multiple_of=config.block_multiple_of,
534
- auto_adjust_ff_dim=config.block_auto_adjust_ff_dim,
535
- ffn_dim_multiplier=config.block_ffn_dim_multiplier,
536
- )
537
- self.operator_norm = LFM2RMSNorm(config.hidden_size, eps=config.norm_eps)
538
- self.ffn_norm = LFM2RMSNorm(config.hidden_size, eps=config.norm_eps)
539
-
540
- def forward(
541
- self,
542
- hidden_states: torch.Tensor,
543
- position_embeddings: tuple[torch.Tensor, torch.Tensor],
544
- attention_mask: Optional[torch.Tensor] = None,
545
- position_ids: Optional[torch.LongTensor] = None,
546
- past_key_value: Optional[tuple[torch.Tensor]] = None,
547
- output_attentions: Optional[bool] = False,
548
- cache_position: Optional[torch.LongTensor] = None,
549
- **kwargs,
550
- ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
551
- h, self_attn_weights = self.self_attn(
552
- hidden_states=self.operator_norm(hidden_states),
553
- position_embeddings=position_embeddings,
554
- attention_mask=attention_mask,
555
- position_ids=position_ids,
556
- past_key_value=past_key_value,
557
- cache_position=cache_position,
558
- **kwargs,
559
- )
560
- h += hidden_states
561
- out = h + self.feed_forward.forward(self.ffn_norm(h))
562
-
563
- outputs = (out,)
564
- if output_attentions:
565
- outputs += (self_attn_weights,)
566
-
567
- return outputs
568
-
569
-
570
- class LFM2ShortConvDecoderLayer(GradientCheckpointingLayer):
571
- def __init__(self, config: LFM2Config, layer_idx: int):
572
- super().__init__()
573
- self.conv = LFM2ShortConv(
574
- config=config,
575
- dim=config.conv_dim,
576
- layer_idx=layer_idx,
577
- )
578
- self.feed_forward = LFM2MLP(
579
- dim=config.block_dim,
580
- ff_dim=config.block_ff_dim,
581
- multiple_of=config.block_multiple_of,
582
- auto_adjust_ff_dim=config.block_auto_adjust_ff_dim,
583
- ffn_dim_multiplier=config.block_ffn_dim_multiplier,
584
- )
585
- self.operator_norm = LFM2RMSNorm(config.hidden_size, eps=config.norm_eps)
586
- self.ffn_norm = LFM2RMSNorm(config.hidden_size, eps=config.norm_eps)
587
-
588
- def forward(
589
- self,
590
- hidden_states: torch.Tensor,
591
- past_key_value: Optional[LFM2Cache] = None,
592
- cache_position: Optional[torch.LongTensor] = None,
593
- attention_mask: Optional[torch.Tensor] = None,
594
- output_attentions: Optional[bool] = False,
595
- **kwargs,
596
- ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
597
- h = self.conv(
598
- self.operator_norm(hidden_states),
599
- cache_params=past_key_value,
600
- cache_position=cache_position,
601
- attention_mask=attention_mask,
602
- )
603
- self_attn_weights = None
604
-
605
- h += hidden_states
606
- out = h + self.feed_forward.forward(self.ffn_norm(h))
607
-
608
- outputs = (out,)
609
- if output_attentions:
610
- outputs += (self_attn_weights,)
611
-
612
- return outputs
613
-
614
-
615
- @auto_docstring
616
- class LFM2PretrainedModel(PreTrainedModel):
617
- config_class = LFM2Config
618
- base_model_prefix = "model"
619
- supports_gradient_checkpointing = True
620
- _no_split_modules: ClassVar = ["LFM2AttentionDecoderLayer", "LFM2ShortConvDecoderLayer"]
621
- _skip_keys_device_placement = "past_key_values"
622
- _supports_flash_attn_2 = True
623
- _supports_sdpa = True
624
- _supports_flex_attn = True
625
- _supports_cache_class = True
626
- _supports_quantized_cache = True
627
- _supports_static_cache = True
628
- _supports_attention_backend = True
629
-
630
- def _init_weights(self, module):
631
- std = self.config.initializer_range
632
- if isinstance(module, (nn.Linear, nn.Conv1d)):
633
- module.weight.data.normal_(mean=0.0, std=std)
634
- if module.bias is not None:
635
- module.bias.data.zero_()
636
- elif isinstance(module, nn.Embedding):
637
- module.weight.data.normal_(mean=0.0, std=std)
638
- if module.padding_idx is not None:
639
- module.weight.data[module.padding_idx].zero_()
640
- elif isinstance(module, LFM2RMSNorm):
641
- module.weight.data.fill_(1.0)
642
-
643
-
644
- class LFM2Model(LFM2PretrainedModel):
645
- def __init__(self, config: LFM2Config):
646
- super().__init__(config)
647
- self.padding_idx = config.pad_token_id
648
- self.vocab_size = config.vocab_size
649
-
650
- self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
651
-
652
- self.pos_emb = LFM2RotaryEmbedding(config)
653
-
654
- decoder_layers = []
655
- for i in range(config.num_hidden_layers):
656
- if i in config.full_attn_idxs:
657
- decoder_layers.append(LFM2AttentionDecoderLayer(config, layer_idx=i))
658
- else:
659
- decoder_layers.append(LFM2ShortConvDecoderLayer(config, layer_idx=i))
660
- self.layers = nn.ModuleList(decoder_layers)
661
-
662
- self.embedding_norm = LFM2RMSNorm(config.hidden_size, eps=config.norm_eps)
663
-
664
- self.gradient_checkpointing = False
665
-
666
- # Initialize weights and apply final processing
667
- self.post_init()
668
-
669
- def get_input_embeddings(self):
670
- return self.embed_tokens
671
-
672
- def set_input_embeddings(self, value):
673
- self.embed_tokens = value
674
-
675
- @can_return_tuple
676
- @auto_docstring
677
- def forward(
678
- self,
679
- input_ids: torch.LongTensor = None,
680
- attention_mask: Optional[torch.Tensor] = None,
681
- position_ids: Optional[torch.LongTensor] = None,
682
- past_key_values: Optional[LFM2Cache] = None,
683
- inputs_embeds: Optional[torch.FloatTensor] = None,
684
- use_cache: Optional[bool] = None,
685
- output_attentions: Optional[bool] = None,
686
- output_hidden_states: Optional[bool] = None,
687
- return_dict: Optional[bool] = None,
688
- cache_position: Optional[torch.LongTensor] = None,
689
- **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
690
- ) -> BaseModelOutputWithPast:
691
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
692
- output_hidden_states = (
693
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
694
- )
695
- use_cache = use_cache if use_cache is not None else self.config.use_cache
696
-
697
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
698
-
699
- if (input_ids is None) ^ (inputs_embeds is not None):
700
- raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
701
-
702
- if self.gradient_checkpointing and self.training and use_cache:
703
- logger.warning_once(
704
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
705
- )
706
- use_cache = False
707
-
708
- if inputs_embeds is None:
709
- inputs_embeds = self.embed_tokens(input_ids)
710
-
711
- if use_cache and past_key_values is None:
712
- batch_size = inputs_embeds.shape[0]
713
- past_key_values = LFM2Cache(
714
- config=self.config, max_batch_size=batch_size, dtype=self.dtype, device=self.device
715
- )
716
-
717
- if cache_position is None:
718
- past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
719
- cache_position = torch.arange(
720
- past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
721
- )
722
-
723
- if position_ids is None:
724
- position_ids = cache_position.unsqueeze(0)
725
-
726
- causal_mask = create_causal_mask(
727
- config=self.config,
728
- input_embeds=inputs_embeds,
729
- attention_mask=attention_mask,
730
- cache_position=cache_position,
731
- past_key_values=past_key_values,
732
- )
733
- hidden_states = inputs_embeds
734
-
735
- position_embeddings = self.pos_emb(hidden_states, position_ids)
736
-
737
- # decoder layers
738
- all_hidden_states = () if output_hidden_states else None
739
- all_self_attns = () if output_attentions else None
740
- for decoder_layer in self.layers:
741
- if output_hidden_states:
742
- all_hidden_states += (hidden_states,)
743
-
744
- layer_outputs = decoder_layer(
745
- hidden_states,
746
- attention_mask=causal_mask,
747
- position_ids=position_ids,
748
- past_key_value=past_key_values,
749
- output_attentions=output_attentions,
750
- use_cache=use_cache,
751
- cache_position=cache_position,
752
- position_embeddings=position_embeddings,
753
- **flash_attn_kwargs,
754
- )
755
-
756
- hidden_states = layer_outputs[0]
757
-
758
- if output_attentions:
759
- all_self_attns += (layer_outputs[1],)
760
-
761
- hidden_states = self.embedding_norm(hidden_states)
762
-
763
- # add hidden states from the last decoder layer
764
- if output_hidden_states:
765
- all_hidden_states += (hidden_states,)
766
-
767
- output = BaseModelOutputWithPast(
768
- last_hidden_state=hidden_states,
769
- past_key_values=past_key_values if use_cache else None,
770
- hidden_states=all_hidden_states,
771
- attentions=all_self_attns,
772
- )
773
- return output if return_dict else output.to_tuple()
774
-
775
-
776
- class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
777
-
778
-
779
- @auto_docstring
780
- class LFM2ForCausalLM(LFM2PretrainedModel, GenerationMixin):
781
- _tied_weights_keys = ["lm_head.weight"]
782
-
783
- def __init__(self, config: LFM2Config):
784
- super().__init__(config)
785
- self.model = LFM2Model(config)
786
- self.vocab_size = config.vocab_size
787
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
788
- self.post_init()
789
-
790
- def get_input_embeddings(self):
791
- return self.model.embed_tokens
792
-
793
- def set_input_embeddings(self, value):
794
- self.model.embed_tokens = value
795
-
796
- def get_output_embeddings(self):
797
- return self.lm_head
798
-
799
- def set_output_embeddings(self, new_embeddings):
800
- self.lm_head = new_embeddings
801
-
802
- def set_decoder(self, decoder):
803
- self.model = decoder
804
-
805
- def get_decoder(self):
806
- return self.model
807
-
808
- def forward(
809
- self,
810
- input_ids: torch.LongTensor = None,
811
- attention_mask: Optional[torch.Tensor] = None,
812
- position_ids: Optional[torch.LongTensor] = None,
813
- past_key_values: Optional[LFM2Cache] = None,
814
- inputs_embeds: Optional[torch.FloatTensor] = None,
815
- labels: Optional[torch.LongTensor] = None,
816
- use_cache: Optional[bool] = None,
817
- output_attentions: Optional[bool] = None,
818
- output_hidden_states: Optional[bool] = None,
819
- return_dict: Optional[bool] = None,
820
- cache_position: Optional[torch.LongTensor] = None,
821
- logits_to_keep: Union[int, torch.Tensor] = 0,
822
- **kwargs: Unpack[KwargsForCausalLM],
823
- ) -> Union[tuple, CausalLMOutputWithPast]:
824
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
825
- output_hidden_states = (
826
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
827
- )
828
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
829
-
830
- outputs: BaseModelOutputWithPast = self.model(
831
- input_ids=input_ids,
832
- attention_mask=attention_mask,
833
- position_ids=position_ids,
834
- past_key_values=past_key_values,
835
- inputs_embeds=inputs_embeds,
836
- use_cache=use_cache,
837
- output_attentions=output_attentions,
838
- output_hidden_states=output_hidden_states,
839
- cache_position=cache_position,
840
- return_dict=return_dict,
841
- **kwargs,
842
- )
843
-
844
- hidden_states = outputs.last_hidden_state
845
- # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
846
- slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
847
- logits = self.lm_head(hidden_states[:, slice_indices, :])
848
-
849
- loss = None
850
- if labels is not None:
851
- loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
852
-
853
- if not return_dict:
854
- output = (logits,) + outputs[1:]
855
- return (loss,) + output if loss is not None else output
856
-
857
- return CausalLMOutputWithPast(
858
- loss=loss,
859
- logits=logits,
860
- past_key_values=outputs.past_key_values,
861
- hidden_states=outputs.hidden_states,
862
- attentions=outputs.attentions,
863
- )
864
-
865
- def prepare_inputs_for_generation(
866
- self,
867
- input_ids,
868
- past_key_values=None,
869
- attention_mask=None,
870
- inputs_embeds=None,
871
- cache_position=None,
872
- position_ids=None,
873
- use_cache=True,
874
- **kwargs,
875
- ):
876
- # Overwritten -- Support custom LFM2Cache.
877
-
878
- empty_past_kv = past_key_values is None or (
879
- isinstance(past_key_values, DynamicCache) and past_key_values._seen_tokens == 0
880
- )
881
-
882
- # Omit tokens covered by past_key_values.
883
- if not empty_past_kv:
884
- # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
885
- # Exception 1: when passing input_embeds, input_ids may be missing entries
886
- # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
887
- # Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case.
888
- # (we can't check exception 3 while compiling)
889
- if (
890
- inputs_embeds is not None # Exception 1
891
- or cache_position[-1] >= input_ids.shape[1] # Exception 3
892
- ):
893
- input_ids = input_ids[:, -cache_position.shape[0] :]
894
- elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
895
- input_ids = input_ids[:, cache_position]
896
- else:
897
- past_key_values = LFM2Cache(self.config, input_ids.shape[0], dtype=self.dtype, device=self.device)
898
-
899
- # if attention_mask is not None and position_ids is None:
900
- # # create position_ids on the fly for batch generation
901
- # position_ids = attention_mask.long().cumsum(-1) - 1
902
- # position_ids.masked_fill_(attention_mask == 0, 1)
903
- # if not empty_past_kv:
904
- # position_ids = position_ids[:, -input_ids.shape[1] :]
905
-
906
- # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
907
- if inputs_embeds is not None and empty_past_kv:
908
- model_inputs = {"inputs_embeds": inputs_embeds}
909
- else:
910
- model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases
911
-
912
- model_inputs.update(
913
- {
914
- # "position_ids": position_ids,
915
- "past_key_values": past_key_values,
916
- "use_cache": use_cache,
917
- "attention_mask": attention_mask,
918
- "cache_position": cache_position,
919
- }
920
- )
921
- return model_inputs
922
-
923
-
924
- __all__ = ["LFM2ForCausalLM", "LFM2Model", "LFM2PretrainedModel"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt DELETED
@@ -1,2 +0,0 @@
1
- transformers==4.53.0.dev0
2
- tokenizers==0.21.1