ffurfaro commited on
Commit
717a938
·
verified ·
1 Parent(s): ee8c3d7

Upload model + init tptt code

Browse files
README.md ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language: en
3
+ license: apache-2.0
4
+ library_name: transformers
5
+ tags:
6
+ - tptt
7
+ - peft
8
+ - trust_remote_code
9
+ pipeline_tag: text-generation
10
+ base_model: allenai/OLMoE-1B-7B-0924
11
+ datasets:
12
+ - yahma/alpaca-cleaned
13
+ ---
14
+
15
+ # Titans-v2-OLMoE-1B-7B-0924
16
+
17
+ <p align="center">
18
+ <a href="https://arxiv.org/abs/2506.17671">
19
+ <img alt="arXiv" src="https://img.shields.io/badge/arXiv-tptt-blueviolet.svg">
20
+ </a>
21
+ <a href="https://pypi.org/project/tptt/">
22
+ <img alt="PyPI" src="https://img.shields.io/pypi/v/tptt?color=orange">
23
+ </a>
24
+ <a href="https://github.com/fabienfrfr/tptt/">
25
+ <img alt="Release" src="https://img.shields.io/github/v/release/fabienfrfr/tptt?color=brightgreen">
26
+ </a>
27
+ <a href="https://fabienfrfr.github.io/tptt/">
28
+ <img alt="Documentation" src="https://img.shields.io/badge/docs-online-blue">
29
+ </a>
30
+ <a href="https://huggingface.co/ffurfaro">
31
+ <img alt="HuggingFace" src="https://img.shields.io/badge/hf-ffurfaro-yellow">
32
+ </a>
33
+ </p>
34
+
35
+ Titanesque version of `allenai/OLMoE-1B-7B-0924` with parallel linearized attention (TPTT 😊) and PEFT.
36
+
37
+ The architecture was presented in the paper [TPTT](https://huggingface.co/papers/2506.17671).
38
+
39
+
40
+ ## Model list
41
+
42
+ Classic model parameter with LiZA injection :
43
+
44
+ | Subfolder | Max Self Attn Length | Mag Weight | Cross Gate | Max Chunk Size | Bidirectional | LoRA | Description |
45
+ |-------------------------------|----------------------|------------|------------|----------------|---------------|------|-------------------------------------------------------|
46
+ | delta_rule | 8192 (default) | 0.5 | False | 64 | False | Yes | Parallel linearized attention with delta_rule operator|
47
+ | delta_rule_gelu | 8192 (default) | 0.5 | False | 64 | False | Yes | Non-linear operator with gelu activation |
48
+ | delta_product | 8192 (default) | 0.5 | False | 64 | False | Yes | Second order operator with derivative trick |
49
+ | delta_product_r | 8192 (default) | 0.5 | False | 64 | False | Yes | Second order operator with rotative trick |
50
+ | delta_product_c | 8192 (default) | 0.5 | False | 64 | False | Yes | Second order operator with combined trick |
51
+
52
+ ## Usage
53
+
54
+ ```python
55
+ from transformers import AutoModelForCausalLM, AutoTokenizer
56
+
57
+ model = AutoModelForCausalLM.from_pretrained(
58
+ "ffurfaro/Titans-v2-OLMoE-1B-7B-0924",
59
+ subfolder="tptt_subfolder", # see in repo tree
60
+ trust_remote_code=True
61
+ )
62
+ tokenizer = AutoTokenizer.from_pretrained("ffurfaro/allenai/OLMoE-1B-7B-0924")
63
+
64
+ prompt = "Your prompt here"
65
+ inputs = tokenizer(prompt, return_tensors="pt")
66
+ outputs = model.generate(**inputs, max_new_tokens=100)
67
+ print(tokenizer.decode(outputs, skip_special_tokens=True))
68
+
69
+ ```
70
+
71
+
72
+ ## Citation & Contact
73
+
74
+ If you use TPTT in your academic work, please cite [Furfaro](https://huggingface.co/ffurfaro). For questions or support, please open an issue on the [GitHub repository](https://github.com/fabienfrfr/tptt) or contact the maintainer.
75
+
76
+
77
+ ---
__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This module implements the TPTT model with linear attention (LiZA) and LoRA support.
3
+ """
4
+
5
+ from .configuration_tptt import (TpttConfig, generate_model_card,
6
+ parse_mode_name)
7
+ from .modeling_tptt import (LCache, LinearAttention, LinearAttentionOp,
8
+ LiZAttention, TpttModel, get_tptt_model,
9
+ load_tptt_safetensors, save_tptt_safetensors)
10
+ from .train_tptt import LiZACallback, SaveBestModelCallback
11
+
12
+ __all__ = [
13
+ "TpttConfig",
14
+ "TpttModel",
15
+ "get_tptt_model",
16
+ "LiZACallback",
17
+ "SaveBestModelCallback",
18
+ "LCache",
19
+ "LinearAttentionOp",
20
+ "LiZAttention",
21
+ "generate_model_card",
22
+ "LinearAttention",
23
+ "parse_mode_name",
24
+ "load_tptt_safetensors",
25
+ "save_tptt_safetensors",
26
+ ]
configuration_tptt.py ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pylint: disable=too-many-arguments, too-many-positional-arguments, too-many-instance-attributes, too-many-locals
2
+ """
3
+ Author : Fabien FURFARO
4
+ """
5
+
6
+ import logging
7
+ import os
8
+ import re
9
+ from typing import Any, Dict, List, Optional, Union
10
+ from jinja2 import Environment, FileSystemLoader
11
+
12
+ import torch
13
+ from transformers import AutoConfig, PretrainedConfig
14
+
15
+ logger = logging.getLogger(__name__) # monitoring
16
+
17
+
18
+ def convert_sets_to_lists(obj):
19
+ """Convert sets to list for LoRA serialized config"""
20
+ if isinstance(obj, set):
21
+ return list(obj)
22
+ if isinstance(obj, dict):
23
+ return {k: convert_sets_to_lists(v) for k, v in obj.items()}
24
+ if isinstance(obj, (list, tuple)):
25
+ return [convert_sets_to_lists(x) for x in obj]
26
+ return obj
27
+
28
+
29
+ class TpttConfig(PretrainedConfig):
30
+ """
31
+ Configuration class for the TPTT model.
32
+ This class merges the backbone config (e.g., Llama) with custom TPTT parameters,
33
+ """
34
+
35
+ model_type = "tptt"
36
+ auto_map = {
37
+ "AutoModelForCausalLM": "modeling_tptt.TpttModel",
38
+ "AutoConfig": "configuration_tptt.TpttConfig",
39
+ }
40
+ architectures = ["TpttModel"]
41
+
42
+ RECURRENT_MODES = {
43
+ "delta_rule": {
44
+ "order": 1,
45
+ "gate_type": "k",
46
+ "linear": True,
47
+ "trick": "derivative",
48
+ },
49
+ "delta_rule_v": {
50
+ "order": 1,
51
+ "gate_type": "v",
52
+ "linear": True,
53
+ "trick": "derivative",
54
+ },
55
+ "delta_rule_kv": {
56
+ "order": 1,
57
+ "gate_type": "kv",
58
+ "linear": True,
59
+ "trick": "derivative",
60
+ },
61
+ "delta_rule_gelu": {
62
+ "order": 1,
63
+ "gate_type": "k",
64
+ "linear": False,
65
+ "trick": "derivative",
66
+ },
67
+ "delta_product": {
68
+ "order": 2,
69
+ "gate_type": "k",
70
+ "linear": True,
71
+ "trick": "derivative",
72
+ },
73
+ "delta_product_r": {
74
+ "order": 2,
75
+ "gate_type": "k",
76
+ "linear": True,
77
+ "trick": "rotative",
78
+ },
79
+ "delta_product_c": {
80
+ "order": 2,
81
+ "gate_type": "k",
82
+ "linear": True,
83
+ "trick": "combined",
84
+ },
85
+ } # Tested modes, see parse_mode_name if you want to add more
86
+
87
+ def __init__(
88
+ self,
89
+ base_model_config: Optional[Union[dict, PretrainedConfig]] = None,
90
+ base_model_name: str = "meta-llama/Llama-3.2-1B",
91
+ base_model_subfolder: Optional = None,
92
+ name_or_path: Optional[str] = None,
93
+ target_modules_names: Optional[List[str]] = None,
94
+ operator_mode: str = "delta_rule",
95
+ max_self_attn_length: Optional[
96
+ int
97
+ ] = None, # unnecessary if SWA, else, standards 8192
98
+ base_scale_attn: bool = False,
99
+ mag_weight: float = 0.5, # if 1.0, use only linear operator
100
+ cross_gate: bool = False, # unlinear mixing strategy
101
+ max_chunk_size: int = 64,
102
+ linear_precision: Union[str, torch.dtype] = "float32",
103
+ lora_config: Optional[dict] = None, # only serialized accepted
104
+ padding_side: Optional[str] = None, # for tokenizer, default "right"
105
+ bidirectional: bool = False, # if True, use bidirectional attention
106
+ pooling_config: Optional[Dict[str, Any]] = None,
107
+ **kwargs,
108
+ ):
109
+ # If base_model_config is provided, load it and merge with this config
110
+ if base_model_config is not None:
111
+ if isinstance(base_model_config, PretrainedConfig):
112
+ base_model_config = base_model_config.to_dict()
113
+ else:
114
+ # Load config from Hugging Face Hub or a local path
115
+ base_model_config = AutoConfig.from_pretrained(
116
+ base_model_name, **kwargs
117
+ ).to_dict()
118
+ # Merge all backbone fields into this config
119
+ for k, v in base_model_config.items():
120
+ setattr(self, k, v)
121
+
122
+ self.base_model_name = base_model_name
123
+ self.base_model_subfolder = base_model_subfolder
124
+
125
+ if name_or_path is not None:
126
+ self._name_or_path = name_or_path
127
+ else:
128
+ if "/" in base_model_name:
129
+ self._name_or_path = "Titans-" + base_model_name.split("/", 1)[1]
130
+ else:
131
+ self._name_or_path = "Titans-" + base_model_name
132
+
133
+ self.target_modules_names = target_modules_names or [
134
+ "attn",
135
+ "self_attn",
136
+ "attention",
137
+ ]
138
+ self.operator_mode = operator_mode
139
+ self.base_scale_attn = base_scale_attn
140
+ self.mag_weight = mag_weight
141
+ self.cross_gate = cross_gate
142
+ self.max_chunk_size = max_chunk_size
143
+ self.max_self_attn_length = max_self_attn_length
144
+ if isinstance(linear_precision, torch.dtype):
145
+ linear_precision = str(linear_precision).replace("torch.", "")
146
+ self.linear_precision = linear_precision
147
+
148
+ self.lora_config = lora_config
149
+ if lora_config is not None:
150
+ if hasattr(self.lora_config.get("peft_type"), "value"):
151
+ self.lora_config["peft_type"] = self.lora_config["peft_type"].value
152
+ self.lora_config = convert_sets_to_lists(self.lora_config)
153
+
154
+ self.padding_side = padding_side
155
+ self.bidirectional = bidirectional
156
+ if self.bidirectional:
157
+ print("Bidirectional is enabled, need to be uncausal and unpadded.")
158
+ self.pooling_config = pooling_config
159
+
160
+ super().__init__(**kwargs) # flush unconsistend pretrained parameters (?)
161
+ # Copy class attributes to instance for serialization (save dict)
162
+ self.model_type = self.__class__.model_type
163
+ self.auto_map = self.__class__.auto_map
164
+ self.architectures = self.__class__.architectures
165
+ # Padding side configuration if not set
166
+ if self.padding_side is None:
167
+ self.padding_side = "right"
168
+ logger.info("Warning: padding_side is None, defaulting to 'right'.")
169
+ # set recurrent configuration from operator mode
170
+ if operator_mode not in self.__class__.RECURRENT_MODES:
171
+ self.recurrent_config = parse_mode_name(operator_mode)
172
+ else:
173
+ self.recurrent_config = self.__class__.RECURRENT_MODES[operator_mode]
174
+ logger.info("Using recurrent mode: %s", get_mode_name(**self.recurrent_config))
175
+
176
+
177
+ TpttConfig.register_for_auto_class()
178
+
179
+
180
+ def parse_mode_name(name: str) -> dict:
181
+ """Parse mode to recurrent config"""
182
+ if name.startswith("delta_product"):
183
+ parts = name.split("_")
184
+ # Prefix is always two words: 'delta' and 'product'
185
+ base_len = 2
186
+ order = 2
187
+ gate_type = "k"
188
+ linear = True
189
+ trick = "derivative"
190
+
191
+ idx = base_len
192
+ # Check for order (immediately after the prefix)
193
+ if len(parts) > idx and parts[idx].isdigit():
194
+ order = int(parts[idx])
195
+ idx += 1
196
+
197
+ remaining = parts[idx:]
198
+ # Trick (r/c) is always at the far right if present
199
+ if remaining and remaining[-1] in ("r", "c"):
200
+ trick = {"r": "rotative", "c": "combined"}[remaining[-1]]
201
+ remaining = remaining[:-1]
202
+ # 'gelu' comes just before the trick if present
203
+ if remaining and remaining[-1] == "gelu":
204
+ linear = False
205
+ remaining = remaining[:-1]
206
+ # If anything remains, it's the gate_type
207
+ if remaining:
208
+ gate_type = "_".join(remaining)
209
+ return {
210
+ "order": order,
211
+ "gate_type": gate_type,
212
+ "linear": linear,
213
+ "trick": trick,
214
+ }
215
+
216
+ # delta_rule[_gate][_gelu]
217
+ m = re.match(r"^delta_rule(?:_(kv|v|k))?(_gelu)?$", name)
218
+ if m:
219
+ return {
220
+ "order": 1,
221
+ "gate_type": m.group(1) if m.group(1) else "k",
222
+ "linear": not bool(m.group(2)),
223
+ "trick": "derivative",
224
+ }
225
+ raise ValueError(f"Unknown mode: {name}")
226
+
227
+
228
+ def get_mode_name(
229
+ order: int = 1, gate_type: str = "k", linear: bool = True, trick: str = "derivative"
230
+ ) -> str:
231
+ """Get recurrent mode name from parameter"""
232
+ base = (
233
+ "delta_rule"
234
+ if order == 1
235
+ else ("delta_product" if order == 2 else f"delta_product_{order}")
236
+ )
237
+ parts = []
238
+ if gate_type != "k":
239
+ parts.append(gate_type)
240
+ if not linear:
241
+ parts.append("gelu")
242
+ if order >= 2 and trick != "derivative":
243
+ parts.append({"rotative": "r", "combined": "c"}.get(trick, trick))
244
+ return base + (("_" + "_".join(parts)) if parts else "")
245
+
246
+
247
+ def render_template(template_path: str, variables: dict) -> str:
248
+ """Load and render a Jinja2 template from any file path."""
249
+ env = Environment(loader=FileSystemLoader(os.path.dirname(template_path)))
250
+ template = env.get_template(os.path.basename(template_path))
251
+ return template.render(**variables)
252
+
253
+
254
+ def write_model_card(output_path: str, content: str):
255
+ """Write the generated content into README.md."""
256
+ os.makedirs(output_path, exist_ok=True)
257
+ readme_path = os.path.join(output_path, "README.md")
258
+ with open(readme_path, "w", encoding="utf-8") as f:
259
+ f.write(content)
260
+
261
+
262
+ def generate_model_card(
263
+ output_path: str,
264
+ config: Union[dict, object],
265
+ template: Optional[
266
+ str
267
+ ], # can be "model_card" OR an absolute/relative path to a .md file
268
+ extra_variables: Optional[Dict] = None,
269
+ ):
270
+ """
271
+ Generate a README.md file from a Jinja2 template and a configuration.
272
+
273
+ - template can be either:
274
+ * a full path to a template file
275
+ * a short name (e.g., "model_card") -> will be looked up inside default_templates_dir
276
+ """
277
+ if template is None:
278
+ template = "model_card_template" # default template name
279
+ # Locate the template
280
+ if os.path.exists(template): # direct file path provided
281
+ template_path = template
282
+ else:
283
+ default_templates_dir = os.path.join(os.path.dirname(__file__), "templates")
284
+ template_path = os.path.join(default_templates_dir, f"{template}.md")
285
+
286
+ if not os.path.exists(template_path):
287
+ raise FileNotFoundError(f"Template not found: {template_path}")
288
+
289
+ variables = {
290
+ "model_id": os.path.basename(output_path),
291
+ "config": config,
292
+ }
293
+ if extra_variables:
294
+ variables.update(extra_variables)
295
+
296
+ content = render_template(template_path, variables)
297
+ write_model_card(output_path, content)
lora_delta_product_m0.5_gradual_t10/README.md ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language: en
3
+ license: apache-2.0
4
+ library_name: transformers
5
+ tags:
6
+ - tptt
7
+ - peft
8
+ - trust_remote_code
9
+ pipeline_tag: text-generation
10
+ base_model: allenai/OLMoE-1B-7B-0924
11
+ datasets:
12
+ - yahma/alpaca-cleaned
13
+ ---
14
+
15
+ # lora_delta_product_m0.5_gradual_t10
16
+
17
+ <p align="center">
18
+ <a href="https://arxiv.org/abs/2506.17671">
19
+ <img alt="arXiv" src="https://img.shields.io/badge/arXiv-tptt-blueviolet.svg">
20
+ </a>
21
+ <a href="https://pypi.org/project/tptt/">
22
+ <img alt="PyPI" src="https://img.shields.io/pypi/v/tptt?color=orange">
23
+ </a>
24
+ <a href="https://github.com/fabienfrfr/tptt/">
25
+ <img alt="Release" src="https://img.shields.io/github/v/release/fabienfrfr/tptt?color=brightgreen">
26
+ </a>
27
+ <a href="https://fabienfrfr.github.io/tptt/">
28
+ <img alt="Documentation" src="https://img.shields.io/badge/docs-online-blue">
29
+ </a>
30
+ <a href="https://huggingface.co/ffurfaro">
31
+ <img alt="HuggingFace" src="https://img.shields.io/badge/hf-ffurfaro-yellow">
32
+ </a>
33
+ </p>
34
+
35
+ Titanesque version of `allenai/OLMoE-1B-7B-0924` with parallel linearized attention (TPTT 😊) and PEFT.
36
+
37
+ The architecture was presented in the paper [TPTT](https://huggingface.co/papers/2506.17671).
38
+
39
+
40
+ ## Model Details
41
+
42
+ - **Architecture:** ['TpttModel']
43
+ - **Base model:** allenai/OLMoE-1B-7B-0924
44
+ - **LiZA config:** operator=delta_product, mag=0.5
45
+ - **LoRA config:** r=8, alpha=16, dropout=0.05
46
+ - **torch_dtype:**
47
+
48
+ ## Usage
49
+
50
+
51
+ ```python
52
+ from transformers import AutoModelForCausalLM, AutoTokenizer
53
+
54
+ model = AutoModelForCausalLM.from_pretrained(
55
+ "ffurfaro/lora_delta_product_m0.5_gradual_t10",
56
+ trust_remote_code=True
57
+ )
58
+ tokenizer = AutoTokenizer.from_pretrained("ffurfaro/allenai/OLMoE-1B-7B-0924")
59
+
60
+ prompt = "Your prompt here"
61
+ inputs = tokenizer(prompt, return_tensors="pt")
62
+ outputs = model.generate(**inputs, max_new_tokens=100)
63
+ print(tokenizer.decode(outputs, skip_special_tokens=True))
64
+
65
+ ```
66
+
67
+ > [!IMPORTANT]
68
+ > You must specify the `subfolder` if the repo contains multiple models, see the homepage for details.
69
+
70
+ ## Training
71
+
72
+ - **Dataset:** yahma/alpaca-cleaned
73
+ - **Platform:** Kaggle
74
+ - **Hardware:** NVIDIA 2xT4
75
+ - **Batch size:** 2
76
+ - **Epochs:** 1.0
77
+ - **Learning rate (final):** N/A
78
+ - **Loss (final):** 3.255523986816406
79
+ - **Training runtime:** 329.6156 sec
80
+ - **Samples per second:** 0.3
81
+ - **Steps per second:** 0.152
82
+ - **Total FLOPs:** 526406736936960.0
83
+ - **Gradient norm (final):** N/A
84
+
85
+ ## Evaluation
86
+
87
+ - **Metrics:** Training loss only (no eval yet, table soon : PiQA, ARC, Hella, Wino, GSM8K, MMLU)
88
+ - **Results:** Final training loss: 3.255523986816406
89
+
90
+
91
+ ## Citation & Contact
92
+
93
+ If you use TPTT in your academic work, please cite [Furfaro](https://huggingface.co/ffurfaro). For questions or support, please open an issue on the [GitHub repository](https://github.com/fabienfrfr/tptt) or contact the maintainer.
94
+
95
+
96
+ ---
lora_delta_product_m0.5_gradual_t10/adapter_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:059c42de63faae87d93d7ac349c927db6398c934c31ad4189906976e4ed10b93
3
+ size 8406296
lora_delta_product_m0.5_gradual_t10/config.json ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "TpttModel"
4
+ ],
5
+ "attention_bias": false,
6
+ "attention_dropout": 0.0,
7
+ "auto_map": {
8
+ "AutoConfig": "configuration_tptt.TpttConfig",
9
+ "AutoModelForCausalLM": "modeling_tptt.TpttModel"
10
+ },
11
+ "base_model_name": "allenai/OLMoE-1B-7B-0924",
12
+ "base_model_subfolder": null,
13
+ "base_scale_attn": false,
14
+ "bidirectional": false,
15
+ "clip_qkv": null,
16
+ "cross_gate": false,
17
+ "hidden_act": "silu",
18
+ "hidden_size": 2048,
19
+ "initializer_range": 0.02,
20
+ "intermediate_size": 1024,
21
+ "linear_precision": "bfloat16",
22
+ "lora_config": {
23
+ "alpha_pattern": {},
24
+ "auto_mapping": null,
25
+ "base_model_name_or_path": null,
26
+ "bias": "none",
27
+ "eva_config": null,
28
+ "exclude_modules": null,
29
+ "fan_in_fan_out": false,
30
+ "inference_mode": false,
31
+ "init_lora_weights": true,
32
+ "layer_replication": null,
33
+ "layers_pattern": null,
34
+ "layers_to_transform": null,
35
+ "loftq_config": {},
36
+ "lora_alpha": 16,
37
+ "lora_bias": false,
38
+ "lora_dropout": 0.05,
39
+ "megatron_config": null,
40
+ "megatron_core": "megatron.core",
41
+ "modules_to_save": null,
42
+ "peft_type": "LORA",
43
+ "r": 8,
44
+ "rank_pattern": {},
45
+ "revision": null,
46
+ "target_modules": [
47
+ "k_proj",
48
+ "o_proj",
49
+ "q_proj",
50
+ "v_proj"
51
+ ],
52
+ "task_type": "CAUSAL_LM",
53
+ "use_dora": false,
54
+ "use_rslora": false
55
+ },
56
+ "mag_weight": 0.5,
57
+ "max_chunk_size": 32,
58
+ "max_position_embeddings": 4096,
59
+ "max_self_attn_length": null,
60
+ "model_type": "tptt",
61
+ "norm_topk_prob": false,
62
+ "num_attention_heads": 16,
63
+ "num_experts": 64,
64
+ "num_experts_per_tok": 8,
65
+ "num_hidden_layers": 16,
66
+ "num_key_value_heads": 16,
67
+ "operator_mode": "delta_product",
68
+ "output_router_logits": false,
69
+ "padding_side": "right",
70
+ "pooling_config": null,
71
+ "recurrent_config": {
72
+ "gate_type": "k",
73
+ "linear": true,
74
+ "order": 2,
75
+ "trick": "derivative"
76
+ },
77
+ "rms_norm_eps": 1e-05,
78
+ "rope_scaling": null,
79
+ "rope_theta": 10000.0,
80
+ "router_aux_loss_coef": 0.01,
81
+ "target_modules_names": [
82
+ "attn",
83
+ "self_attn",
84
+ "attention"
85
+ ],
86
+ "torch_dtype": "bfloat16",
87
+ "transformers_version": "4.49.0",
88
+ "use_cache": true,
89
+ "vocab_size": 50304
90
+ }
lora_delta_product_m0.5_gradual_t10/configuration_tptt.py ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pylint: disable=too-many-arguments, too-many-positional-arguments, too-many-instance-attributes, too-many-locals
2
+ """
3
+ Author : Fabien FURFARO
4
+ """
5
+
6
+ import logging
7
+ import os
8
+ import re
9
+ from typing import Any, Dict, List, Optional, Union
10
+ from jinja2 import Environment, FileSystemLoader
11
+
12
+ import torch
13
+ from transformers import AutoConfig, PretrainedConfig
14
+
15
+ logger = logging.getLogger(__name__) # monitoring
16
+
17
+
18
+ def convert_sets_to_lists(obj):
19
+ """Convert sets to list for LoRA serialized config"""
20
+ if isinstance(obj, set):
21
+ return list(obj)
22
+ if isinstance(obj, dict):
23
+ return {k: convert_sets_to_lists(v) for k, v in obj.items()}
24
+ if isinstance(obj, (list, tuple)):
25
+ return [convert_sets_to_lists(x) for x in obj]
26
+ return obj
27
+
28
+
29
+ class TpttConfig(PretrainedConfig):
30
+ """
31
+ Configuration class for the TPTT model.
32
+ This class merges the backbone config (e.g., Llama) with custom TPTT parameters,
33
+ """
34
+
35
+ model_type = "tptt"
36
+ auto_map = {
37
+ "AutoModelForCausalLM": "modeling_tptt.TpttModel",
38
+ "AutoConfig": "configuration_tptt.TpttConfig",
39
+ }
40
+ architectures = ["TpttModel"]
41
+
42
+ RECURRENT_MODES = {
43
+ "delta_rule": {
44
+ "order": 1,
45
+ "gate_type": "k",
46
+ "linear": True,
47
+ "trick": "derivative",
48
+ },
49
+ "delta_rule_v": {
50
+ "order": 1,
51
+ "gate_type": "v",
52
+ "linear": True,
53
+ "trick": "derivative",
54
+ },
55
+ "delta_rule_kv": {
56
+ "order": 1,
57
+ "gate_type": "kv",
58
+ "linear": True,
59
+ "trick": "derivative",
60
+ },
61
+ "delta_rule_gelu": {
62
+ "order": 1,
63
+ "gate_type": "k",
64
+ "linear": False,
65
+ "trick": "derivative",
66
+ },
67
+ "delta_product": {
68
+ "order": 2,
69
+ "gate_type": "k",
70
+ "linear": True,
71
+ "trick": "derivative",
72
+ },
73
+ "delta_product_r": {
74
+ "order": 2,
75
+ "gate_type": "k",
76
+ "linear": True,
77
+ "trick": "rotative",
78
+ },
79
+ "delta_product_c": {
80
+ "order": 2,
81
+ "gate_type": "k",
82
+ "linear": True,
83
+ "trick": "combined",
84
+ },
85
+ } # Tested modes, see parse_mode_name if you want to add more
86
+
87
+ def __init__(
88
+ self,
89
+ base_model_config: Optional[Union[dict, PretrainedConfig]] = None,
90
+ base_model_name: str = "meta-llama/Llama-3.2-1B",
91
+ base_model_subfolder: Optional = None,
92
+ name_or_path: Optional[str] = None,
93
+ target_modules_names: Optional[List[str]] = None,
94
+ operator_mode: str = "delta_rule",
95
+ max_self_attn_length: Optional[
96
+ int
97
+ ] = None, # unnecessary if SWA, else, standards 8192
98
+ base_scale_attn: bool = False,
99
+ mag_weight: float = 0.5, # if 1.0, use only linear operator
100
+ cross_gate: bool = False, # unlinear mixing strategy
101
+ max_chunk_size: int = 64,
102
+ linear_precision: Union[str, torch.dtype] = "float32",
103
+ lora_config: Optional[dict] = None, # only serialized accepted
104
+ padding_side: Optional[str] = None, # for tokenizer, default "right"
105
+ bidirectional: bool = False, # if True, use bidirectional attention
106
+ pooling_config: Optional[Dict[str, Any]] = None,
107
+ **kwargs,
108
+ ):
109
+ # If base_model_config is provided, load it and merge with this config
110
+ if base_model_config is not None:
111
+ if isinstance(base_model_config, PretrainedConfig):
112
+ base_model_config = base_model_config.to_dict()
113
+ else:
114
+ # Load config from Hugging Face Hub or a local path
115
+ base_model_config = AutoConfig.from_pretrained(
116
+ base_model_name, **kwargs
117
+ ).to_dict()
118
+ # Merge all backbone fields into this config
119
+ for k, v in base_model_config.items():
120
+ setattr(self, k, v)
121
+
122
+ self.base_model_name = base_model_name
123
+ self.base_model_subfolder = base_model_subfolder
124
+
125
+ if name_or_path is not None:
126
+ self._name_or_path = name_or_path
127
+ else:
128
+ if "/" in base_model_name:
129
+ self._name_or_path = "Titans-" + base_model_name.split("/", 1)[1]
130
+ else:
131
+ self._name_or_path = "Titans-" + base_model_name
132
+
133
+ self.target_modules_names = target_modules_names or [
134
+ "attn",
135
+ "self_attn",
136
+ "attention",
137
+ ]
138
+ self.operator_mode = operator_mode
139
+ self.base_scale_attn = base_scale_attn
140
+ self.mag_weight = mag_weight
141
+ self.cross_gate = cross_gate
142
+ self.max_chunk_size = max_chunk_size
143
+ self.max_self_attn_length = max_self_attn_length
144
+ if isinstance(linear_precision, torch.dtype):
145
+ linear_precision = str(linear_precision).replace("torch.", "")
146
+ self.linear_precision = linear_precision
147
+
148
+ self.lora_config = lora_config
149
+ if lora_config is not None:
150
+ if hasattr(self.lora_config.get("peft_type"), "value"):
151
+ self.lora_config["peft_type"] = self.lora_config["peft_type"].value
152
+ self.lora_config = convert_sets_to_lists(self.lora_config)
153
+
154
+ self.padding_side = padding_side
155
+ self.bidirectional = bidirectional
156
+ if self.bidirectional:
157
+ print("Bidirectional is enabled, need to be uncausal and unpadded.")
158
+ self.pooling_config = pooling_config
159
+
160
+ super().__init__(**kwargs) # flush unconsistend pretrained parameters (?)
161
+ # Copy class attributes to instance for serialization (save dict)
162
+ self.model_type = self.__class__.model_type
163
+ self.auto_map = self.__class__.auto_map
164
+ self.architectures = self.__class__.architectures
165
+ # Padding side configuration if not set
166
+ if self.padding_side is None:
167
+ self.padding_side = "right"
168
+ logger.info("Warning: padding_side is None, defaulting to 'right'.")
169
+ # set recurrent configuration from operator mode
170
+ if operator_mode not in self.__class__.RECURRENT_MODES:
171
+ self.recurrent_config = parse_mode_name(operator_mode)
172
+ else:
173
+ self.recurrent_config = self.__class__.RECURRENT_MODES[operator_mode]
174
+ logger.info("Using recurrent mode: %s", get_mode_name(**self.recurrent_config))
175
+
176
+
177
+ TpttConfig.register_for_auto_class()
178
+
179
+
180
+ def parse_mode_name(name: str) -> dict:
181
+ """Parse mode to recurrent config"""
182
+ if name.startswith("delta_product"):
183
+ parts = name.split("_")
184
+ # Prefix is always two words: 'delta' and 'product'
185
+ base_len = 2
186
+ order = 2
187
+ gate_type = "k"
188
+ linear = True
189
+ trick = "derivative"
190
+
191
+ idx = base_len
192
+ # Check for order (immediately after the prefix)
193
+ if len(parts) > idx and parts[idx].isdigit():
194
+ order = int(parts[idx])
195
+ idx += 1
196
+
197
+ remaining = parts[idx:]
198
+ # Trick (r/c) is always at the far right if present
199
+ if remaining and remaining[-1] in ("r", "c"):
200
+ trick = {"r": "rotative", "c": "combined"}[remaining[-1]]
201
+ remaining = remaining[:-1]
202
+ # 'gelu' comes just before the trick if present
203
+ if remaining and remaining[-1] == "gelu":
204
+ linear = False
205
+ remaining = remaining[:-1]
206
+ # If anything remains, it's the gate_type
207
+ if remaining:
208
+ gate_type = "_".join(remaining)
209
+ return {
210
+ "order": order,
211
+ "gate_type": gate_type,
212
+ "linear": linear,
213
+ "trick": trick,
214
+ }
215
+
216
+ # delta_rule[_gate][_gelu]
217
+ m = re.match(r"^delta_rule(?:_(kv|v|k))?(_gelu)?$", name)
218
+ if m:
219
+ return {
220
+ "order": 1,
221
+ "gate_type": m.group(1) if m.group(1) else "k",
222
+ "linear": not bool(m.group(2)),
223
+ "trick": "derivative",
224
+ }
225
+ raise ValueError(f"Unknown mode: {name}")
226
+
227
+
228
+ def get_mode_name(
229
+ order: int = 1, gate_type: str = "k", linear: bool = True, trick: str = "derivative"
230
+ ) -> str:
231
+ """Get recurrent mode name from parameter"""
232
+ base = (
233
+ "delta_rule"
234
+ if order == 1
235
+ else ("delta_product" if order == 2 else f"delta_product_{order}")
236
+ )
237
+ parts = []
238
+ if gate_type != "k":
239
+ parts.append(gate_type)
240
+ if not linear:
241
+ parts.append("gelu")
242
+ if order >= 2 and trick != "derivative":
243
+ parts.append({"rotative": "r", "combined": "c"}.get(trick, trick))
244
+ return base + (("_" + "_".join(parts)) if parts else "")
245
+
246
+
247
+ def render_template(template_path: str, variables: dict) -> str:
248
+ """Load and render a Jinja2 template from any file path."""
249
+ env = Environment(loader=FileSystemLoader(os.path.dirname(template_path)))
250
+ template = env.get_template(os.path.basename(template_path))
251
+ return template.render(**variables)
252
+
253
+
254
+ def write_model_card(output_path: str, content: str):
255
+ """Write the generated content into README.md."""
256
+ os.makedirs(output_path, exist_ok=True)
257
+ readme_path = os.path.join(output_path, "README.md")
258
+ with open(readme_path, "w", encoding="utf-8") as f:
259
+ f.write(content)
260
+
261
+
262
+ def generate_model_card(
263
+ output_path: str,
264
+ config: Union[dict, object],
265
+ template: Optional[
266
+ str
267
+ ], # can be "model_card" OR an absolute/relative path to a .md file
268
+ extra_variables: Optional[Dict] = None,
269
+ ):
270
+ """
271
+ Generate a README.md file from a Jinja2 template and a configuration.
272
+
273
+ - template can be either:
274
+ * a full path to a template file
275
+ * a short name (e.g., "model_card") -> will be looked up inside default_templates_dir
276
+ """
277
+ if template is None:
278
+ template = "model_card_template" # default template name
279
+ # Locate the template
280
+ if os.path.exists(template): # direct file path provided
281
+ template_path = template
282
+ else:
283
+ default_templates_dir = os.path.join(os.path.dirname(__file__), "templates")
284
+ template_path = os.path.join(default_templates_dir, f"{template}.md")
285
+
286
+ if not os.path.exists(template_path):
287
+ raise FileNotFoundError(f"Template not found: {template_path}")
288
+
289
+ variables = {
290
+ "model_id": os.path.basename(output_path),
291
+ "config": config,
292
+ }
293
+ if extra_variables:
294
+ variables.update(extra_variables)
295
+
296
+ content = render_template(template_path, variables)
297
+ write_model_card(output_path, content)
lora_delta_product_m0.5_gradual_t10/generation_config.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "transformers_version": "4.49.0"
4
+ }
lora_delta_product_m0.5_gradual_t10/modeling_tptt.py ADDED
@@ -0,0 +1,1478 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pylint: disable=too-many-lines, too-many-arguments, too-many-positional-arguments, too-many-instance-attributes, too-many-locals
2
+
3
+ """
4
+ This module implements the TPTT model with linear attention (LiZA) and LoRA support.
5
+ Author : Fabien FURFARO
6
+ TPTT : Transforming Pretrained Transformers into Titans (https://arxiv.org/abs/2506.17671)
7
+ """
8
+
9
+ import logging
10
+ import math
11
+ import os
12
+ from pathlib import Path
13
+ import re
14
+ import shutil
15
+ from functools import partial
16
+ from typing import Any, Dict, List, Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn.functional as F
20
+ from einops import rearrange
21
+ from huggingface_hub import hf_hub_download, list_repo_files
22
+ from peft import LoraConfig, PeftModel, get_peft_model
23
+ from safetensors import safe_open
24
+ from safetensors.torch import save_file
25
+ from torch import nn
26
+ from torch.utils.checkpoint import checkpoint
27
+ from transformers import AutoConfig, AutoModelForCausalLM, DynamicCache, PreTrainedModel
28
+ from transformers.configuration_utils import PretrainedConfig
29
+
30
+ from .configuration_tptt import TpttConfig
31
+
32
+ logger = logging.getLogger(__name__) # monitoring
33
+
34
+
35
+ class LCache:
36
+ """Cache for storing intermediate states of linear attention layers."""
37
+
38
+ def __init__(self):
39
+ """Stores per-layer intermediate states: {layer_idx: state_dict}"""
40
+ self.inputs_states: Dict[int, Dict[str, torch.Tensor]] = (
41
+ {}
42
+ ) # recurrent states and qkv buffers
43
+
44
+ def __getitem__(self, layer_idx: int) -> Optional[Dict[str, torch.Tensor]]:
45
+ """Retrieve cached state for a given layer, or None if not present"""
46
+ return self.inputs_states.get(layer_idx, None)
47
+
48
+ def update(self, layer_idx: int, **kwargs):
49
+ """Detach all tensors to avoid retaining computation graphs"""
50
+ detached_kwargs = {
51
+ k: v.detach() if isinstance(v, torch.Tensor) else v
52
+ for k, v in kwargs.items()
53
+ }
54
+ # Update or create the state for the specified layer
55
+ if layer_idx in self.inputs_states:
56
+ self.inputs_states[layer_idx].update(detached_kwargs)
57
+ else:
58
+ self.inputs_states[layer_idx] = detached_kwargs
59
+
60
+ def reset(self):
61
+ """Clear all cached states and reset the token counter"""
62
+ self.inputs_states.clear()
63
+
64
+
65
+ class CausalAvgPool1d(nn.Module):
66
+ """Causal sliding window average (uniform, no shape loss along sequence)"""
67
+
68
+ def __init__(
69
+ self, output_size: int, offsets: tuple[int] = (0, 1, 2), mode: str = "replicate"
70
+ ):
71
+ super().__init__()
72
+ self.offsets = offsets
73
+ self.mode = mode
74
+ self.pool = nn.AdaptiveAvgPool1d(output_size=output_size)
75
+
76
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
77
+ """x: [B, S, F] → [B, S, F → output_size]"""
78
+ x_ = x.transpose(1, 2) # [B, F, S]
79
+ idxs = torch.tensor(self.offsets, device=x.device)
80
+ ksize = idxs.max() - idxs.min() + 1
81
+ w = torch.zeros(ksize, device=x.device, dtype=x.dtype)
82
+ w[idxs - idxs.min()] = 1 / len(self.offsets) # Always uniform weights
83
+ kernel = w.repeat(x_.shape[1], 1).reshape(x_.shape[1], 1, ksize)
84
+ pad_left = -idxs.min().item()
85
+ pad_right = (ksize - 1) - pad_left
86
+ x_pad = F.pad(x_, (pad_left, pad_right), mode=self.mode)
87
+ y = F.conv1d(x_pad, kernel, groups=x_.shape[1]) # pylint: disable=not-callable
88
+ return self.pool(y.transpose(1, 2)) # [B, S, F → output_size]
89
+
90
+
91
+ class LinearAttention(nn.Module):
92
+ """
93
+ Linear multi-head attention layer: [B, S, D] -> [B, S, D]
94
+ Projections + gating + efficient linear attention mechanism (TPTT compatible).
95
+ """
96
+
97
+ def __init__(
98
+ self,
99
+ hidden_dim: int,
100
+ num_heads: int,
101
+ head_dim: Optional[int] = None,
102
+ num_key_value_heads: Optional[int] = None,
103
+ num_key_value_groups: Optional[int] = None,
104
+ bias: bool = True,
105
+ dropout: Optional[float] = None,
106
+ linear_precision: torch.dtype = torch.float32,
107
+ padding_side: str = "right",
108
+ shared_attn: bool = False, # shared attention
109
+ layer_idx: int = 0,
110
+ operator_mode: str = "delta_rule",
111
+ recurrent_config: Optional[Dict[str, Any]] = None,
112
+ linear_cache: Optional[LCache] = None,
113
+ max_chunk_size: int = 64,
114
+ bidirectional: bool = False, # not used if causal
115
+ pooling_config: Optional[Dict[str, Any]] = None,
116
+ ):
117
+ super().__init__()
118
+ if pooling_config is None:
119
+ pooling_config = {
120
+ "offsets": (0, 1, 2),
121
+ "mode": "replicate",
122
+ }
123
+ self.hidden_dim = hidden_dim
124
+ self.num_heads = num_heads
125
+ self.head_dim = head_dim or hidden_dim // num_heads
126
+ self.num_key_value_heads = num_key_value_heads or num_heads
127
+ self.num_key_value_groups = num_key_value_groups or (
128
+ num_heads // (num_key_value_heads or num_heads)
129
+ )
130
+ self.scaling = self.head_dim**-0.5
131
+ self.linear_precision = linear_precision
132
+ self.padding_side = padding_side
133
+
134
+ self.shared_attn = shared_attn
135
+
136
+ if not shared_attn:
137
+ self.q_proj = nn.Linear(hidden_dim, num_heads * self.head_dim, bias=bias)
138
+ self.k_proj = nn.Linear(
139
+ hidden_dim, self.num_key_value_heads * self.head_dim, bias=bias
140
+ )
141
+ self.v_proj = nn.Linear(
142
+ hidden_dim, self.num_key_value_heads * self.head_dim, bias=bias
143
+ )
144
+ self.out_proj = nn.Linear(num_heads * self.head_dim, hidden_dim, bias=bias)
145
+
146
+ self.dropout = nn.Dropout(dropout) if dropout is not None else None
147
+
148
+ self.linear_operator = LinearAttentionOp(
149
+ layer_idx=layer_idx,
150
+ operator_mode=operator_mode,
151
+ recurrent_config=recurrent_config,
152
+ max_chunk_size=max_chunk_size,
153
+ linear_cache=linear_cache,
154
+ linear_precision=linear_precision,
155
+ )
156
+ self.bidirectional = bidirectional
157
+ # Causal average pooling for gating
158
+ self.pooling_config = pooling_config
159
+ self.pool_g = CausalAvgPool1d(
160
+ output_size=self.head_dim * self.num_key_value_heads, **pooling_config
161
+ )
162
+
163
+ def forward(
164
+ self,
165
+ x: Union[List[torch.Tensor], torch.Tensor],
166
+ attn_mask: Optional[torch.Tensor] = None,
167
+ out_proj: Optional[nn.Module] = None,
168
+ **kwargs: Any,
169
+ ) -> torch.Tensor:
170
+ """
171
+ Forward pass for linear attention. Input shape: [B, S, D], output [B, S, D].
172
+ """
173
+
174
+ if not self.shared_attn:
175
+ hidden_states = x[0] if isinstance(x, (list, tuple)) else x
176
+ # Projections
177
+ q = self.q_proj(hidden_states)
178
+ k = self.k_proj(hidden_states)
179
+ v = self.v_proj(hidden_states)
180
+ out_proj = self.out_proj
181
+ else:
182
+ # Shared attention <=> no projections here
183
+ q, k, v = x[0], x[1], x[2]
184
+ out_proj = self.out_proj if out_proj is None else out_proj
185
+
186
+ # get dtype and device
187
+ final_dtype, final_device = q.dtype, q.device
188
+ # Masking if needed
189
+ if attn_mask is not None:
190
+ v = apply_linear_attention_mask(attn_mask, v, self.padding_side)
191
+
192
+ # Forget and Write Gating for linear attn (abusive term)
193
+ f_g, w_g = self.pool_g(k), self.pool_g(v)
194
+
195
+ # Reshape for multi-head
196
+ q = rearrange(q, "b n (h d) -> b h n d", h=self.num_heads)
197
+ k = rearrange(k, "b n (h d) -> b h n d", h=self.num_key_value_heads)
198
+ v = rearrange(v, "b n (h d) -> b h n d", h=self.num_key_value_heads)
199
+
200
+ f_g = rearrange(f_g, "b n (h m) -> b h n m", h=self.num_key_value_heads)
201
+ w_g = rearrange(w_g, "b n (h m) -> b h n m", h=self.num_key_value_heads)
202
+
203
+ # Repeat for GQA
204
+ k = k.repeat_interleave(self.num_key_value_groups, dim=1)
205
+ v = v.repeat_interleave(self.num_key_value_groups, dim=1)
206
+
207
+ f_g = f_g.repeat_interleave(self.num_key_value_groups, dim=1)
208
+ w_g = w_g.repeat_interleave(self.num_key_value_groups, dim=1)
209
+
210
+ ## DeltaNet-style: Silu activation and normalization
211
+ q = F.normalize(F.silu(q), p=2, dim=-1, eps=1e-6)
212
+ k = F.normalize(F.silu(k), p=2, dim=-1, eps=1e-6)
213
+
214
+ ## linear stability part
215
+ v = ensure_stability(v * self.scaling, min_val=-1e4, max_val=1e4)
216
+
217
+ # Apply sigmoid to forget and write gates
218
+ f_g = torch.clamp(torch.sigmoid(f_g), min=1e-6, max=1 - 1e-6)
219
+ w_g = torch.clamp(torch.sigmoid(w_g), min=1e-6, max=1 - 1e-6)
220
+
221
+ # Convert to linear_precision (float32) for numerical stability and get model dtype
222
+ q, k, v, f_g, w_g = (
223
+ x.to(self.linear_precision).contiguous() for x in (q, k, v, f_g, w_g)
224
+ )
225
+ g = (f_g, w_g)
226
+
227
+ # Linear Attention Core, output: [B, H, S, d]
228
+ if self.bidirectional: # Work only with uncausal attention
229
+ # Forward direction
230
+ out_forward = self.linear_operator(q, k, v, g, **kwargs)
231
+ # Backward direction: flip the input sequence on the time dimension (dim=2)
232
+ kwargs_bwd = kwargs.copy()
233
+ kwargs_bwd["use_cache"] = False
234
+ out_backward = self.linear_operator(
235
+ torch.flip(q, dims=[2]),
236
+ torch.flip(k, dims=[2]),
237
+ torch.flip(v, dims=[2]),
238
+ tuple(torch.flip(t, dims=[2]) for t in g),
239
+ **kwargs_bwd,
240
+ )
241
+ # Flip the output back to restore proper order
242
+ out_backward = torch.flip(out_backward, dims=[2])
243
+ # Fusion: here, simple addition
244
+ out = out_forward + out_backward
245
+ else:
246
+ out = self.linear_operator(q, k, v, g, **kwargs)
247
+
248
+ # Merge heads and project: [B, H, S, d] -> [B, S, H*d] -> Out proj
249
+ out = rearrange(out, "b h s d -> b s (h d)")
250
+ # Normalize output (RMS norm). Note: bidirectional compatibility
251
+ out = out / out.pow(2).mean(dim=-1, keepdim=True).add(1e-6).sqrt()
252
+ # Ensure dtype and device consistency
253
+ out = out.to(dtype=final_dtype, device=final_device)
254
+ # Apply output projection
255
+ out = out_proj(out) # [B, S, D]
256
+ out = ensure_stability(out, min_val=-1e4, max_val=1e4)
257
+ # Apply dropout if specified
258
+ if self.dropout is not None:
259
+ out = self.dropout(out)
260
+ return out
261
+
262
+
263
+ class LiZAttention(nn.Module):
264
+ """LiZA Linear Attention module, mixing linear and vanilla attention."""
265
+
266
+ def __init__(
267
+ self,
268
+ base_attn: nn.Module,
269
+ layer_idx: int,
270
+ base_config: PretrainedConfig, # Backbone Config
271
+ linear_cache: Optional[LCache] = None,
272
+ operator_mode: str = "delta_rule",
273
+ recurrent_config: Optional[Dict[str, Any]] = None,
274
+ max_self_attn_length: Optional[int] = None, # unnecessary
275
+ base_scale_attn: bool = False,
276
+ mag_weight: float = 0.5,
277
+ cross_gate: bool = False,
278
+ max_chunk_size: int = 64,
279
+ linear_precision: Union[str, torch.dtype] = "float32",
280
+ padding_side: str = "right", # for tokenizer
281
+ disable_linear_attn: bool = False,
282
+ bidirectional: bool = False, # if True, use bidirectional attention
283
+ pooling_config: Optional[Dict[str, Any]] = None,
284
+ ):
285
+ super().__init__()
286
+ if isinstance(linear_precision, str):
287
+ linear_precision = getattr(torch, linear_precision)
288
+ self.linear_precision = linear_precision
289
+ self.base_attn: nn.Module = base_attn
290
+ self.base_config = base_config
291
+ self.layer_idx = layer_idx
292
+ self.max_self_attn_length = max_self_attn_length
293
+ self.base_scale_attn = base_scale_attn
294
+ self.mag_weight = mag_weight
295
+ self.cross_gate = cross_gate
296
+ self.max_chunk_size = max_chunk_size
297
+ self.linear_precision = linear_precision
298
+ self.padding_side = padding_side
299
+ self.disable_linear_attn = disable_linear_attn
300
+
301
+ (
302
+ self.num_heads,
303
+ self.head_dim,
304
+ self.num_key_value_heads,
305
+ self.num_key_value_groups,
306
+ ) = self._get_attention_parameters(base_attn, base_config)
307
+ self.scaling = self.head_dim**-0.5
308
+
309
+ self.linear_attn = LinearAttention(
310
+ layer_idx=layer_idx,
311
+ shared_attn=True,
312
+ operator_mode=operator_mode,
313
+ recurrent_config=recurrent_config,
314
+ hidden_dim=base_config.hidden_size,
315
+ num_heads=self.num_heads,
316
+ head_dim=self.head_dim,
317
+ num_key_value_heads=self.num_key_value_heads,
318
+ num_key_value_groups=self.num_key_value_groups,
319
+ linear_precision=linear_precision,
320
+ linear_cache=linear_cache,
321
+ max_chunk_size=max_chunk_size,
322
+ padding_side=padding_side,
323
+ bidirectional=bidirectional,
324
+ pooling_config=pooling_config,
325
+ )
326
+
327
+ def _get_attention_parameters(
328
+ self, base_attn: nn.Module, base_config: PretrainedConfig
329
+ ) -> Tuple[Optional[int], Optional[int], Optional[int], Optional[int]]:
330
+ """Retrieve the attention parameters from the base attention module."""
331
+ # first order base attention module and second order config
332
+ num_heads = (
333
+ getattr(base_attn, "num_heads", None)
334
+ or getattr(base_attn, "num_q_heads", None)
335
+ or getattr(base_config, "num_heads", None)
336
+ or getattr(base_config, "num_attention_heads", None)
337
+ )
338
+ head_dim = (
339
+ getattr(base_attn, "head_dim", None)
340
+ or getattr(base_attn, "attention_head_size", None)
341
+ or getattr(base_config, "head_dim", None)
342
+ or (
343
+ getattr(base_config, "hidden_size", None) // num_heads
344
+ if num_heads and getattr(base_config, "hidden_size", None)
345
+ else None
346
+ )
347
+ )
348
+ num_key_value_heads = (
349
+ getattr(base_attn, "num_kv_heads", None)
350
+ or getattr(base_attn, "num_k_heads", None)
351
+ or getattr(base_config, "num_key_value_heads", None)
352
+ or num_heads # fallback
353
+ )
354
+ num_key_value_groups = getattr(base_attn, "num_key_value_groups", None) or (
355
+ num_heads // num_key_value_heads if num_heads and num_key_value_heads else 1
356
+ )
357
+ return (
358
+ num_heads,
359
+ head_dim,
360
+ num_key_value_heads,
361
+ num_key_value_groups,
362
+ )
363
+
364
+ def _apply_shared_projections(
365
+ self, hidden_states: torch.Tensor
366
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, nn.Module]:
367
+ base_attn = self.base_attn
368
+ if hasattr(base_attn, "q_proj"):
369
+ # LLama, OLMO and Mistral style
370
+ q = base_attn.q_proj(hidden_states)
371
+ k = base_attn.k_proj(hidden_states)
372
+ v = base_attn.v_proj(hidden_states)
373
+ out_proj = base_attn.o_proj
374
+ elif hasattr(base_attn, "qkv_proj"):
375
+ # OpenELM and GPT-Neo style : QKV fused, split on the last dimension
376
+ qkv = base_attn.qkv_proj(hidden_states)
377
+ q, k, v = split_qkv(base_attn, qkv)
378
+ out_proj = base_attn.out_proj
379
+ elif hasattr(base_attn, "c_attn") and hasattr(base_attn, "c_proj"):
380
+ # GPT-2 style
381
+ qkv = base_attn.c_attn(hidden_states)
382
+ q, k, v = qkv.chunk(3, dim=-1)
383
+ out_proj = base_attn.c_proj
384
+ elif all(hasattr(base_attn, n) for n in ["query", "key", "value"]):
385
+ # BERT - ViT
386
+ q = base_attn.query(hidden_states)
387
+ k = base_attn.key(hidden_states)
388
+ v = base_attn.value(hidden_states)
389
+ out_proj = getattr(base_attn, "dense", None) # ou output.dense
390
+ else:
391
+ raise ValueError("Unsupported attention module: cannot find projections.")
392
+ # Ensure stability
393
+ q = ensure_stability(q, min_val=-1e4, max_val=1e4)
394
+ k = ensure_stability(k, min_val=-1e4, max_val=1e4)
395
+ v = ensure_stability(v, min_val=-1e4, max_val=1e4)
396
+ return q, k, v, out_proj
397
+
398
+ def _process_self_attn(
399
+ self,
400
+ hidden_states: torch.Tensor,
401
+ attention_mask: Optional[torch.Tensor],
402
+ kwargs,
403
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[DynamicCache], int]:
404
+ """Process the self-attention part (with truncation)."""
405
+ if self.max_self_attn_length: # Not needed for SWA (nonparam memorize context)
406
+ hidden_states, attention_mask = truncate_attention_mask(
407
+ hidden_states, attention_mask, self.max_self_attn_length
408
+ )
409
+
410
+ if kwargs.get("position_embeddings", None) is not None:
411
+ cos, sin = kwargs["position_embeddings"]
412
+ cos = cos[:, -self.max_self_attn_length :]
413
+ sin = sin[:, -self.max_self_attn_length :]
414
+ kwargs["position_embeddings"] = (cos, sin)
415
+
416
+ if isinstance(kwargs.get("past_key_value", None), DynamicCache):
417
+ # cache management
418
+ if (
419
+ len(kwargs["past_key_value"]) > self.layer_idx
420
+ and self.layer_idx == 0
421
+ ):
422
+ kwargs["past_key_value"].crop(self.max_self_attn_length - 1)
423
+
424
+ # Standard attention (mask and rotation is applied inside)
425
+ base_attn_outputs = self.base_attn(
426
+ hidden_states,
427
+ attention_mask=attention_mask,
428
+ **kwargs,
429
+ )
430
+
431
+ if isinstance(base_attn_outputs, tuple):
432
+ if len(base_attn_outputs) == 3:
433
+ o_base, attn_weights, present_key_value = base_attn_outputs
434
+ expected_attn_mode = 3
435
+ elif len(base_attn_outputs) == 2:
436
+ o_base, attn_weights = base_attn_outputs
437
+ present_key_value, expected_attn_mode = None, 2
438
+ else:
439
+ raise ValueError(
440
+ f"Unexpected number of outputs from base_attn: {len(base_attn_outputs)}"
441
+ )
442
+ else:
443
+ o_base = base_attn_outputs
444
+ attn_weights, present_key_value, expected_attn_mode = None, None, 1
445
+ # Ensure stability
446
+ o_base = ensure_stability(o_base, min_val=-1e4, max_val=1e4)
447
+ return o_base, attn_weights, present_key_value, expected_attn_mode
448
+
449
+ def _prepare_attn_mixin(
450
+ self,
451
+ o_lin: torch.Tensor,
452
+ o_base: torch.Tensor,
453
+ tensor_dtype: torch.dtype,
454
+ eps: float = 1e-5,
455
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
456
+ """Prepare linear attn for mixing with self attn."""
457
+ # Force cast typing, shape : [b n (h d)]
458
+ o_lin = o_lin.to(tensor_dtype)
459
+ o_base = o_base.to(tensor_dtype)
460
+ # feature scaling
461
+ if self.base_scale_attn:
462
+ scaler = o_base.pow(2).mean(dim=-1, keepdim=True).add(eps).sqrt()
463
+ o_lin = scaler * o_lin
464
+ return o_lin, o_base
465
+
466
+ def _apply_mag(
467
+ self, linear_attention: torch.Tensor, softmax_attention: torch.Tensor
468
+ ) -> torch.Tensor:
469
+ """Apply the MAG strategy"""
470
+ # Left-Padding management
471
+ if linear_attention.shape[1] != softmax_attention.shape[1]:
472
+ left_trunc = min(linear_attention.shape[1], softmax_attention.shape[1])
473
+ linear_attention, softmax_attention = (
474
+ linear_attention[:, -left_trunc:],
475
+ softmax_attention[:, -left_trunc:],
476
+ )
477
+ # NAM : Neural Attention Mixer (with graph forcing)
478
+ mag_weight = torch.tensor(
479
+ self.mag_weight,
480
+ dtype=softmax_attention.dtype,
481
+ device=softmax_attention.device,
482
+ )
483
+ softmax_weighted = (1 - mag_weight) * softmax_attention
484
+ linear_weighted = mag_weight * linear_attention
485
+ if self.cross_gate:
486
+ output_attention = (
487
+ softmax_weighted + linear_weighted + softmax_weighted * linear_weighted
488
+ ) # complex cross product (unlinear interaction)
489
+ else:
490
+ output_attention = softmax_weighted + linear_weighted # classic
491
+
492
+ if torch.allclose(softmax_weighted, output_attention):
493
+ logger.info(
494
+ "[LOG] layer : %s, softmax_weighted and output_attention are close.",
495
+ self.layer_idx,
496
+ )
497
+ # Final output
498
+ return ensure_stability(output_attention, min_val=-1e4, max_val=1e4)
499
+
500
+ def forward(
501
+ self,
502
+ hidden_states: torch.Tensor,
503
+ attention_mask: Optional[torch.Tensor] = None,
504
+ **kwargs,
505
+ ) -> torch.Tensor:
506
+ """Mix linear and self attention forward"""
507
+ device = hidden_states.device
508
+ tensor_dtype = hidden_states.dtype
509
+ self.base_attn.to(device)
510
+
511
+ if self.training:
512
+ kwargs.pop("past_key_value", None)
513
+ kwargs["use_cache"] = False
514
+ elif "use_cache" not in kwargs:
515
+ kwargs.pop("past_key_value", None)
516
+ kwargs["use_cache"] = False
517
+
518
+ kwargs.pop("position_ids", None) # obsolete
519
+
520
+ # Apply shared projections
521
+ q, k, v, out_proj = self._apply_shared_projections(hidden_states)
522
+
523
+ # Apply linear attention to hidden states
524
+ o_lin = self.linear_attn(
525
+ x=[q, k, v], attn_mask=attention_mask, out_proj=out_proj, **kwargs
526
+ )
527
+
528
+ # Process self attn with truncation
529
+ o_base, attn_weights, present_key_value, expected_attn_mode = (
530
+ self._process_self_attn(hidden_states, attention_mask, kwargs)
531
+ )
532
+
533
+ # Prepare output mixing
534
+ o_lin, o_base = self._prepare_attn_mixin(o_lin, o_base, tensor_dtype, eps=1e-5)
535
+
536
+ # Apply Memory as Gate in self-attention (with length management and ablation)
537
+ out = o_base if self.disable_linear_attn else self._apply_mag(o_lin, o_base)
538
+
539
+ # Return output following transformer convention
540
+ if expected_attn_mode == 3:
541
+ return out, attn_weights, present_key_value
542
+ if expected_attn_mode == 2:
543
+ return out, attn_weights
544
+ return out
545
+
546
+
547
+ def load_tptt_safetensors(
548
+ repo_or_path: str,
549
+ model: Union[PreTrainedModel, PeftModel],
550
+ subfolder: Optional[str] = None,
551
+ token: Optional[str] = None,
552
+ ) -> Union[PreTrainedModel, PeftModel]:
553
+ """Load Tptt safetensor from LoRA/PEFT weights and adapt keys if needed."""
554
+ # sharding not supported yet (e.g. : -00001-of-00005.safetensors, ...)
555
+ fname = "adapter_model.safetensors"
556
+ # subfolder management
557
+ if subfolder:
558
+ repo_or_path_norm = os.path.normpath(repo_or_path)
559
+ subfolder_norm = os.path.normpath(subfolder)
560
+ if not repo_or_path_norm.endswith(subfolder_norm):
561
+ fname = f"{subfolder}/{fname}" if subfolder else fname
562
+ # Find file path
563
+ if os.path.isdir(repo_or_path):
564
+ path = os.path.join(repo_or_path, fname)
565
+ if not os.path.exists(path):
566
+ return model
567
+ else:
568
+ if fname not in list_repo_files(repo_or_path, token=token):
569
+ return model
570
+ path = hf_hub_download(repo_or_path, fname, token=token)
571
+
572
+ # Load weights from safetensors
573
+ with safe_open(path, framework="pt") as f:
574
+ state_dict = {k: f.get_tensor(k) for k in f.keys()}
575
+
576
+ # Adapt LoRA/Specific keys if needed (add .default if expected by the model)
577
+ def adapt_keys(sd, model):
578
+ model_keys = list(model.state_dict().keys())
579
+ if any(k.startswith("tptt_model.base_model.") for k in model_keys):
580
+ prefix = "tptt_model.base_model."
581
+ elif any(k.startswith("base_model.") for k in model_keys):
582
+ prefix = "base_model."
583
+ else:
584
+ prefix = ""
585
+
586
+ has_base_attn = any(".base_attn." in k for k in model_keys)
587
+
588
+ def adapt_key(k):
589
+ k_ = k if k.startswith(prefix) else prefix + k
590
+ # first, verify and modify base_attn (LiZA)
591
+ if ".base_attn." in k_ and not has_base_attn:
592
+ k_ = k_.replace(".base_attn.", ".")
593
+ # change LoRA if needed
594
+ if (
595
+ k_.endswith("lora_A.weight") or k_.endswith("lora_B.weight")
596
+ ) and k_.replace(".weight", ".default.weight") in model_keys:
597
+ k_ = k_.replace(".weight", ".default.weight")
598
+ return k_
599
+
600
+ return {adapt_key(k): v for k, v in sd.items()}
601
+
602
+ state_dict = adapt_keys(state_dict, model)
603
+
604
+ # Cast tensors to the expected dtype of the model parameters
605
+ model_state_dict = model.state_dict()
606
+ for k, v in state_dict.items():
607
+ if k in model_state_dict:
608
+ expected_dtype = model_state_dict[k].dtype
609
+ if v.dtype != expected_dtype:
610
+ state_dict[k] = v.to(expected_dtype)
611
+
612
+ logger.info("Input LoRA/Specific keys: %s", [k for k in state_dict.keys()])
613
+
614
+ # Load into model
615
+ missing, unexpected = model.load_state_dict(state_dict, strict=False, assign=True)
616
+ missing_lora = [k for k in missing if "lora" in k]
617
+ if missing_lora:
618
+ logger.warning("Missing keys: %s", missing_lora)
619
+ if unexpected:
620
+ logger.warning("Unexpected keys: %s", unexpected)
621
+ return model
622
+
623
+
624
+ def get_tptt_model( # pylint: disable=too-many-arguments, too-many-positional-arguments
625
+ model: nn.Module,
626
+ base_config: PretrainedConfig, # ou LlamaConfig, MistralConfig, etc.
627
+ linear_cache: Optional[LCache] = None,
628
+ liza_attention: nn.Module = LiZAttention,
629
+ target_modules_names: Optional[list[str]] = None,
630
+ operator_mode: str = "delta_rule",
631
+ recurrent_config: Optional[Dict[str, Any]] = None,
632
+ base_scale_attn: bool = False,
633
+ mag_weight: float = 0.5,
634
+ cross_gate: bool = False,
635
+ max_chunk_size: int = 64,
636
+ linear_precision: torch.dtype = torch.float32,
637
+ max_self_attn_length: Optional[int] = None, # unnecessary
638
+ padding_side: str = "right", # for tokenizer
639
+ bidirectional: bool = False, # if True, use bidirectional attention
640
+ pooling_config: Optional[Dict[str, Any]] = None,
641
+ **kwargs, # quickfix unexpected arguments
642
+ ) -> Tuple[PreTrainedModel, LCache]:
643
+ """Replace target modules in a model with LiZAttention."""
644
+ if target_modules_names is None:
645
+ target_modules_names = ["attn", "self_attn", "attention"]
646
+ # Find target modules by suffix (e.g., "attn", "attention")
647
+ target_modules_names = [
648
+ name
649
+ for name, _ in model.named_modules()
650
+ if any(name.endswith(suffix) for suffix in target_modules_names)
651
+ and not any(f".{suffix}." in name for suffix in target_modules_names)
652
+ ]
653
+ if not target_modules_names:
654
+ raise ValueError(
655
+ f"Target modules '{target_modules_names}' not found in the model."
656
+ )
657
+ # Prepare recurrent config
658
+ linear_cache = linear_cache or LCache()
659
+ # Inject LiZAttention into the model
660
+ for name, _ in model.named_modules():
661
+ if name in target_modules_names:
662
+ parent = model
663
+ *path, last = name.split(".")
664
+ for p in path:
665
+ parent = getattr(parent, p)
666
+ layer_idx = extract_layer_idx(name)
667
+ setattr(
668
+ parent,
669
+ last,
670
+ liza_attention(
671
+ getattr(parent, last),
672
+ layer_idx=layer_idx,
673
+ base_config=base_config,
674
+ linear_cache=linear_cache,
675
+ operator_mode=operator_mode,
676
+ recurrent_config=recurrent_config,
677
+ max_self_attn_length=max_self_attn_length,
678
+ base_scale_attn=base_scale_attn,
679
+ mag_weight=mag_weight,
680
+ cross_gate=cross_gate,
681
+ max_chunk_size=max_chunk_size,
682
+ linear_precision=linear_precision,
683
+ padding_side=padding_side,
684
+ bidirectional=bidirectional,
685
+ pooling_config=pooling_config,
686
+ ),
687
+ )
688
+ return model, linear_cache
689
+
690
+
691
+ def save_tptt_safetensors(model, path: str, name: str = "adapter_model.safetensors"):
692
+ """Save trainable LoRA/Specific weights and adapting key names"""
693
+ # 1. Get the full state_dict
694
+ all_sd = model.state_dict()
695
+
696
+ # 2. Identify trainable parameter names (usually only LoRA/PEFT adapters)
697
+ trainable_keys = [
698
+ name for name, param in model.named_parameters() if param.requires_grad
699
+ ] # Also, you can manually select specific keys in model after load
700
+
701
+ # 3. Filter and adapt the keys (Remove custom model encapsulation info)
702
+ to_save = {
703
+ k.replace("tptt_model.", "").replace("base_model.", ""): all_sd[k]
704
+ for k in trainable_keys
705
+ }
706
+
707
+ # 4. Save the filtered adapters to a safetensors file
708
+ if to_save:
709
+ os.makedirs(os.path.dirname(path), exist_ok=True)
710
+ # sharding not supported yet (e.g. : -00001-of-00005.safetensors, ...)
711
+ save_file(to_save, os.path.join(path, name))
712
+
713
+
714
+ class TpttModel(PreTrainedModel):
715
+ """
716
+ TPTT model wrapper with linear attention (LiZA) and LoRA support.
717
+ Handles only architecture and weights.
718
+ """
719
+
720
+ config_class = TpttConfig
721
+
722
+ def __init__(
723
+ self,
724
+ config: TpttConfig,
725
+ **kwargs,
726
+ ):
727
+ """
728
+ Initialize TpttModel with a given config and backbone.
729
+ Injects LiZA attention modules into the backbone.
730
+ """
731
+ super().__init__(config, **kwargs)
732
+ repo_or_path = getattr(config, "_base_path", None) or config._name_or_path
733
+
734
+ # 1. Load backbone (with subfolder management) :
735
+ kwargs_bb = kwargs.copy()
736
+ if config.base_model_subfolder is not None:
737
+ kwargs_bb["subfolder"] = config.base_model_subfolder
738
+ else:
739
+ kwargs_bb.pop("subfolder", None)
740
+ tptt_model = AutoModelForCausalLM.from_pretrained(
741
+ config.base_model_name, **kwargs_bb
742
+ )
743
+
744
+ # 2. Inject LiZA attention
745
+ self.linear_cache = LCache()
746
+ tptt_model, self.linear_cache = get_tptt_model(
747
+ tptt_model, config, self.linear_cache, **config.to_dict()
748
+ )
749
+
750
+ # 3. Apply LoRA/Specific if present and configured
751
+ if config.lora_config is not None:
752
+ lora_config_obj = LoraConfig(**config.lora_config)
753
+ tptt_model = get_peft_model(tptt_model, lora_config_obj)
754
+ else:
755
+ tptt_model = set_trainable_parameters(tptt_model)
756
+
757
+ # 4. Load safetensor if tptt/peft adaptor in repo
758
+ if repo_or_path:
759
+ tptt_model = load_tptt_safetensors(
760
+ repo_or_path,
761
+ tptt_model,
762
+ subfolder=kwargs.get("subfolder", None),
763
+ token=kwargs.get("token", None),
764
+ )
765
+ self.tptt_model = tptt_model
766
+
767
+ def forward(
768
+ self,
769
+ input_ids: Optional[torch.LongTensor] = None,
770
+ attention_mask: Optional[torch.Tensor] = None,
771
+ labels: Optional[torch.LongTensor] = None,
772
+ **kwargs,
773
+ ):
774
+ """Forward pass. All arguments are passed to the underlying base model."""
775
+ if self.training:
776
+ kwargs["use_cache"] = False
777
+ kwargs.pop("num_items_in_batch", None)
778
+ elif "use_cache" not in kwargs: # evaluation
779
+ kwargs.pop("num_items_in_batch", None)
780
+ kwargs["use_cache"] = False
781
+ return self.tptt_model(
782
+ input_ids=input_ids, attention_mask=attention_mask, labels=labels, **kwargs
783
+ )
784
+
785
+ def generate(self, *args, **kwargs):
786
+ """Delegate the generate call to the backbone model, which supports generation"""
787
+ return self.tptt_model.generate(*args, **kwargs)
788
+
789
+ def save_pretrained(self, path: str, **kwargs):
790
+ """Save model weights, config, and source code to the given path."""
791
+ # 0. Save complete tptt config (with or without LoRA)
792
+ super().save_pretrained(path, **kwargs) # pylint: disable=no-member
793
+ self._adjust_save_strategy(path, **kwargs)
794
+ # 1. Save true weights and adapte keys
795
+ save_tptt_safetensors(self, path)
796
+ # 2. Copy Python files for trust_remote_code
797
+ self._copy_source_files(path, **kwargs)
798
+
799
+ def _adjust_save_strategy(self, path: str, **kwargs):
800
+ """Re-adapt/remove the weight safetensor and saved adapter config"""
801
+ if isinstance(self.tptt_model, PeftModel):
802
+ self.tptt_model.save_pretrained(path, **kwargs)
803
+ safetensor_path = os.path.join(path, "model.safetensors")
804
+ if os.path.exists(safetensor_path):
805
+ os.remove(safetensor_path)
806
+ adapter_path = os.path.join(path, "adapter_config.json")
807
+ if os.path.exists(adapter_path):
808
+ os.remove(adapter_path)
809
+
810
+ def _copy_source_files(self, target_path: str, **kwargs):
811
+ """Copy all .py files from package directory for trust_remote_code."""
812
+ src_dir = os.path.dirname(os.path.abspath(__file__))
813
+ dst_dir = (
814
+ f"./{str(Path(target_path).parts[0])}"
815
+ if kwargs.get("subfolder", False)
816
+ else target_path
817
+ )
818
+ for fname in os.listdir(src_dir):
819
+ if fname.endswith(".py"):
820
+ src = os.path.join(src_dir, fname)
821
+ dst = os.path.join(dst_dir, fname)
822
+ shutil.copy2(src, dst)
823
+
824
+ def retie_lm_after_load(self, **kwargs):
825
+ """Re-link lm_head after loading external weights."""
826
+ embed_lm = find_embedding_lm(self.tptt_model)
827
+ if embed_lm is not None and hasattr(self.tptt_model, "lm_head"):
828
+ if self.tptt_model.lm_head is None: # ensure lm_head exists
829
+ self.tptt_model.lm_head = nn.Linear(
830
+ embed_lm.weight.shape[1], embed_lm.weight.shape[0], bias=False
831
+ )
832
+ if kwargs.get("tie_word_embeddings", True):
833
+ self.tptt_model.lm_head.weight = embed_lm.weight # share weights
834
+ logger.info("Weights of lm_head have been shared with embedding.")
835
+ else:
836
+ self.tptt_model.lm_head.weight = nn.Parameter(embed_lm.weight.clone())
837
+ logger.info("Weights of lm_head have been cloned from the embedding.")
838
+
839
+ @classmethod
840
+ def from_pretrained(cls, pretrained_model_name_or_path=None, *model_args, **kwargs):
841
+ """Custom from_pretrained that accepts the standard positional argument"""
842
+ config = kwargs.pop("config", None)
843
+ repo_or_path = (
844
+ pretrained_model_name_or_path
845
+ or kwargs.pop("pretrained_model_name_or_path", None)
846
+ or kwargs.pop("repo_or_path", None)
847
+ or (getattr(config, "_base_path", None) if config else None)
848
+ or (getattr(config, "_name_or_path", None) if config else None)
849
+ )
850
+
851
+ if config is None and repo_or_path is not None:
852
+ config = AutoConfig.from_pretrained(repo_or_path, **kwargs)
853
+ model = cls(config, *model_args, **kwargs)
854
+ model.retie_lm_after_load(**kwargs)
855
+ return model
856
+
857
+
858
+ TpttModel.register_for_auto_class("AutoModelForCausalLM")
859
+
860
+
861
+ class LinearAttentionOp(nn.Module):
862
+ """Base class for linear attention operators."""
863
+
864
+ def __init__(
865
+ self,
866
+ layer_idx: int,
867
+ operator_mode: str = "delta_rule",
868
+ recurrent_config: Optional[dict] = None,
869
+ max_chunk_size: int = 64,
870
+ linear_cache: Optional[LCache] = None,
871
+ linear_precision: torch.dtype = torch.float32,
872
+ ):
873
+ super().__init__()
874
+ self.layer_idx = layer_idx
875
+ if recurrent_config is None:
876
+ operator_mode = "delta_rule" # force default operator mode if no config
877
+ recurrent_config = {
878
+ "order": 1,
879
+ "gate_type": "k",
880
+ "linear": True,
881
+ "trick": "derivative",
882
+ }
883
+ self.operator_mode = operator_mode
884
+ self.order = recurrent_config["order"]
885
+ self.gate_type = recurrent_config["gate_type"]
886
+ self.linear = recurrent_config["linear"]
887
+ self.trick = recurrent_config["trick"]
888
+
889
+ self.max_chunk_size = max_chunk_size
890
+ self.linear_cache = linear_cache or LCache()
891
+ self.linear_precision = linear_precision
892
+
893
+ def compute_gate(self, beta: Tuple[torch.Tensor]) -> torch.Tensor:
894
+ """
895
+ Compute the gating tensor according to the gate_type.
896
+ """
897
+ if self.gate_type == "k":
898
+ return torch.clamp(beta[0], min=1e-6, max=1 - 1e-6)
899
+ if self.gate_type == "v":
900
+ return torch.clamp(beta[1], min=1e-6, max=1 - 1e-6)
901
+ if self.gate_type == "kv":
902
+ return torch.clamp(beta[0] * beta[1], min=1e-6, max=1 - 1e-6)
903
+ raise ValueError(f"Unsupported gate_type: {self.gate_type}")
904
+
905
+ def get_cache(self, use_cache: bool) -> Tuple[
906
+ Optional[torch.Tensor],
907
+ Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]],
908
+ ]:
909
+ """
910
+ Retrieve recurrent state and qkv buffers from the cache.
911
+ """
912
+ if not use_cache:
913
+ return None, None
914
+ last_state = self.linear_cache[self.layer_idx]
915
+ if last_state is not None:
916
+ recurrent_state = last_state.get("recurrent_state", None)
917
+ qkv_buffers = last_state.get("qkv", None)
918
+ else:
919
+ recurrent_state = None
920
+ qkv_buffers = None
921
+ return recurrent_state, qkv_buffers
922
+
923
+ def save_cache(
924
+ self,
925
+ use_cache: bool,
926
+ q: torch.Tensor,
927
+ k: torch.Tensor,
928
+ v: torch.Tensor,
929
+ gate: torch.Tensor,
930
+ state: torch.Tensor,
931
+ ) -> None:
932
+ """
933
+ Save the recurrent state and qkv buffers to the cache.
934
+ """
935
+ if not use_cache:
936
+ return
937
+ if self.order > 1:
938
+ qkv_buffers = (
939
+ q[:, :, -(self.order - 1) :, :],
940
+ k[:, :, -(self.order - 1) :, :],
941
+ v[:, :, -(self.order - 1) :, :],
942
+ gate[:, :, -(self.order - 1) :, :],
943
+ )
944
+ else:
945
+ qkv_buffers = None
946
+ self.linear_cache.update(self.layer_idx, recurrent_state=state, qkv=qkv_buffers)
947
+
948
+ def forward(
949
+ self,
950
+ q: torch.Tensor,
951
+ k: torch.Tensor,
952
+ v: torch.Tensor,
953
+ beta: Union[Tuple[torch.Tensor], torch.Tensor],
954
+ **kwargs,
955
+ ) -> torch.Tensor:
956
+ """
957
+ Forward pass for the attention operator.
958
+ """
959
+ # Ensure linear_precision for numerical stability (float32)
960
+ q, k, v = [x.to(self.linear_precision) for x in (q, k, v)]
961
+ if isinstance(beta, (tuple, list)):
962
+ beta = tuple(b.to(self.linear_precision) for b in beta)
963
+ else:
964
+ beta = beta.to(self.linear_precision)
965
+
966
+ gate = self.compute_gate(beta)
967
+
968
+ # Retrieve cache if needed
969
+ use_cache = kwargs.get("use_cache", False)
970
+ recurrent_state, qkvb = self.get_cache(use_cache)
971
+
972
+ if qkvb is not None and qkvb[0].shape == q.shape:
973
+ q = torch.cat([qkvb[0].to(q.device), q], dim=2).to(self.linear_precision)
974
+ k = torch.cat([qkvb[1].to(q.device), k], dim=2).to(self.linear_precision)
975
+ v = torch.cat([qkvb[2].to(q.device), v], dim=2).to(self.linear_precision)
976
+ gate = torch.cat([qkvb[3].to(q.device), gate], dim=2).to(
977
+ self.linear_precision
978
+ )
979
+
980
+ output, state = self.chunk_delta_product_forward(
981
+ q,
982
+ k,
983
+ v,
984
+ gate,
985
+ self.max_chunk_size,
986
+ n=self.order,
987
+ trick=self.trick,
988
+ linear=self.linear,
989
+ initial_state=recurrent_state,
990
+ use_checkpoint=not (use_cache),
991
+ linear_precision=self.linear_precision,
992
+ )
993
+
994
+ # Save cache if needed
995
+ self.save_cache(use_cache, q, k, v, gate, state)
996
+
997
+ return output
998
+
999
+ @staticmethod
1000
+ def chunk_delta_product_forward(
1001
+ query: torch.Tensor,
1002
+ key: torch.Tensor,
1003
+ value: torch.Tensor,
1004
+ beta_gate: torch.Tensor,
1005
+ chunk_size: int,
1006
+ n: int = 1,
1007
+ trick: str = "derivative",
1008
+ linear: bool = True,
1009
+ initial_state: Optional[torch.Tensor] = None,
1010
+ use_checkpoint: bool = True,
1011
+ linear_precision: torch.dtype = torch.float32,
1012
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
1013
+ """
1014
+ Chunkwise parallel implementation https://arxiv.org/abs/2406.06484
1015
+ For each chunk, processes chunk_size * n_orders steps (virtual tokens) in order.
1016
+ """
1017
+
1018
+ # --- Main chunk_delta_product_forward logic ---
1019
+
1020
+ batch_size, num_heads, seq_len, head_dim = query.shape
1021
+ chunk_size = get_valid_chunk_size(seq_len, chunk_size)
1022
+ num_chunks = seq_len // chunk_size
1023
+
1024
+ query_n = query if n == 1 else expand_virtual_tokens(query, n, trick)
1025
+ key_n = key if n == 1 else expand_virtual_tokens(key, n, trick)
1026
+ value_n = value if n == 1 else expand_virtual_tokens(value, n, trick)
1027
+ beta_n = beta_gate if n == 1 else expand_virtual_tokens(beta_gate, n, trick)
1028
+
1029
+ q_chunks = chunk_sequence(query_n, num_chunks, chunk_size * n)
1030
+ k_chunks = chunk_sequence(key_n, num_chunks, chunk_size * n)
1031
+ v_chunks = chunk_sequence(value_n, num_chunks, chunk_size * n)
1032
+ beta_chunks = chunk_sequence(beta_n, num_chunks, chunk_size * n)
1033
+
1034
+ k_beta = k_chunks * beta_chunks
1035
+ v_beta = v_chunks * beta_chunks
1036
+
1037
+ householder = -(k_beta @ k_chunks.transpose(-2, -1)).tril(-1)
1038
+ householder = ensure_stability(householder, min_val=-1e4, max_val=1e4)
1039
+
1040
+ # size : N = chunk_size * n
1041
+ inv_hh = fast_invert_matrix(householder, dtype=linear_precision) # [(...),N,N]
1042
+
1043
+ w = ensure_stability(torch.matmul(inv_hh, k_beta), min_val=-1e4, max_val=1e4)
1044
+ u = ensure_stability(torch.matmul(inv_hh, v_beta), min_val=-1e4, max_val=1e4)
1045
+
1046
+ state_shape = (batch_size, num_heads, n, head_dim, head_dim)
1047
+ if initial_state is not None and initial_state.shape == state_shape:
1048
+ state = initial_state.to(device=query.device, dtype=linear_precision)
1049
+ else:
1050
+ state = torch.full(
1051
+ state_shape,
1052
+ fill_value=1e-6, # stability if unlinear activation
1053
+ device=query.device,
1054
+ dtype=linear_precision,
1055
+ )
1056
+
1057
+ output, final_state = sequential_delta_product_scan(
1058
+ q_chunks.to(dtype=linear_precision),
1059
+ w.to(dtype=linear_precision),
1060
+ u.to(dtype=linear_precision),
1061
+ n,
1062
+ linear,
1063
+ chunk_size,
1064
+ state.to(dtype=linear_precision),
1065
+ linear_precision=linear_precision,
1066
+ use_checkpoint=use_checkpoint,
1067
+ )
1068
+
1069
+ idx_last_order = torch.arange(chunk_size, device=output.device) * n + (n - 1)
1070
+ output = output[:, :, :, idx_last_order, :] # [B, H, num_chunks, chunk_size, D]
1071
+ output = output.reshape(batch_size, num_heads, seq_len, head_dim)
1072
+
1073
+ return output.to(dtype=linear_precision), final_state.to(dtype=linear_precision)
1074
+
1075
+
1076
+ def sequential_delta_product_scan(
1077
+ q_chunks: torch.Tensor,
1078
+ w: torch.Tensor,
1079
+ u: torch.Tensor,
1080
+ n_orders: int,
1081
+ linear_activation: bool,
1082
+ current_chunk_size: int,
1083
+ initial_recurrent_state: torch.Tensor,
1084
+ linear_precision: torch.dtype,
1085
+ use_checkpoint: bool,
1086
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
1087
+ """
1088
+ DeltaProduct implementation https://arxiv.org/abs/2502.10297
1089
+ Implements the per-token Householder state updates.
1090
+ """
1091
+ batch, head, num_chunks_inner, chunk_n_total, dim = q_chunks.shape
1092
+ output_inner = torch.empty_like(q_chunks)
1093
+ # initial_recurrent_state is H_{last_token_of_prev_chunk, n-1} ([B, H, D, D])
1094
+ h_0_base = initial_recurrent_state[:, :, -1, :, :].clone()
1095
+
1096
+ def process_one_chunk(
1097
+ q_chunk_params: torch.Tensor,
1098
+ w_chunk_params: torch.Tensor,
1099
+ u_chunk_params: torch.Tensor,
1100
+ h_0_base: torch.Tensor,
1101
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
1102
+ """
1103
+ Process a single chunk (with per-token state for n_orders > 1).
1104
+ """
1105
+ o_intra_current_chunk = torch.zeros(
1106
+ batch,
1107
+ head,
1108
+ chunk_n_total,
1109
+ dim,
1110
+ device=q_chunk_params.device,
1111
+ dtype=linear_precision,
1112
+ )
1113
+ o_inter_current_chunk = torch.zeros_like(o_intra_current_chunk)
1114
+ current_accumulated_state_per_token = (
1115
+ h_0_base.unsqueeze(2).expand(-1, -1, current_chunk_size, -1, -1).clone()
1116
+ ) # [B, H, current_chunk_size, D, D]
1117
+
1118
+ for step in range(n_orders):
1119
+ idx_virtual_tokens = (
1120
+ torch.arange(current_chunk_size, device=q_chunk_params.device)
1121
+ * n_orders
1122
+ + step
1123
+ )
1124
+ q_s = q_chunk_params[:, :, idx_virtual_tokens, :]
1125
+ w_s = w_chunk_params[:, :, idx_virtual_tokens, :]
1126
+ u_s = u_chunk_params[:, :, idx_virtual_tokens, :]
1127
+
1128
+ state_input_for_this_step = current_accumulated_state_per_token
1129
+
1130
+ ## BLAS/cuBLAS einsum "bhcd,bhcdd->bhcd"
1131
+ k_trans_h_old = (
1132
+ torch.matmul(
1133
+ w_s.unsqueeze(-2),
1134
+ state_input_for_this_step,
1135
+ )
1136
+ .squeeze(-2)
1137
+ .to(dtype=linear_precision)
1138
+ )
1139
+
1140
+ u_val = u_s - k_trans_h_old
1141
+
1142
+ o_inter_current_chunk[:, :, idx_virtual_tokens, :] = (
1143
+ torch.matmul(q_s.unsqueeze(-2), state_input_for_this_step)
1144
+ .squeeze(-2)
1145
+ .to(dtype=linear_precision)
1146
+ )
1147
+
1148
+ ## BLAS/cuBLAS einsum "bhcd,bhcd->bhcd"
1149
+ o_intra_current_chunk[:, :, idx_virtual_tokens, :] = (q_s * u_val).to(
1150
+ dtype=linear_precision
1151
+ )
1152
+
1153
+ outer_product_term = torch.matmul(w_s.unsqueeze(-1), u_val.unsqueeze(-2))
1154
+ new_state_i_per_token = state_input_for_this_step + outer_product_term
1155
+ new_state_i_per_token = ensure_stability(
1156
+ new_state_i_per_token, min_val=-1e4, max_val=1e4
1157
+ )
1158
+ current_accumulated_state_per_token = new_state_i_per_token.to(
1159
+ dtype=linear_precision
1160
+ )
1161
+ # Return all needed for next chunk
1162
+ return (
1163
+ o_intra_current_chunk,
1164
+ o_inter_current_chunk,
1165
+ current_accumulated_state_per_token[:, :, -1, :, :], # new h_0_base
1166
+ )
1167
+
1168
+ for chunk_idx_inner in range(num_chunks_inner):
1169
+ q_chunk_params = q_chunks[:, :, chunk_idx_inner]
1170
+ w_chunk_params = w[:, :, chunk_idx_inner]
1171
+ u_chunk_params = u[:, :, chunk_idx_inner]
1172
+
1173
+ # Checkpointed call if training
1174
+ call = (
1175
+ partial(checkpoint, use_reentrant=False)
1176
+ if use_checkpoint
1177
+ else lambda f, *a: f(*a)
1178
+ )
1179
+ o_intra, o_inter, h_0_base = call(
1180
+ process_one_chunk,
1181
+ q_chunk_params,
1182
+ w_chunk_params,
1183
+ u_chunk_params,
1184
+ h_0_base,
1185
+ )
1186
+ if not linear_activation: # unlinear activation between chunks
1187
+ h_0_base = unlinear_activation(h_0_base).to(dtype=linear_precision)
1188
+ output_inner[:, :, chunk_idx_inner] = o_intra + o_inter
1189
+
1190
+ return output_inner, h_0_base
1191
+
1192
+
1193
+ def unlinear_activation(x: torch.Tensor, scale: float = 2.0) -> torch.Tensor:
1194
+ """Unlinear activation between chunk"""
1195
+ x_n = x.norm(p=2, dim=-1, keepdim=True) + 1e-6
1196
+ x_gelu = F.gelu(scale * x / x_n, approximate="tanh") # pylint: disable=not-callable
1197
+ return (x / scale) * x_gelu
1198
+
1199
+
1200
+ def chunk_sequence(x: torch.Tensor, num_chunks: int, chunk_size: int) -> torch.Tensor:
1201
+ """Splits [B, H, S, D] to [B, H, num_chunks, chunk_size, D]"""
1202
+ batch_size, num_heads, _, head_dim = x.shape
1203
+ return x.reshape(batch_size, num_heads, num_chunks, chunk_size, head_dim)
1204
+
1205
+
1206
+ def expand_virtual_tokens(
1207
+ x: torch.Tensor, n: int, mode: str = "derivative"
1208
+ ) -> torch.Tensor:
1209
+ """Expand tokens into 'n' virtual tokens using the selected trick."""
1210
+ batch_size, num_heads, seq_len, head_dim = x.shape
1211
+ device, dtype = x.device, x.dtype
1212
+
1213
+ def derivative_expand(x: torch.Tensor) -> torch.Tensor:
1214
+ """Expand tokens using the derivative trick."""
1215
+ x_pad = torch.cat(
1216
+ [
1217
+ torch.zeros(
1218
+ batch_size, num_heads, n - 1, head_dim, device=device, dtype=dtype
1219
+ ),
1220
+ x,
1221
+ ],
1222
+ dim=2,
1223
+ )
1224
+ coeffs = torch.tensor(
1225
+ [(-1) ** k * math.comb(n - 1, k) for k in range(n)],
1226
+ device=device,
1227
+ dtype=dtype,
1228
+ )
1229
+ coeffs /= coeffs.norm(p=1)
1230
+ return (
1231
+ (x_pad.unfold(2, n, 1) * coeffs.view(1, 1, 1, 1, n))
1232
+ .flip(-1)
1233
+ .permute(0, 1, 2, 4, 3)
1234
+ .reshape(batch_size, num_heads, seq_len * n, head_dim)
1235
+ )
1236
+
1237
+ def rotative_expand(x: torch.Tensor) -> torch.Tensor:
1238
+ """Expand tokens using the rotative trick."""
1239
+ d_parity = head_dim // 2
1240
+ angles = torch.arange(n, device=device, dtype=dtype) * (2 * math.pi / n)
1241
+ cos = torch.cos(angles).view(1, 1, 1, n, 1)
1242
+ sin = torch.sin(angles).view(1, 1, 1, n, 1)
1243
+ if head_dim % 2:
1244
+ x_pairs = x[..., :-1].view(batch_size, num_heads, seq_len, d_parity, 2)
1245
+ else:
1246
+ x_pairs = x.view(batch_size, num_heads, seq_len, d_parity, 2)
1247
+ x_pairs = x_pairs.unsqueeze(3).expand(
1248
+ batch_size, num_heads, seq_len, n, d_parity, 2
1249
+ )
1250
+ x0, x1 = x_pairs[..., 0], x_pairs[..., 1]
1251
+ x0r = x0 * cos - x1 * sin
1252
+ x1r = x0 * sin + x1 * cos
1253
+ rot = torch.stack([x0r, x1r], -1).reshape(
1254
+ batch_size, num_heads, seq_len, n, d_parity * 2
1255
+ )
1256
+ if head_dim % 2:
1257
+ last = (
1258
+ x[..., -1]
1259
+ .unsqueeze(-1)
1260
+ .unsqueeze(3)
1261
+ .expand(batch_size, num_heads, seq_len, n, 1)
1262
+ )
1263
+ rot = torch.cat([rot, last], -1)
1264
+ return rot.reshape(batch_size, num_heads, seq_len * n, head_dim)
1265
+
1266
+ if mode == "derivative":
1267
+ return derivative_expand(x)
1268
+ if mode == "rotative":
1269
+ return rotative_expand(x)
1270
+ if mode == "combined":
1271
+ return (derivative_expand(x) + rotative_expand(x)) / 2
1272
+ raise ValueError(f"Unknown mode: {mode}")
1273
+
1274
+
1275
+ def extract_layer_idx(module_name: str) -> int:
1276
+ """Extract the layer index from a module name string."""
1277
+ match = re.search(r"\.(\d+)\.", module_name)
1278
+ if match:
1279
+ return int(match.group(1))
1280
+ return -1
1281
+
1282
+
1283
+ def find_embedding_lm(module: nn.Module) -> Optional[nn.Module]:
1284
+ """Find the embedding weight in a model module."""
1285
+ for _, child in module.named_modules():
1286
+ if hasattr(child, "embed_tokens") and hasattr(child.embed_tokens, "weight"):
1287
+ return child.embed_tokens
1288
+ if hasattr(child, "token_embeddings") and hasattr(
1289
+ child.token_embeddings, "weight"
1290
+ ):
1291
+ return child.token_embeddings
1292
+ return None
1293
+
1294
+
1295
+ def set_trainable_parameters(
1296
+ model: PreTrainedModel, trainable_patterns: List[str] = None
1297
+ ) -> PreTrainedModel:
1298
+ """Freeze model parameters except trainable_patterns."""
1299
+ if trainable_patterns is None:
1300
+ trainable_patterns = [
1301
+ "q_proj",
1302
+ "k_proj",
1303
+ "v_proj",
1304
+ "o_proj",
1305
+ "qkv_proj",
1306
+ "out_proj",
1307
+ "c_attn",
1308
+ "c_proj",
1309
+ "query",
1310
+ "key",
1311
+ "value",
1312
+ ]
1313
+
1314
+ for name, param in model.named_parameters():
1315
+ param.requires_grad = any(pattern in name for pattern in trainable_patterns)
1316
+
1317
+ trainable_layers = [n for n, p in model.named_parameters() if p.requires_grad]
1318
+ logger.info("Trainable parameters after freeze: %s", trainable_layers)
1319
+ return model
1320
+
1321
+
1322
+ def ensure_stability(
1323
+ tensor: torch.Tensor, min_val: float = -1e4, max_val: float = 1e4
1324
+ ) -> torch.Tensor:
1325
+ """stability forcing"""
1326
+ dtype = tensor.dtype
1327
+ center = (max_val + min_val) / 2
1328
+ tensor = torch.clamp(tensor, min=min_val, max=max_val)
1329
+ tensor = torch.nan_to_num(tensor, nan=center, posinf=max_val, neginf=min_val)
1330
+ return tensor.to(dtype=dtype)
1331
+
1332
+
1333
+ def apply_linear_attention_mask(
1334
+ attention_mask: torch.Tensor, v: torch.Tensor, padding_side: str = "right"
1335
+ ) -> torch.Tensor:
1336
+ """Extract if padding --> [B,S]"""
1337
+ if attention_mask.dim() == 4 and attention_mask.shape[1] == 1:
1338
+ mask = attention_mask.diagonal(dim1=-2, dim2=-1).squeeze(1)
1339
+ else:
1340
+ mask = attention_mask.squeeze(
1341
+ dim=tuple(
1342
+ i
1343
+ for i in range(1, attention_mask.dim())
1344
+ if attention_mask.shape[i] == 1
1345
+ )
1346
+ )
1347
+ # Ensure cast to the same dtype as v and convert to binary mask
1348
+ if not (
1349
+ mask.dtype == torch.bool
1350
+ or (
1351
+ mask.dtype in [torch.uint8, torch.int32, torch.int64]
1352
+ and mask.max() <= 1
1353
+ and mask.min() >= 0
1354
+ )
1355
+ ):
1356
+ mask = (mask >= 0).to(v.dtype) # [-inf, 0, 0, -inf] --> [0, 1, 1, 0]
1357
+ else:
1358
+ mask = mask.to(v.dtype)
1359
+ # mask is [batch, seq] --> Broadcast to v [batch, seq, (...)]
1360
+ if padding_side == "left":
1361
+ mask = mask[:, -v.shape[-2] :][(...,) + (None,) * (v.dim() - 2)]
1362
+ else: # right padding
1363
+ mask = mask[:, : v.shape[-2]][(...,) + (None,) * (v.dim() - 2)]
1364
+ return v * mask
1365
+
1366
+
1367
+ def truncate_attention_mask(
1368
+ hidden_states: torch.Tensor, attention_mask: torch.Tensor, max_length: int
1369
+ ) -> tuple[torch.Tensor, torch.Tensor]:
1370
+ """Truncate hidden_states and attention_mask to the last window of size max_length"""
1371
+ seq_dim = 1 # convention: (batch, seq, ...)
1372
+ seq_len = hidden_states.shape[seq_dim]
1373
+ if seq_len > max_length:
1374
+ hidden_states = hidden_states.narrow(seq_dim, seq_len - max_length, max_length)
1375
+ if attention_mask is not None:
1376
+ # mask [batch, seq]
1377
+ if attention_mask.dim() == 2:
1378
+ attention_mask = attention_mask[:, -max_length:]
1379
+ # mask [batch, seq, seq]
1380
+ elif attention_mask.dim() == 3:
1381
+ attention_mask = attention_mask[:, -max_length:, -max_length:]
1382
+ # mask [batch, 1, seq, seq]
1383
+ elif attention_mask.dim() == 4 and attention_mask.shape[1] == 1:
1384
+ attention_mask = attention_mask[:, :, -max_length:, -max_length:]
1385
+ else:
1386
+ raise ValueError(
1387
+ "No dimension in attention_mask matches sequence length of hidden_states."
1388
+ )
1389
+ return hidden_states, attention_mask
1390
+
1391
+
1392
+ def fast_invert_matrix(
1393
+ tri_tensor: torch.Tensor, dtype: torch.dtype = torch.float32
1394
+ ) -> torch.Tensor:
1395
+ """Equivalent to vectorized forward substitution applied to the identity matrix."""
1396
+ tri_tensor = tri_tensor.to(dtype=dtype).clone()
1397
+ chunk_size = tri_tensor.shape[-1]
1398
+
1399
+ for i in range(1, chunk_size):
1400
+ tri_tensor[..., i, :i] = tri_tensor[..., i, :i] + (
1401
+ tri_tensor[..., i, :, None].clone() * tri_tensor[..., :, :i].clone()
1402
+ ).sum(-2)
1403
+
1404
+ tri_tensor = tri_tensor + torch.eye(
1405
+ chunk_size, dtype=dtype, device=tri_tensor.device
1406
+ )
1407
+ return tri_tensor.to(dtype=dtype)
1408
+
1409
+
1410
+ def get_valid_chunk_size(total_l: int, chunk_size: int) -> int:
1411
+ """Return the largest chunk_size <= chunk_size that divides total_l."""
1412
+ for c in range(min(chunk_size, total_l), 0, -1):
1413
+ if total_l % c == 0:
1414
+ return c
1415
+ return 1
1416
+
1417
+
1418
+ ## RARELY
1419
+ def split_qkv(
1420
+ base_attn: nn.Module, qkv: torch.Tensor
1421
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
1422
+ """Split the QKV tensor into separate Q, K, and V tensors."""
1423
+ num_q_heads = getattr(base_attn, "num_q_heads", None)
1424
+ num_k_heads = getattr(base_attn, "num_k_heads", None)
1425
+ num_v_heads = getattr(base_attn, "num_v_heads", None)
1426
+ head_dim = getattr(base_attn, "head_dim", None)
1427
+
1428
+ if num_q_heads is None or num_k_heads is None or num_v_heads is None:
1429
+ raise ValueError(
1430
+ "Base attention must have num_q_heads, num_k_heads, and num_v_heads defined."
1431
+ )
1432
+
1433
+ q_len = num_q_heads * head_dim
1434
+ k_len = num_k_heads * head_dim
1435
+ v_len = num_v_heads * head_dim
1436
+
1437
+ q, k, v = torch.split(qkv, [q_len, k_len, v_len], dim=-1)
1438
+ return q, k, v
1439
+
1440
+
1441
+ ## OPTIONAL
1442
+ def match_dim(x: torch.Tensor, dim: int, target_size: int) -> torch.Tensor:
1443
+ """Match the size of tensor x along dimension dim to target_size by interpolation"""
1444
+ src_size = x.shape[dim]
1445
+ if src_size == target_size:
1446
+ return x
1447
+ x = torch.moveaxis(x, dim, -1)
1448
+ shape = x.shape
1449
+ if src_size < target_size:
1450
+ x = x.reshape(-1, 1, src_size)
1451
+ x = F.interpolate(x, size=target_size, mode="linear", align_corners=False)
1452
+ x = x.reshape(*shape[:-1], target_size)
1453
+ else:
1454
+ eye = torch.eye(target_size, src_size, device=x.device, dtype=x.dtype)
1455
+ x = F.linear(x, eye) # pylint: disable=not-callable
1456
+ x = torch.moveaxis(x, -1, dim)
1457
+ return x
1458
+
1459
+
1460
+ def soft_clamp(
1461
+ x: torch.Tensor, min_val: float = 1e-6, max_val: float = 1 - 1e-6
1462
+ ) -> torch.Tensor:
1463
+ """Differentiable clamping for stability"""
1464
+ dtype = x.dtype
1465
+ scale = (max_val - min_val) / 2
1466
+ center = (max_val + min_val) / 2
1467
+ return (torch.tanh((x - center) / scale) * scale + center).to(dtype=dtype)
1468
+
1469
+
1470
+ def describe(x: torch.Tensor, name="tensor") -> None:
1471
+ """Prints the shape, min, max, mean, and std of a tensor."""
1472
+ stats = (x.min(), x.max(), x.mean(), x.std())
1473
+ print(
1474
+ f"{name} shape: {tuple(x.shape)}, "
1475
+ + f"min: {stats[0]:.4g}, max: {stats[1]:.4g}, "
1476
+ + f"mean: {stats[2]:.4g}, std: {stats[3]:.4g}, "
1477
+ + f"dtype: {x.dtype}, device: {x.device}"
1478
+ )
lora_delta_product_m0.5_gradual_t10/runs/Aug12_18-12-23_aac70857b6d3/events.out.tfevents.1755022349.aac70857b6d3.35.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ab0dc0ccba8e545ecf31b507dd7891135557a687fcea16ba612d50adc174e6ff
3
+ size 12411
lora_delta_product_m0.5_gradual_t10/special_tokens_map.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "eos_token": {
3
+ "content": "<|endoftext|>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "pad_token": {
10
+ "content": "<|padding|>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ }
16
+ }
lora_delta_product_m0.5_gradual_t10/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
lora_delta_product_m0.5_gradual_t10/tokenizer_config.json ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": false,
3
+ "add_eos_token": false,
4
+ "add_prefix_space": false,
5
+ "added_tokens_decoder": {
6
+ "0": {
7
+ "content": "|||IP_ADDRESS|||",
8
+ "lstrip": false,
9
+ "normalized": true,
10
+ "rstrip": false,
11
+ "single_word": false,
12
+ "special": false
13
+ },
14
+ "1": {
15
+ "content": "<|padding|>",
16
+ "lstrip": false,
17
+ "normalized": false,
18
+ "rstrip": false,
19
+ "single_word": false,
20
+ "special": true
21
+ },
22
+ "50254": {
23
+ "content": " ",
24
+ "lstrip": false,
25
+ "normalized": true,
26
+ "rstrip": false,
27
+ "single_word": false,
28
+ "special": false
29
+ },
30
+ "50255": {
31
+ "content": " ",
32
+ "lstrip": false,
33
+ "normalized": true,
34
+ "rstrip": false,
35
+ "single_word": false,
36
+ "special": false
37
+ },
38
+ "50256": {
39
+ "content": " ",
40
+ "lstrip": false,
41
+ "normalized": true,
42
+ "rstrip": false,
43
+ "single_word": false,
44
+ "special": false
45
+ },
46
+ "50257": {
47
+ "content": " ",
48
+ "lstrip": false,
49
+ "normalized": true,
50
+ "rstrip": false,
51
+ "single_word": false,
52
+ "special": false
53
+ },
54
+ "50258": {
55
+ "content": " ",
56
+ "lstrip": false,
57
+ "normalized": true,
58
+ "rstrip": false,
59
+ "single_word": false,
60
+ "special": false
61
+ },
62
+ "50259": {
63
+ "content": " ",
64
+ "lstrip": false,
65
+ "normalized": true,
66
+ "rstrip": false,
67
+ "single_word": false,
68
+ "special": false
69
+ },
70
+ "50260": {
71
+ "content": " ",
72
+ "lstrip": false,
73
+ "normalized": true,
74
+ "rstrip": false,
75
+ "single_word": false,
76
+ "special": false
77
+ },
78
+ "50261": {
79
+ "content": " ",
80
+ "lstrip": false,
81
+ "normalized": true,
82
+ "rstrip": false,
83
+ "single_word": false,
84
+ "special": false
85
+ },
86
+ "50262": {
87
+ "content": " ",
88
+ "lstrip": false,
89
+ "normalized": true,
90
+ "rstrip": false,
91
+ "single_word": false,
92
+ "special": false
93
+ },
94
+ "50263": {
95
+ "content": " ",
96
+ "lstrip": false,
97
+ "normalized": true,
98
+ "rstrip": false,
99
+ "single_word": false,
100
+ "special": false
101
+ },
102
+ "50264": {
103
+ "content": " ",
104
+ "lstrip": false,
105
+ "normalized": true,
106
+ "rstrip": false,
107
+ "single_word": false,
108
+ "special": false
109
+ },
110
+ "50265": {
111
+ "content": " ",
112
+ "lstrip": false,
113
+ "normalized": true,
114
+ "rstrip": false,
115
+ "single_word": false,
116
+ "special": false
117
+ },
118
+ "50266": {
119
+ "content": " ",
120
+ "lstrip": false,
121
+ "normalized": true,
122
+ "rstrip": false,
123
+ "single_word": false,
124
+ "special": false
125
+ },
126
+ "50267": {
127
+ "content": " ",
128
+ "lstrip": false,
129
+ "normalized": true,
130
+ "rstrip": false,
131
+ "single_word": false,
132
+ "special": false
133
+ },
134
+ "50268": {
135
+ "content": " ",
136
+ "lstrip": false,
137
+ "normalized": true,
138
+ "rstrip": false,
139
+ "single_word": false,
140
+ "special": false
141
+ },
142
+ "50269": {
143
+ "content": " ",
144
+ "lstrip": false,
145
+ "normalized": true,
146
+ "rstrip": false,
147
+ "single_word": false,
148
+ "special": false
149
+ },
150
+ "50270": {
151
+ "content": " ",
152
+ "lstrip": false,
153
+ "normalized": true,
154
+ "rstrip": false,
155
+ "single_word": false,
156
+ "special": false
157
+ },
158
+ "50271": {
159
+ "content": " ",
160
+ "lstrip": false,
161
+ "normalized": true,
162
+ "rstrip": false,
163
+ "single_word": false,
164
+ "special": false
165
+ },
166
+ "50272": {
167
+ "content": " ",
168
+ "lstrip": false,
169
+ "normalized": true,
170
+ "rstrip": false,
171
+ "single_word": false,
172
+ "special": false
173
+ },
174
+ "50273": {
175
+ "content": " ",
176
+ "lstrip": false,
177
+ "normalized": true,
178
+ "rstrip": false,
179
+ "single_word": false,
180
+ "special": false
181
+ },
182
+ "50274": {
183
+ "content": " ",
184
+ "lstrip": false,
185
+ "normalized": true,
186
+ "rstrip": false,
187
+ "single_word": false,
188
+ "special": false
189
+ },
190
+ "50275": {
191
+ "content": " ",
192
+ "lstrip": false,
193
+ "normalized": true,
194
+ "rstrip": false,
195
+ "single_word": false,
196
+ "special": false
197
+ },
198
+ "50276": {
199
+ "content": " ",
200
+ "lstrip": false,
201
+ "normalized": true,
202
+ "rstrip": false,
203
+ "single_word": false,
204
+ "special": false
205
+ },
206
+ "50277": {
207
+ "content": "|||EMAIL_ADDRESS|||",
208
+ "lstrip": false,
209
+ "normalized": true,
210
+ "rstrip": false,
211
+ "single_word": false,
212
+ "special": false
213
+ },
214
+ "50278": {
215
+ "content": "|||PHONE_NUMBER|||",
216
+ "lstrip": false,
217
+ "normalized": true,
218
+ "rstrip": false,
219
+ "single_word": false,
220
+ "special": false
221
+ },
222
+ "50279": {
223
+ "content": "<|endoftext|>",
224
+ "lstrip": false,
225
+ "normalized": false,
226
+ "rstrip": false,
227
+ "single_word": false,
228
+ "special": true
229
+ }
230
+ },
231
+ "bos_token": null,
232
+ "clean_up_tokenization_spaces": true,
233
+ "eos_token": "<|endoftext|>",
234
+ "extra_special_tokens": {},
235
+ "model_max_length": 1000000000000000019884624838656,
236
+ "pad_token": "<|padding|>",
237
+ "tokenizer_class": "GPTNeoXTokenizer",
238
+ "unk_token": null
239
+ }
modeling_tptt.py ADDED
@@ -0,0 +1,1478 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pylint: disable=too-many-lines, too-many-arguments, too-many-positional-arguments, too-many-instance-attributes, too-many-locals
2
+
3
+ """
4
+ This module implements the TPTT model with linear attention (LiZA) and LoRA support.
5
+ Author : Fabien FURFARO
6
+ TPTT : Transforming Pretrained Transformers into Titans (https://arxiv.org/abs/2506.17671)
7
+ """
8
+
9
+ import logging
10
+ import math
11
+ import os
12
+ from pathlib import Path
13
+ import re
14
+ import shutil
15
+ from functools import partial
16
+ from typing import Any, Dict, List, Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn.functional as F
20
+ from einops import rearrange
21
+ from huggingface_hub import hf_hub_download, list_repo_files
22
+ from peft import LoraConfig, PeftModel, get_peft_model
23
+ from safetensors import safe_open
24
+ from safetensors.torch import save_file
25
+ from torch import nn
26
+ from torch.utils.checkpoint import checkpoint
27
+ from transformers import AutoConfig, AutoModelForCausalLM, DynamicCache, PreTrainedModel
28
+ from transformers.configuration_utils import PretrainedConfig
29
+
30
+ from .configuration_tptt import TpttConfig
31
+
32
+ logger = logging.getLogger(__name__) # monitoring
33
+
34
+
35
+ class LCache:
36
+ """Cache for storing intermediate states of linear attention layers."""
37
+
38
+ def __init__(self):
39
+ """Stores per-layer intermediate states: {layer_idx: state_dict}"""
40
+ self.inputs_states: Dict[int, Dict[str, torch.Tensor]] = (
41
+ {}
42
+ ) # recurrent states and qkv buffers
43
+
44
+ def __getitem__(self, layer_idx: int) -> Optional[Dict[str, torch.Tensor]]:
45
+ """Retrieve cached state for a given layer, or None if not present"""
46
+ return self.inputs_states.get(layer_idx, None)
47
+
48
+ def update(self, layer_idx: int, **kwargs):
49
+ """Detach all tensors to avoid retaining computation graphs"""
50
+ detached_kwargs = {
51
+ k: v.detach() if isinstance(v, torch.Tensor) else v
52
+ for k, v in kwargs.items()
53
+ }
54
+ # Update or create the state for the specified layer
55
+ if layer_idx in self.inputs_states:
56
+ self.inputs_states[layer_idx].update(detached_kwargs)
57
+ else:
58
+ self.inputs_states[layer_idx] = detached_kwargs
59
+
60
+ def reset(self):
61
+ """Clear all cached states and reset the token counter"""
62
+ self.inputs_states.clear()
63
+
64
+
65
+ class CausalAvgPool1d(nn.Module):
66
+ """Causal sliding window average (uniform, no shape loss along sequence)"""
67
+
68
+ def __init__(
69
+ self, output_size: int, offsets: tuple[int] = (0, 1, 2), mode: str = "replicate"
70
+ ):
71
+ super().__init__()
72
+ self.offsets = offsets
73
+ self.mode = mode
74
+ self.pool = nn.AdaptiveAvgPool1d(output_size=output_size)
75
+
76
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
77
+ """x: [B, S, F] → [B, S, F → output_size]"""
78
+ x_ = x.transpose(1, 2) # [B, F, S]
79
+ idxs = torch.tensor(self.offsets, device=x.device)
80
+ ksize = idxs.max() - idxs.min() + 1
81
+ w = torch.zeros(ksize, device=x.device, dtype=x.dtype)
82
+ w[idxs - idxs.min()] = 1 / len(self.offsets) # Always uniform weights
83
+ kernel = w.repeat(x_.shape[1], 1).reshape(x_.shape[1], 1, ksize)
84
+ pad_left = -idxs.min().item()
85
+ pad_right = (ksize - 1) - pad_left
86
+ x_pad = F.pad(x_, (pad_left, pad_right), mode=self.mode)
87
+ y = F.conv1d(x_pad, kernel, groups=x_.shape[1]) # pylint: disable=not-callable
88
+ return self.pool(y.transpose(1, 2)) # [B, S, F → output_size]
89
+
90
+
91
+ class LinearAttention(nn.Module):
92
+ """
93
+ Linear multi-head attention layer: [B, S, D] -> [B, S, D]
94
+ Projections + gating + efficient linear attention mechanism (TPTT compatible).
95
+ """
96
+
97
+ def __init__(
98
+ self,
99
+ hidden_dim: int,
100
+ num_heads: int,
101
+ head_dim: Optional[int] = None,
102
+ num_key_value_heads: Optional[int] = None,
103
+ num_key_value_groups: Optional[int] = None,
104
+ bias: bool = True,
105
+ dropout: Optional[float] = None,
106
+ linear_precision: torch.dtype = torch.float32,
107
+ padding_side: str = "right",
108
+ shared_attn: bool = False, # shared attention
109
+ layer_idx: int = 0,
110
+ operator_mode: str = "delta_rule",
111
+ recurrent_config: Optional[Dict[str, Any]] = None,
112
+ linear_cache: Optional[LCache] = None,
113
+ max_chunk_size: int = 64,
114
+ bidirectional: bool = False, # not used if causal
115
+ pooling_config: Optional[Dict[str, Any]] = None,
116
+ ):
117
+ super().__init__()
118
+ if pooling_config is None:
119
+ pooling_config = {
120
+ "offsets": (0, 1, 2),
121
+ "mode": "replicate",
122
+ }
123
+ self.hidden_dim = hidden_dim
124
+ self.num_heads = num_heads
125
+ self.head_dim = head_dim or hidden_dim // num_heads
126
+ self.num_key_value_heads = num_key_value_heads or num_heads
127
+ self.num_key_value_groups = num_key_value_groups or (
128
+ num_heads // (num_key_value_heads or num_heads)
129
+ )
130
+ self.scaling = self.head_dim**-0.5
131
+ self.linear_precision = linear_precision
132
+ self.padding_side = padding_side
133
+
134
+ self.shared_attn = shared_attn
135
+
136
+ if not shared_attn:
137
+ self.q_proj = nn.Linear(hidden_dim, num_heads * self.head_dim, bias=bias)
138
+ self.k_proj = nn.Linear(
139
+ hidden_dim, self.num_key_value_heads * self.head_dim, bias=bias
140
+ )
141
+ self.v_proj = nn.Linear(
142
+ hidden_dim, self.num_key_value_heads * self.head_dim, bias=bias
143
+ )
144
+ self.out_proj = nn.Linear(num_heads * self.head_dim, hidden_dim, bias=bias)
145
+
146
+ self.dropout = nn.Dropout(dropout) if dropout is not None else None
147
+
148
+ self.linear_operator = LinearAttentionOp(
149
+ layer_idx=layer_idx,
150
+ operator_mode=operator_mode,
151
+ recurrent_config=recurrent_config,
152
+ max_chunk_size=max_chunk_size,
153
+ linear_cache=linear_cache,
154
+ linear_precision=linear_precision,
155
+ )
156
+ self.bidirectional = bidirectional
157
+ # Causal average pooling for gating
158
+ self.pooling_config = pooling_config
159
+ self.pool_g = CausalAvgPool1d(
160
+ output_size=self.head_dim * self.num_key_value_heads, **pooling_config
161
+ )
162
+
163
+ def forward(
164
+ self,
165
+ x: Union[List[torch.Tensor], torch.Tensor],
166
+ attn_mask: Optional[torch.Tensor] = None,
167
+ out_proj: Optional[nn.Module] = None,
168
+ **kwargs: Any,
169
+ ) -> torch.Tensor:
170
+ """
171
+ Forward pass for linear attention. Input shape: [B, S, D], output [B, S, D].
172
+ """
173
+
174
+ if not self.shared_attn:
175
+ hidden_states = x[0] if isinstance(x, (list, tuple)) else x
176
+ # Projections
177
+ q = self.q_proj(hidden_states)
178
+ k = self.k_proj(hidden_states)
179
+ v = self.v_proj(hidden_states)
180
+ out_proj = self.out_proj
181
+ else:
182
+ # Shared attention <=> no projections here
183
+ q, k, v = x[0], x[1], x[2]
184
+ out_proj = self.out_proj if out_proj is None else out_proj
185
+
186
+ # get dtype and device
187
+ final_dtype, final_device = q.dtype, q.device
188
+ # Masking if needed
189
+ if attn_mask is not None:
190
+ v = apply_linear_attention_mask(attn_mask, v, self.padding_side)
191
+
192
+ # Forget and Write Gating for linear attn (abusive term)
193
+ f_g, w_g = self.pool_g(k), self.pool_g(v)
194
+
195
+ # Reshape for multi-head
196
+ q = rearrange(q, "b n (h d) -> b h n d", h=self.num_heads)
197
+ k = rearrange(k, "b n (h d) -> b h n d", h=self.num_key_value_heads)
198
+ v = rearrange(v, "b n (h d) -> b h n d", h=self.num_key_value_heads)
199
+
200
+ f_g = rearrange(f_g, "b n (h m) -> b h n m", h=self.num_key_value_heads)
201
+ w_g = rearrange(w_g, "b n (h m) -> b h n m", h=self.num_key_value_heads)
202
+
203
+ # Repeat for GQA
204
+ k = k.repeat_interleave(self.num_key_value_groups, dim=1)
205
+ v = v.repeat_interleave(self.num_key_value_groups, dim=1)
206
+
207
+ f_g = f_g.repeat_interleave(self.num_key_value_groups, dim=1)
208
+ w_g = w_g.repeat_interleave(self.num_key_value_groups, dim=1)
209
+
210
+ ## DeltaNet-style: Silu activation and normalization
211
+ q = F.normalize(F.silu(q), p=2, dim=-1, eps=1e-6)
212
+ k = F.normalize(F.silu(k), p=2, dim=-1, eps=1e-6)
213
+
214
+ ## linear stability part
215
+ v = ensure_stability(v * self.scaling, min_val=-1e4, max_val=1e4)
216
+
217
+ # Apply sigmoid to forget and write gates
218
+ f_g = torch.clamp(torch.sigmoid(f_g), min=1e-6, max=1 - 1e-6)
219
+ w_g = torch.clamp(torch.sigmoid(w_g), min=1e-6, max=1 - 1e-6)
220
+
221
+ # Convert to linear_precision (float32) for numerical stability and get model dtype
222
+ q, k, v, f_g, w_g = (
223
+ x.to(self.linear_precision).contiguous() for x in (q, k, v, f_g, w_g)
224
+ )
225
+ g = (f_g, w_g)
226
+
227
+ # Linear Attention Core, output: [B, H, S, d]
228
+ if self.bidirectional: # Work only with uncausal attention
229
+ # Forward direction
230
+ out_forward = self.linear_operator(q, k, v, g, **kwargs)
231
+ # Backward direction: flip the input sequence on the time dimension (dim=2)
232
+ kwargs_bwd = kwargs.copy()
233
+ kwargs_bwd["use_cache"] = False
234
+ out_backward = self.linear_operator(
235
+ torch.flip(q, dims=[2]),
236
+ torch.flip(k, dims=[2]),
237
+ torch.flip(v, dims=[2]),
238
+ tuple(torch.flip(t, dims=[2]) for t in g),
239
+ **kwargs_bwd,
240
+ )
241
+ # Flip the output back to restore proper order
242
+ out_backward = torch.flip(out_backward, dims=[2])
243
+ # Fusion: here, simple addition
244
+ out = out_forward + out_backward
245
+ else:
246
+ out = self.linear_operator(q, k, v, g, **kwargs)
247
+
248
+ # Merge heads and project: [B, H, S, d] -> [B, S, H*d] -> Out proj
249
+ out = rearrange(out, "b h s d -> b s (h d)")
250
+ # Normalize output (RMS norm). Note: bidirectional compatibility
251
+ out = out / out.pow(2).mean(dim=-1, keepdim=True).add(1e-6).sqrt()
252
+ # Ensure dtype and device consistency
253
+ out = out.to(dtype=final_dtype, device=final_device)
254
+ # Apply output projection
255
+ out = out_proj(out) # [B, S, D]
256
+ out = ensure_stability(out, min_val=-1e4, max_val=1e4)
257
+ # Apply dropout if specified
258
+ if self.dropout is not None:
259
+ out = self.dropout(out)
260
+ return out
261
+
262
+
263
+ class LiZAttention(nn.Module):
264
+ """LiZA Linear Attention module, mixing linear and vanilla attention."""
265
+
266
+ def __init__(
267
+ self,
268
+ base_attn: nn.Module,
269
+ layer_idx: int,
270
+ base_config: PretrainedConfig, # Backbone Config
271
+ linear_cache: Optional[LCache] = None,
272
+ operator_mode: str = "delta_rule",
273
+ recurrent_config: Optional[Dict[str, Any]] = None,
274
+ max_self_attn_length: Optional[int] = None, # unnecessary
275
+ base_scale_attn: bool = False,
276
+ mag_weight: float = 0.5,
277
+ cross_gate: bool = False,
278
+ max_chunk_size: int = 64,
279
+ linear_precision: Union[str, torch.dtype] = "float32",
280
+ padding_side: str = "right", # for tokenizer
281
+ disable_linear_attn: bool = False,
282
+ bidirectional: bool = False, # if True, use bidirectional attention
283
+ pooling_config: Optional[Dict[str, Any]] = None,
284
+ ):
285
+ super().__init__()
286
+ if isinstance(linear_precision, str):
287
+ linear_precision = getattr(torch, linear_precision)
288
+ self.linear_precision = linear_precision
289
+ self.base_attn: nn.Module = base_attn
290
+ self.base_config = base_config
291
+ self.layer_idx = layer_idx
292
+ self.max_self_attn_length = max_self_attn_length
293
+ self.base_scale_attn = base_scale_attn
294
+ self.mag_weight = mag_weight
295
+ self.cross_gate = cross_gate
296
+ self.max_chunk_size = max_chunk_size
297
+ self.linear_precision = linear_precision
298
+ self.padding_side = padding_side
299
+ self.disable_linear_attn = disable_linear_attn
300
+
301
+ (
302
+ self.num_heads,
303
+ self.head_dim,
304
+ self.num_key_value_heads,
305
+ self.num_key_value_groups,
306
+ ) = self._get_attention_parameters(base_attn, base_config)
307
+ self.scaling = self.head_dim**-0.5
308
+
309
+ self.linear_attn = LinearAttention(
310
+ layer_idx=layer_idx,
311
+ shared_attn=True,
312
+ operator_mode=operator_mode,
313
+ recurrent_config=recurrent_config,
314
+ hidden_dim=base_config.hidden_size,
315
+ num_heads=self.num_heads,
316
+ head_dim=self.head_dim,
317
+ num_key_value_heads=self.num_key_value_heads,
318
+ num_key_value_groups=self.num_key_value_groups,
319
+ linear_precision=linear_precision,
320
+ linear_cache=linear_cache,
321
+ max_chunk_size=max_chunk_size,
322
+ padding_side=padding_side,
323
+ bidirectional=bidirectional,
324
+ pooling_config=pooling_config,
325
+ )
326
+
327
+ def _get_attention_parameters(
328
+ self, base_attn: nn.Module, base_config: PretrainedConfig
329
+ ) -> Tuple[Optional[int], Optional[int], Optional[int], Optional[int]]:
330
+ """Retrieve the attention parameters from the base attention module."""
331
+ # first order base attention module and second order config
332
+ num_heads = (
333
+ getattr(base_attn, "num_heads", None)
334
+ or getattr(base_attn, "num_q_heads", None)
335
+ or getattr(base_config, "num_heads", None)
336
+ or getattr(base_config, "num_attention_heads", None)
337
+ )
338
+ head_dim = (
339
+ getattr(base_attn, "head_dim", None)
340
+ or getattr(base_attn, "attention_head_size", None)
341
+ or getattr(base_config, "head_dim", None)
342
+ or (
343
+ getattr(base_config, "hidden_size", None) // num_heads
344
+ if num_heads and getattr(base_config, "hidden_size", None)
345
+ else None
346
+ )
347
+ )
348
+ num_key_value_heads = (
349
+ getattr(base_attn, "num_kv_heads", None)
350
+ or getattr(base_attn, "num_k_heads", None)
351
+ or getattr(base_config, "num_key_value_heads", None)
352
+ or num_heads # fallback
353
+ )
354
+ num_key_value_groups = getattr(base_attn, "num_key_value_groups", None) or (
355
+ num_heads // num_key_value_heads if num_heads and num_key_value_heads else 1
356
+ )
357
+ return (
358
+ num_heads,
359
+ head_dim,
360
+ num_key_value_heads,
361
+ num_key_value_groups,
362
+ )
363
+
364
+ def _apply_shared_projections(
365
+ self, hidden_states: torch.Tensor
366
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, nn.Module]:
367
+ base_attn = self.base_attn
368
+ if hasattr(base_attn, "q_proj"):
369
+ # LLama, OLMO and Mistral style
370
+ q = base_attn.q_proj(hidden_states)
371
+ k = base_attn.k_proj(hidden_states)
372
+ v = base_attn.v_proj(hidden_states)
373
+ out_proj = base_attn.o_proj
374
+ elif hasattr(base_attn, "qkv_proj"):
375
+ # OpenELM and GPT-Neo style : QKV fused, split on the last dimension
376
+ qkv = base_attn.qkv_proj(hidden_states)
377
+ q, k, v = split_qkv(base_attn, qkv)
378
+ out_proj = base_attn.out_proj
379
+ elif hasattr(base_attn, "c_attn") and hasattr(base_attn, "c_proj"):
380
+ # GPT-2 style
381
+ qkv = base_attn.c_attn(hidden_states)
382
+ q, k, v = qkv.chunk(3, dim=-1)
383
+ out_proj = base_attn.c_proj
384
+ elif all(hasattr(base_attn, n) for n in ["query", "key", "value"]):
385
+ # BERT - ViT
386
+ q = base_attn.query(hidden_states)
387
+ k = base_attn.key(hidden_states)
388
+ v = base_attn.value(hidden_states)
389
+ out_proj = getattr(base_attn, "dense", None) # ou output.dense
390
+ else:
391
+ raise ValueError("Unsupported attention module: cannot find projections.")
392
+ # Ensure stability
393
+ q = ensure_stability(q, min_val=-1e4, max_val=1e4)
394
+ k = ensure_stability(k, min_val=-1e4, max_val=1e4)
395
+ v = ensure_stability(v, min_val=-1e4, max_val=1e4)
396
+ return q, k, v, out_proj
397
+
398
+ def _process_self_attn(
399
+ self,
400
+ hidden_states: torch.Tensor,
401
+ attention_mask: Optional[torch.Tensor],
402
+ kwargs,
403
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[DynamicCache], int]:
404
+ """Process the self-attention part (with truncation)."""
405
+ if self.max_self_attn_length: # Not needed for SWA (nonparam memorize context)
406
+ hidden_states, attention_mask = truncate_attention_mask(
407
+ hidden_states, attention_mask, self.max_self_attn_length
408
+ )
409
+
410
+ if kwargs.get("position_embeddings", None) is not None:
411
+ cos, sin = kwargs["position_embeddings"]
412
+ cos = cos[:, -self.max_self_attn_length :]
413
+ sin = sin[:, -self.max_self_attn_length :]
414
+ kwargs["position_embeddings"] = (cos, sin)
415
+
416
+ if isinstance(kwargs.get("past_key_value", None), DynamicCache):
417
+ # cache management
418
+ if (
419
+ len(kwargs["past_key_value"]) > self.layer_idx
420
+ and self.layer_idx == 0
421
+ ):
422
+ kwargs["past_key_value"].crop(self.max_self_attn_length - 1)
423
+
424
+ # Standard attention (mask and rotation is applied inside)
425
+ base_attn_outputs = self.base_attn(
426
+ hidden_states,
427
+ attention_mask=attention_mask,
428
+ **kwargs,
429
+ )
430
+
431
+ if isinstance(base_attn_outputs, tuple):
432
+ if len(base_attn_outputs) == 3:
433
+ o_base, attn_weights, present_key_value = base_attn_outputs
434
+ expected_attn_mode = 3
435
+ elif len(base_attn_outputs) == 2:
436
+ o_base, attn_weights = base_attn_outputs
437
+ present_key_value, expected_attn_mode = None, 2
438
+ else:
439
+ raise ValueError(
440
+ f"Unexpected number of outputs from base_attn: {len(base_attn_outputs)}"
441
+ )
442
+ else:
443
+ o_base = base_attn_outputs
444
+ attn_weights, present_key_value, expected_attn_mode = None, None, 1
445
+ # Ensure stability
446
+ o_base = ensure_stability(o_base, min_val=-1e4, max_val=1e4)
447
+ return o_base, attn_weights, present_key_value, expected_attn_mode
448
+
449
+ def _prepare_attn_mixin(
450
+ self,
451
+ o_lin: torch.Tensor,
452
+ o_base: torch.Tensor,
453
+ tensor_dtype: torch.dtype,
454
+ eps: float = 1e-5,
455
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
456
+ """Prepare linear attn for mixing with self attn."""
457
+ # Force cast typing, shape : [b n (h d)]
458
+ o_lin = o_lin.to(tensor_dtype)
459
+ o_base = o_base.to(tensor_dtype)
460
+ # feature scaling
461
+ if self.base_scale_attn:
462
+ scaler = o_base.pow(2).mean(dim=-1, keepdim=True).add(eps).sqrt()
463
+ o_lin = scaler * o_lin
464
+ return o_lin, o_base
465
+
466
+ def _apply_mag(
467
+ self, linear_attention: torch.Tensor, softmax_attention: torch.Tensor
468
+ ) -> torch.Tensor:
469
+ """Apply the MAG strategy"""
470
+ # Left-Padding management
471
+ if linear_attention.shape[1] != softmax_attention.shape[1]:
472
+ left_trunc = min(linear_attention.shape[1], softmax_attention.shape[1])
473
+ linear_attention, softmax_attention = (
474
+ linear_attention[:, -left_trunc:],
475
+ softmax_attention[:, -left_trunc:],
476
+ )
477
+ # NAM : Neural Attention Mixer (with graph forcing)
478
+ mag_weight = torch.tensor(
479
+ self.mag_weight,
480
+ dtype=softmax_attention.dtype,
481
+ device=softmax_attention.device,
482
+ )
483
+ softmax_weighted = (1 - mag_weight) * softmax_attention
484
+ linear_weighted = mag_weight * linear_attention
485
+ if self.cross_gate:
486
+ output_attention = (
487
+ softmax_weighted + linear_weighted + softmax_weighted * linear_weighted
488
+ ) # complex cross product (unlinear interaction)
489
+ else:
490
+ output_attention = softmax_weighted + linear_weighted # classic
491
+
492
+ if torch.allclose(softmax_weighted, output_attention):
493
+ logger.info(
494
+ "[LOG] layer : %s, softmax_weighted and output_attention are close.",
495
+ self.layer_idx,
496
+ )
497
+ # Final output
498
+ return ensure_stability(output_attention, min_val=-1e4, max_val=1e4)
499
+
500
+ def forward(
501
+ self,
502
+ hidden_states: torch.Tensor,
503
+ attention_mask: Optional[torch.Tensor] = None,
504
+ **kwargs,
505
+ ) -> torch.Tensor:
506
+ """Mix linear and self attention forward"""
507
+ device = hidden_states.device
508
+ tensor_dtype = hidden_states.dtype
509
+ self.base_attn.to(device)
510
+
511
+ if self.training:
512
+ kwargs.pop("past_key_value", None)
513
+ kwargs["use_cache"] = False
514
+ elif "use_cache" not in kwargs:
515
+ kwargs.pop("past_key_value", None)
516
+ kwargs["use_cache"] = False
517
+
518
+ kwargs.pop("position_ids", None) # obsolete
519
+
520
+ # Apply shared projections
521
+ q, k, v, out_proj = self._apply_shared_projections(hidden_states)
522
+
523
+ # Apply linear attention to hidden states
524
+ o_lin = self.linear_attn(
525
+ x=[q, k, v], attn_mask=attention_mask, out_proj=out_proj, **kwargs
526
+ )
527
+
528
+ # Process self attn with truncation
529
+ o_base, attn_weights, present_key_value, expected_attn_mode = (
530
+ self._process_self_attn(hidden_states, attention_mask, kwargs)
531
+ )
532
+
533
+ # Prepare output mixing
534
+ o_lin, o_base = self._prepare_attn_mixin(o_lin, o_base, tensor_dtype, eps=1e-5)
535
+
536
+ # Apply Memory as Gate in self-attention (with length management and ablation)
537
+ out = o_base if self.disable_linear_attn else self._apply_mag(o_lin, o_base)
538
+
539
+ # Return output following transformer convention
540
+ if expected_attn_mode == 3:
541
+ return out, attn_weights, present_key_value
542
+ if expected_attn_mode == 2:
543
+ return out, attn_weights
544
+ return out
545
+
546
+
547
+ def load_tptt_safetensors(
548
+ repo_or_path: str,
549
+ model: Union[PreTrainedModel, PeftModel],
550
+ subfolder: Optional[str] = None,
551
+ token: Optional[str] = None,
552
+ ) -> Union[PreTrainedModel, PeftModel]:
553
+ """Load Tptt safetensor from LoRA/PEFT weights and adapt keys if needed."""
554
+ # sharding not supported yet (e.g. : -00001-of-00005.safetensors, ...)
555
+ fname = "adapter_model.safetensors"
556
+ # subfolder management
557
+ if subfolder:
558
+ repo_or_path_norm = os.path.normpath(repo_or_path)
559
+ subfolder_norm = os.path.normpath(subfolder)
560
+ if not repo_or_path_norm.endswith(subfolder_norm):
561
+ fname = f"{subfolder}/{fname}" if subfolder else fname
562
+ # Find file path
563
+ if os.path.isdir(repo_or_path):
564
+ path = os.path.join(repo_or_path, fname)
565
+ if not os.path.exists(path):
566
+ return model
567
+ else:
568
+ if fname not in list_repo_files(repo_or_path, token=token):
569
+ return model
570
+ path = hf_hub_download(repo_or_path, fname, token=token)
571
+
572
+ # Load weights from safetensors
573
+ with safe_open(path, framework="pt") as f:
574
+ state_dict = {k: f.get_tensor(k) for k in f.keys()}
575
+
576
+ # Adapt LoRA/Specific keys if needed (add .default if expected by the model)
577
+ def adapt_keys(sd, model):
578
+ model_keys = list(model.state_dict().keys())
579
+ if any(k.startswith("tptt_model.base_model.") for k in model_keys):
580
+ prefix = "tptt_model.base_model."
581
+ elif any(k.startswith("base_model.") for k in model_keys):
582
+ prefix = "base_model."
583
+ else:
584
+ prefix = ""
585
+
586
+ has_base_attn = any(".base_attn." in k for k in model_keys)
587
+
588
+ def adapt_key(k):
589
+ k_ = k if k.startswith(prefix) else prefix + k
590
+ # first, verify and modify base_attn (LiZA)
591
+ if ".base_attn." in k_ and not has_base_attn:
592
+ k_ = k_.replace(".base_attn.", ".")
593
+ # change LoRA if needed
594
+ if (
595
+ k_.endswith("lora_A.weight") or k_.endswith("lora_B.weight")
596
+ ) and k_.replace(".weight", ".default.weight") in model_keys:
597
+ k_ = k_.replace(".weight", ".default.weight")
598
+ return k_
599
+
600
+ return {adapt_key(k): v for k, v in sd.items()}
601
+
602
+ state_dict = adapt_keys(state_dict, model)
603
+
604
+ # Cast tensors to the expected dtype of the model parameters
605
+ model_state_dict = model.state_dict()
606
+ for k, v in state_dict.items():
607
+ if k in model_state_dict:
608
+ expected_dtype = model_state_dict[k].dtype
609
+ if v.dtype != expected_dtype:
610
+ state_dict[k] = v.to(expected_dtype)
611
+
612
+ logger.info("Input LoRA/Specific keys: %s", [k for k in state_dict.keys()])
613
+
614
+ # Load into model
615
+ missing, unexpected = model.load_state_dict(state_dict, strict=False, assign=True)
616
+ missing_lora = [k for k in missing if "lora" in k]
617
+ if missing_lora:
618
+ logger.warning("Missing keys: %s", missing_lora)
619
+ if unexpected:
620
+ logger.warning("Unexpected keys: %s", unexpected)
621
+ return model
622
+
623
+
624
+ def get_tptt_model( # pylint: disable=too-many-arguments, too-many-positional-arguments
625
+ model: nn.Module,
626
+ base_config: PretrainedConfig, # ou LlamaConfig, MistralConfig, etc.
627
+ linear_cache: Optional[LCache] = None,
628
+ liza_attention: nn.Module = LiZAttention,
629
+ target_modules_names: Optional[list[str]] = None,
630
+ operator_mode: str = "delta_rule",
631
+ recurrent_config: Optional[Dict[str, Any]] = None,
632
+ base_scale_attn: bool = False,
633
+ mag_weight: float = 0.5,
634
+ cross_gate: bool = False,
635
+ max_chunk_size: int = 64,
636
+ linear_precision: torch.dtype = torch.float32,
637
+ max_self_attn_length: Optional[int] = None, # unnecessary
638
+ padding_side: str = "right", # for tokenizer
639
+ bidirectional: bool = False, # if True, use bidirectional attention
640
+ pooling_config: Optional[Dict[str, Any]] = None,
641
+ **kwargs, # quickfix unexpected arguments
642
+ ) -> Tuple[PreTrainedModel, LCache]:
643
+ """Replace target modules in a model with LiZAttention."""
644
+ if target_modules_names is None:
645
+ target_modules_names = ["attn", "self_attn", "attention"]
646
+ # Find target modules by suffix (e.g., "attn", "attention")
647
+ target_modules_names = [
648
+ name
649
+ for name, _ in model.named_modules()
650
+ if any(name.endswith(suffix) for suffix in target_modules_names)
651
+ and not any(f".{suffix}." in name for suffix in target_modules_names)
652
+ ]
653
+ if not target_modules_names:
654
+ raise ValueError(
655
+ f"Target modules '{target_modules_names}' not found in the model."
656
+ )
657
+ # Prepare recurrent config
658
+ linear_cache = linear_cache or LCache()
659
+ # Inject LiZAttention into the model
660
+ for name, _ in model.named_modules():
661
+ if name in target_modules_names:
662
+ parent = model
663
+ *path, last = name.split(".")
664
+ for p in path:
665
+ parent = getattr(parent, p)
666
+ layer_idx = extract_layer_idx(name)
667
+ setattr(
668
+ parent,
669
+ last,
670
+ liza_attention(
671
+ getattr(parent, last),
672
+ layer_idx=layer_idx,
673
+ base_config=base_config,
674
+ linear_cache=linear_cache,
675
+ operator_mode=operator_mode,
676
+ recurrent_config=recurrent_config,
677
+ max_self_attn_length=max_self_attn_length,
678
+ base_scale_attn=base_scale_attn,
679
+ mag_weight=mag_weight,
680
+ cross_gate=cross_gate,
681
+ max_chunk_size=max_chunk_size,
682
+ linear_precision=linear_precision,
683
+ padding_side=padding_side,
684
+ bidirectional=bidirectional,
685
+ pooling_config=pooling_config,
686
+ ),
687
+ )
688
+ return model, linear_cache
689
+
690
+
691
+ def save_tptt_safetensors(model, path: str, name: str = "adapter_model.safetensors"):
692
+ """Save trainable LoRA/Specific weights and adapting key names"""
693
+ # 1. Get the full state_dict
694
+ all_sd = model.state_dict()
695
+
696
+ # 2. Identify trainable parameter names (usually only LoRA/PEFT adapters)
697
+ trainable_keys = [
698
+ name for name, param in model.named_parameters() if param.requires_grad
699
+ ] # Also, you can manually select specific keys in model after load
700
+
701
+ # 3. Filter and adapt the keys (Remove custom model encapsulation info)
702
+ to_save = {
703
+ k.replace("tptt_model.", "").replace("base_model.", ""): all_sd[k]
704
+ for k in trainable_keys
705
+ }
706
+
707
+ # 4. Save the filtered adapters to a safetensors file
708
+ if to_save:
709
+ os.makedirs(os.path.dirname(path), exist_ok=True)
710
+ # sharding not supported yet (e.g. : -00001-of-00005.safetensors, ...)
711
+ save_file(to_save, os.path.join(path, name))
712
+
713
+
714
+ class TpttModel(PreTrainedModel):
715
+ """
716
+ TPTT model wrapper with linear attention (LiZA) and LoRA support.
717
+ Handles only architecture and weights.
718
+ """
719
+
720
+ config_class = TpttConfig
721
+
722
+ def __init__(
723
+ self,
724
+ config: TpttConfig,
725
+ **kwargs,
726
+ ):
727
+ """
728
+ Initialize TpttModel with a given config and backbone.
729
+ Injects LiZA attention modules into the backbone.
730
+ """
731
+ super().__init__(config, **kwargs)
732
+ repo_or_path = getattr(config, "_base_path", None) or config._name_or_path
733
+
734
+ # 1. Load backbone (with subfolder management) :
735
+ kwargs_bb = kwargs.copy()
736
+ if config.base_model_subfolder is not None:
737
+ kwargs_bb["subfolder"] = config.base_model_subfolder
738
+ else:
739
+ kwargs_bb.pop("subfolder", None)
740
+ tptt_model = AutoModelForCausalLM.from_pretrained(
741
+ config.base_model_name, **kwargs_bb
742
+ )
743
+
744
+ # 2. Inject LiZA attention
745
+ self.linear_cache = LCache()
746
+ tptt_model, self.linear_cache = get_tptt_model(
747
+ tptt_model, config, self.linear_cache, **config.to_dict()
748
+ )
749
+
750
+ # 3. Apply LoRA/Specific if present and configured
751
+ if config.lora_config is not None:
752
+ lora_config_obj = LoraConfig(**config.lora_config)
753
+ tptt_model = get_peft_model(tptt_model, lora_config_obj)
754
+ else:
755
+ tptt_model = set_trainable_parameters(tptt_model)
756
+
757
+ # 4. Load safetensor if tptt/peft adaptor in repo
758
+ if repo_or_path:
759
+ tptt_model = load_tptt_safetensors(
760
+ repo_or_path,
761
+ tptt_model,
762
+ subfolder=kwargs.get("subfolder", None),
763
+ token=kwargs.get("token", None),
764
+ )
765
+ self.tptt_model = tptt_model
766
+
767
+ def forward(
768
+ self,
769
+ input_ids: Optional[torch.LongTensor] = None,
770
+ attention_mask: Optional[torch.Tensor] = None,
771
+ labels: Optional[torch.LongTensor] = None,
772
+ **kwargs,
773
+ ):
774
+ """Forward pass. All arguments are passed to the underlying base model."""
775
+ if self.training:
776
+ kwargs["use_cache"] = False
777
+ kwargs.pop("num_items_in_batch", None)
778
+ elif "use_cache" not in kwargs: # evaluation
779
+ kwargs.pop("num_items_in_batch", None)
780
+ kwargs["use_cache"] = False
781
+ return self.tptt_model(
782
+ input_ids=input_ids, attention_mask=attention_mask, labels=labels, **kwargs
783
+ )
784
+
785
+ def generate(self, *args, **kwargs):
786
+ """Delegate the generate call to the backbone model, which supports generation"""
787
+ return self.tptt_model.generate(*args, **kwargs)
788
+
789
+ def save_pretrained(self, path: str, **kwargs):
790
+ """Save model weights, config, and source code to the given path."""
791
+ # 0. Save complete tptt config (with or without LoRA)
792
+ super().save_pretrained(path, **kwargs) # pylint: disable=no-member
793
+ self._adjust_save_strategy(path, **kwargs)
794
+ # 1. Save true weights and adapte keys
795
+ save_tptt_safetensors(self, path)
796
+ # 2. Copy Python files for trust_remote_code
797
+ self._copy_source_files(path, **kwargs)
798
+
799
+ def _adjust_save_strategy(self, path: str, **kwargs):
800
+ """Re-adapt/remove the weight safetensor and saved adapter config"""
801
+ if isinstance(self.tptt_model, PeftModel):
802
+ self.tptt_model.save_pretrained(path, **kwargs)
803
+ safetensor_path = os.path.join(path, "model.safetensors")
804
+ if os.path.exists(safetensor_path):
805
+ os.remove(safetensor_path)
806
+ adapter_path = os.path.join(path, "adapter_config.json")
807
+ if os.path.exists(adapter_path):
808
+ os.remove(adapter_path)
809
+
810
+ def _copy_source_files(self, target_path: str, **kwargs):
811
+ """Copy all .py files from package directory for trust_remote_code."""
812
+ src_dir = os.path.dirname(os.path.abspath(__file__))
813
+ dst_dir = (
814
+ f"./{str(Path(target_path).parts[0])}"
815
+ if kwargs.get("subfolder", False)
816
+ else target_path
817
+ )
818
+ for fname in os.listdir(src_dir):
819
+ if fname.endswith(".py"):
820
+ src = os.path.join(src_dir, fname)
821
+ dst = os.path.join(dst_dir, fname)
822
+ shutil.copy2(src, dst)
823
+
824
+ def retie_lm_after_load(self, **kwargs):
825
+ """Re-link lm_head after loading external weights."""
826
+ embed_lm = find_embedding_lm(self.tptt_model)
827
+ if embed_lm is not None and hasattr(self.tptt_model, "lm_head"):
828
+ if self.tptt_model.lm_head is None: # ensure lm_head exists
829
+ self.tptt_model.lm_head = nn.Linear(
830
+ embed_lm.weight.shape[1], embed_lm.weight.shape[0], bias=False
831
+ )
832
+ if kwargs.get("tie_word_embeddings", True):
833
+ self.tptt_model.lm_head.weight = embed_lm.weight # share weights
834
+ logger.info("Weights of lm_head have been shared with embedding.")
835
+ else:
836
+ self.tptt_model.lm_head.weight = nn.Parameter(embed_lm.weight.clone())
837
+ logger.info("Weights of lm_head have been cloned from the embedding.")
838
+
839
+ @classmethod
840
+ def from_pretrained(cls, pretrained_model_name_or_path=None, *model_args, **kwargs):
841
+ """Custom from_pretrained that accepts the standard positional argument"""
842
+ config = kwargs.pop("config", None)
843
+ repo_or_path = (
844
+ pretrained_model_name_or_path
845
+ or kwargs.pop("pretrained_model_name_or_path", None)
846
+ or kwargs.pop("repo_or_path", None)
847
+ or (getattr(config, "_base_path", None) if config else None)
848
+ or (getattr(config, "_name_or_path", None) if config else None)
849
+ )
850
+
851
+ if config is None and repo_or_path is not None:
852
+ config = AutoConfig.from_pretrained(repo_or_path, **kwargs)
853
+ model = cls(config, *model_args, **kwargs)
854
+ model.retie_lm_after_load(**kwargs)
855
+ return model
856
+
857
+
858
+ TpttModel.register_for_auto_class("AutoModelForCausalLM")
859
+
860
+
861
+ class LinearAttentionOp(nn.Module):
862
+ """Base class for linear attention operators."""
863
+
864
+ def __init__(
865
+ self,
866
+ layer_idx: int,
867
+ operator_mode: str = "delta_rule",
868
+ recurrent_config: Optional[dict] = None,
869
+ max_chunk_size: int = 64,
870
+ linear_cache: Optional[LCache] = None,
871
+ linear_precision: torch.dtype = torch.float32,
872
+ ):
873
+ super().__init__()
874
+ self.layer_idx = layer_idx
875
+ if recurrent_config is None:
876
+ operator_mode = "delta_rule" # force default operator mode if no config
877
+ recurrent_config = {
878
+ "order": 1,
879
+ "gate_type": "k",
880
+ "linear": True,
881
+ "trick": "derivative",
882
+ }
883
+ self.operator_mode = operator_mode
884
+ self.order = recurrent_config["order"]
885
+ self.gate_type = recurrent_config["gate_type"]
886
+ self.linear = recurrent_config["linear"]
887
+ self.trick = recurrent_config["trick"]
888
+
889
+ self.max_chunk_size = max_chunk_size
890
+ self.linear_cache = linear_cache or LCache()
891
+ self.linear_precision = linear_precision
892
+
893
+ def compute_gate(self, beta: Tuple[torch.Tensor]) -> torch.Tensor:
894
+ """
895
+ Compute the gating tensor according to the gate_type.
896
+ """
897
+ if self.gate_type == "k":
898
+ return torch.clamp(beta[0], min=1e-6, max=1 - 1e-6)
899
+ if self.gate_type == "v":
900
+ return torch.clamp(beta[1], min=1e-6, max=1 - 1e-6)
901
+ if self.gate_type == "kv":
902
+ return torch.clamp(beta[0] * beta[1], min=1e-6, max=1 - 1e-6)
903
+ raise ValueError(f"Unsupported gate_type: {self.gate_type}")
904
+
905
+ def get_cache(self, use_cache: bool) -> Tuple[
906
+ Optional[torch.Tensor],
907
+ Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]],
908
+ ]:
909
+ """
910
+ Retrieve recurrent state and qkv buffers from the cache.
911
+ """
912
+ if not use_cache:
913
+ return None, None
914
+ last_state = self.linear_cache[self.layer_idx]
915
+ if last_state is not None:
916
+ recurrent_state = last_state.get("recurrent_state", None)
917
+ qkv_buffers = last_state.get("qkv", None)
918
+ else:
919
+ recurrent_state = None
920
+ qkv_buffers = None
921
+ return recurrent_state, qkv_buffers
922
+
923
+ def save_cache(
924
+ self,
925
+ use_cache: bool,
926
+ q: torch.Tensor,
927
+ k: torch.Tensor,
928
+ v: torch.Tensor,
929
+ gate: torch.Tensor,
930
+ state: torch.Tensor,
931
+ ) -> None:
932
+ """
933
+ Save the recurrent state and qkv buffers to the cache.
934
+ """
935
+ if not use_cache:
936
+ return
937
+ if self.order > 1:
938
+ qkv_buffers = (
939
+ q[:, :, -(self.order - 1) :, :],
940
+ k[:, :, -(self.order - 1) :, :],
941
+ v[:, :, -(self.order - 1) :, :],
942
+ gate[:, :, -(self.order - 1) :, :],
943
+ )
944
+ else:
945
+ qkv_buffers = None
946
+ self.linear_cache.update(self.layer_idx, recurrent_state=state, qkv=qkv_buffers)
947
+
948
+ def forward(
949
+ self,
950
+ q: torch.Tensor,
951
+ k: torch.Tensor,
952
+ v: torch.Tensor,
953
+ beta: Union[Tuple[torch.Tensor], torch.Tensor],
954
+ **kwargs,
955
+ ) -> torch.Tensor:
956
+ """
957
+ Forward pass for the attention operator.
958
+ """
959
+ # Ensure linear_precision for numerical stability (float32)
960
+ q, k, v = [x.to(self.linear_precision) for x in (q, k, v)]
961
+ if isinstance(beta, (tuple, list)):
962
+ beta = tuple(b.to(self.linear_precision) for b in beta)
963
+ else:
964
+ beta = beta.to(self.linear_precision)
965
+
966
+ gate = self.compute_gate(beta)
967
+
968
+ # Retrieve cache if needed
969
+ use_cache = kwargs.get("use_cache", False)
970
+ recurrent_state, qkvb = self.get_cache(use_cache)
971
+
972
+ if qkvb is not None and qkvb[0].shape == q.shape:
973
+ q = torch.cat([qkvb[0].to(q.device), q], dim=2).to(self.linear_precision)
974
+ k = torch.cat([qkvb[1].to(q.device), k], dim=2).to(self.linear_precision)
975
+ v = torch.cat([qkvb[2].to(q.device), v], dim=2).to(self.linear_precision)
976
+ gate = torch.cat([qkvb[3].to(q.device), gate], dim=2).to(
977
+ self.linear_precision
978
+ )
979
+
980
+ output, state = self.chunk_delta_product_forward(
981
+ q,
982
+ k,
983
+ v,
984
+ gate,
985
+ self.max_chunk_size,
986
+ n=self.order,
987
+ trick=self.trick,
988
+ linear=self.linear,
989
+ initial_state=recurrent_state,
990
+ use_checkpoint=not (use_cache),
991
+ linear_precision=self.linear_precision,
992
+ )
993
+
994
+ # Save cache if needed
995
+ self.save_cache(use_cache, q, k, v, gate, state)
996
+
997
+ return output
998
+
999
+ @staticmethod
1000
+ def chunk_delta_product_forward(
1001
+ query: torch.Tensor,
1002
+ key: torch.Tensor,
1003
+ value: torch.Tensor,
1004
+ beta_gate: torch.Tensor,
1005
+ chunk_size: int,
1006
+ n: int = 1,
1007
+ trick: str = "derivative",
1008
+ linear: bool = True,
1009
+ initial_state: Optional[torch.Tensor] = None,
1010
+ use_checkpoint: bool = True,
1011
+ linear_precision: torch.dtype = torch.float32,
1012
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
1013
+ """
1014
+ Chunkwise parallel implementation https://arxiv.org/abs/2406.06484
1015
+ For each chunk, processes chunk_size * n_orders steps (virtual tokens) in order.
1016
+ """
1017
+
1018
+ # --- Main chunk_delta_product_forward logic ---
1019
+
1020
+ batch_size, num_heads, seq_len, head_dim = query.shape
1021
+ chunk_size = get_valid_chunk_size(seq_len, chunk_size)
1022
+ num_chunks = seq_len // chunk_size
1023
+
1024
+ query_n = query if n == 1 else expand_virtual_tokens(query, n, trick)
1025
+ key_n = key if n == 1 else expand_virtual_tokens(key, n, trick)
1026
+ value_n = value if n == 1 else expand_virtual_tokens(value, n, trick)
1027
+ beta_n = beta_gate if n == 1 else expand_virtual_tokens(beta_gate, n, trick)
1028
+
1029
+ q_chunks = chunk_sequence(query_n, num_chunks, chunk_size * n)
1030
+ k_chunks = chunk_sequence(key_n, num_chunks, chunk_size * n)
1031
+ v_chunks = chunk_sequence(value_n, num_chunks, chunk_size * n)
1032
+ beta_chunks = chunk_sequence(beta_n, num_chunks, chunk_size * n)
1033
+
1034
+ k_beta = k_chunks * beta_chunks
1035
+ v_beta = v_chunks * beta_chunks
1036
+
1037
+ householder = -(k_beta @ k_chunks.transpose(-2, -1)).tril(-1)
1038
+ householder = ensure_stability(householder, min_val=-1e4, max_val=1e4)
1039
+
1040
+ # size : N = chunk_size * n
1041
+ inv_hh = fast_invert_matrix(householder, dtype=linear_precision) # [(...),N,N]
1042
+
1043
+ w = ensure_stability(torch.matmul(inv_hh, k_beta), min_val=-1e4, max_val=1e4)
1044
+ u = ensure_stability(torch.matmul(inv_hh, v_beta), min_val=-1e4, max_val=1e4)
1045
+
1046
+ state_shape = (batch_size, num_heads, n, head_dim, head_dim)
1047
+ if initial_state is not None and initial_state.shape == state_shape:
1048
+ state = initial_state.to(device=query.device, dtype=linear_precision)
1049
+ else:
1050
+ state = torch.full(
1051
+ state_shape,
1052
+ fill_value=1e-6, # stability if unlinear activation
1053
+ device=query.device,
1054
+ dtype=linear_precision,
1055
+ )
1056
+
1057
+ output, final_state = sequential_delta_product_scan(
1058
+ q_chunks.to(dtype=linear_precision),
1059
+ w.to(dtype=linear_precision),
1060
+ u.to(dtype=linear_precision),
1061
+ n,
1062
+ linear,
1063
+ chunk_size,
1064
+ state.to(dtype=linear_precision),
1065
+ linear_precision=linear_precision,
1066
+ use_checkpoint=use_checkpoint,
1067
+ )
1068
+
1069
+ idx_last_order = torch.arange(chunk_size, device=output.device) * n + (n - 1)
1070
+ output = output[:, :, :, idx_last_order, :] # [B, H, num_chunks, chunk_size, D]
1071
+ output = output.reshape(batch_size, num_heads, seq_len, head_dim)
1072
+
1073
+ return output.to(dtype=linear_precision), final_state.to(dtype=linear_precision)
1074
+
1075
+
1076
+ def sequential_delta_product_scan(
1077
+ q_chunks: torch.Tensor,
1078
+ w: torch.Tensor,
1079
+ u: torch.Tensor,
1080
+ n_orders: int,
1081
+ linear_activation: bool,
1082
+ current_chunk_size: int,
1083
+ initial_recurrent_state: torch.Tensor,
1084
+ linear_precision: torch.dtype,
1085
+ use_checkpoint: bool,
1086
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
1087
+ """
1088
+ DeltaProduct implementation https://arxiv.org/abs/2502.10297
1089
+ Implements the per-token Householder state updates.
1090
+ """
1091
+ batch, head, num_chunks_inner, chunk_n_total, dim = q_chunks.shape
1092
+ output_inner = torch.empty_like(q_chunks)
1093
+ # initial_recurrent_state is H_{last_token_of_prev_chunk, n-1} ([B, H, D, D])
1094
+ h_0_base = initial_recurrent_state[:, :, -1, :, :].clone()
1095
+
1096
+ def process_one_chunk(
1097
+ q_chunk_params: torch.Tensor,
1098
+ w_chunk_params: torch.Tensor,
1099
+ u_chunk_params: torch.Tensor,
1100
+ h_0_base: torch.Tensor,
1101
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
1102
+ """
1103
+ Process a single chunk (with per-token state for n_orders > 1).
1104
+ """
1105
+ o_intra_current_chunk = torch.zeros(
1106
+ batch,
1107
+ head,
1108
+ chunk_n_total,
1109
+ dim,
1110
+ device=q_chunk_params.device,
1111
+ dtype=linear_precision,
1112
+ )
1113
+ o_inter_current_chunk = torch.zeros_like(o_intra_current_chunk)
1114
+ current_accumulated_state_per_token = (
1115
+ h_0_base.unsqueeze(2).expand(-1, -1, current_chunk_size, -1, -1).clone()
1116
+ ) # [B, H, current_chunk_size, D, D]
1117
+
1118
+ for step in range(n_orders):
1119
+ idx_virtual_tokens = (
1120
+ torch.arange(current_chunk_size, device=q_chunk_params.device)
1121
+ * n_orders
1122
+ + step
1123
+ )
1124
+ q_s = q_chunk_params[:, :, idx_virtual_tokens, :]
1125
+ w_s = w_chunk_params[:, :, idx_virtual_tokens, :]
1126
+ u_s = u_chunk_params[:, :, idx_virtual_tokens, :]
1127
+
1128
+ state_input_for_this_step = current_accumulated_state_per_token
1129
+
1130
+ ## BLAS/cuBLAS einsum "bhcd,bhcdd->bhcd"
1131
+ k_trans_h_old = (
1132
+ torch.matmul(
1133
+ w_s.unsqueeze(-2),
1134
+ state_input_for_this_step,
1135
+ )
1136
+ .squeeze(-2)
1137
+ .to(dtype=linear_precision)
1138
+ )
1139
+
1140
+ u_val = u_s - k_trans_h_old
1141
+
1142
+ o_inter_current_chunk[:, :, idx_virtual_tokens, :] = (
1143
+ torch.matmul(q_s.unsqueeze(-2), state_input_for_this_step)
1144
+ .squeeze(-2)
1145
+ .to(dtype=linear_precision)
1146
+ )
1147
+
1148
+ ## BLAS/cuBLAS einsum "bhcd,bhcd->bhcd"
1149
+ o_intra_current_chunk[:, :, idx_virtual_tokens, :] = (q_s * u_val).to(
1150
+ dtype=linear_precision
1151
+ )
1152
+
1153
+ outer_product_term = torch.matmul(w_s.unsqueeze(-1), u_val.unsqueeze(-2))
1154
+ new_state_i_per_token = state_input_for_this_step + outer_product_term
1155
+ new_state_i_per_token = ensure_stability(
1156
+ new_state_i_per_token, min_val=-1e4, max_val=1e4
1157
+ )
1158
+ current_accumulated_state_per_token = new_state_i_per_token.to(
1159
+ dtype=linear_precision
1160
+ )
1161
+ # Return all needed for next chunk
1162
+ return (
1163
+ o_intra_current_chunk,
1164
+ o_inter_current_chunk,
1165
+ current_accumulated_state_per_token[:, :, -1, :, :], # new h_0_base
1166
+ )
1167
+
1168
+ for chunk_idx_inner in range(num_chunks_inner):
1169
+ q_chunk_params = q_chunks[:, :, chunk_idx_inner]
1170
+ w_chunk_params = w[:, :, chunk_idx_inner]
1171
+ u_chunk_params = u[:, :, chunk_idx_inner]
1172
+
1173
+ # Checkpointed call if training
1174
+ call = (
1175
+ partial(checkpoint, use_reentrant=False)
1176
+ if use_checkpoint
1177
+ else lambda f, *a: f(*a)
1178
+ )
1179
+ o_intra, o_inter, h_0_base = call(
1180
+ process_one_chunk,
1181
+ q_chunk_params,
1182
+ w_chunk_params,
1183
+ u_chunk_params,
1184
+ h_0_base,
1185
+ )
1186
+ if not linear_activation: # unlinear activation between chunks
1187
+ h_0_base = unlinear_activation(h_0_base).to(dtype=linear_precision)
1188
+ output_inner[:, :, chunk_idx_inner] = o_intra + o_inter
1189
+
1190
+ return output_inner, h_0_base
1191
+
1192
+
1193
+ def unlinear_activation(x: torch.Tensor, scale: float = 2.0) -> torch.Tensor:
1194
+ """Unlinear activation between chunk"""
1195
+ x_n = x.norm(p=2, dim=-1, keepdim=True) + 1e-6
1196
+ x_gelu = F.gelu(scale * x / x_n, approximate="tanh") # pylint: disable=not-callable
1197
+ return (x / scale) * x_gelu
1198
+
1199
+
1200
+ def chunk_sequence(x: torch.Tensor, num_chunks: int, chunk_size: int) -> torch.Tensor:
1201
+ """Splits [B, H, S, D] to [B, H, num_chunks, chunk_size, D]"""
1202
+ batch_size, num_heads, _, head_dim = x.shape
1203
+ return x.reshape(batch_size, num_heads, num_chunks, chunk_size, head_dim)
1204
+
1205
+
1206
+ def expand_virtual_tokens(
1207
+ x: torch.Tensor, n: int, mode: str = "derivative"
1208
+ ) -> torch.Tensor:
1209
+ """Expand tokens into 'n' virtual tokens using the selected trick."""
1210
+ batch_size, num_heads, seq_len, head_dim = x.shape
1211
+ device, dtype = x.device, x.dtype
1212
+
1213
+ def derivative_expand(x: torch.Tensor) -> torch.Tensor:
1214
+ """Expand tokens using the derivative trick."""
1215
+ x_pad = torch.cat(
1216
+ [
1217
+ torch.zeros(
1218
+ batch_size, num_heads, n - 1, head_dim, device=device, dtype=dtype
1219
+ ),
1220
+ x,
1221
+ ],
1222
+ dim=2,
1223
+ )
1224
+ coeffs = torch.tensor(
1225
+ [(-1) ** k * math.comb(n - 1, k) for k in range(n)],
1226
+ device=device,
1227
+ dtype=dtype,
1228
+ )
1229
+ coeffs /= coeffs.norm(p=1)
1230
+ return (
1231
+ (x_pad.unfold(2, n, 1) * coeffs.view(1, 1, 1, 1, n))
1232
+ .flip(-1)
1233
+ .permute(0, 1, 2, 4, 3)
1234
+ .reshape(batch_size, num_heads, seq_len * n, head_dim)
1235
+ )
1236
+
1237
+ def rotative_expand(x: torch.Tensor) -> torch.Tensor:
1238
+ """Expand tokens using the rotative trick."""
1239
+ d_parity = head_dim // 2
1240
+ angles = torch.arange(n, device=device, dtype=dtype) * (2 * math.pi / n)
1241
+ cos = torch.cos(angles).view(1, 1, 1, n, 1)
1242
+ sin = torch.sin(angles).view(1, 1, 1, n, 1)
1243
+ if head_dim % 2:
1244
+ x_pairs = x[..., :-1].view(batch_size, num_heads, seq_len, d_parity, 2)
1245
+ else:
1246
+ x_pairs = x.view(batch_size, num_heads, seq_len, d_parity, 2)
1247
+ x_pairs = x_pairs.unsqueeze(3).expand(
1248
+ batch_size, num_heads, seq_len, n, d_parity, 2
1249
+ )
1250
+ x0, x1 = x_pairs[..., 0], x_pairs[..., 1]
1251
+ x0r = x0 * cos - x1 * sin
1252
+ x1r = x0 * sin + x1 * cos
1253
+ rot = torch.stack([x0r, x1r], -1).reshape(
1254
+ batch_size, num_heads, seq_len, n, d_parity * 2
1255
+ )
1256
+ if head_dim % 2:
1257
+ last = (
1258
+ x[..., -1]
1259
+ .unsqueeze(-1)
1260
+ .unsqueeze(3)
1261
+ .expand(batch_size, num_heads, seq_len, n, 1)
1262
+ )
1263
+ rot = torch.cat([rot, last], -1)
1264
+ return rot.reshape(batch_size, num_heads, seq_len * n, head_dim)
1265
+
1266
+ if mode == "derivative":
1267
+ return derivative_expand(x)
1268
+ if mode == "rotative":
1269
+ return rotative_expand(x)
1270
+ if mode == "combined":
1271
+ return (derivative_expand(x) + rotative_expand(x)) / 2
1272
+ raise ValueError(f"Unknown mode: {mode}")
1273
+
1274
+
1275
+ def extract_layer_idx(module_name: str) -> int:
1276
+ """Extract the layer index from a module name string."""
1277
+ match = re.search(r"\.(\d+)\.", module_name)
1278
+ if match:
1279
+ return int(match.group(1))
1280
+ return -1
1281
+
1282
+
1283
+ def find_embedding_lm(module: nn.Module) -> Optional[nn.Module]:
1284
+ """Find the embedding weight in a model module."""
1285
+ for _, child in module.named_modules():
1286
+ if hasattr(child, "embed_tokens") and hasattr(child.embed_tokens, "weight"):
1287
+ return child.embed_tokens
1288
+ if hasattr(child, "token_embeddings") and hasattr(
1289
+ child.token_embeddings, "weight"
1290
+ ):
1291
+ return child.token_embeddings
1292
+ return None
1293
+
1294
+
1295
+ def set_trainable_parameters(
1296
+ model: PreTrainedModel, trainable_patterns: List[str] = None
1297
+ ) -> PreTrainedModel:
1298
+ """Freeze model parameters except trainable_patterns."""
1299
+ if trainable_patterns is None:
1300
+ trainable_patterns = [
1301
+ "q_proj",
1302
+ "k_proj",
1303
+ "v_proj",
1304
+ "o_proj",
1305
+ "qkv_proj",
1306
+ "out_proj",
1307
+ "c_attn",
1308
+ "c_proj",
1309
+ "query",
1310
+ "key",
1311
+ "value",
1312
+ ]
1313
+
1314
+ for name, param in model.named_parameters():
1315
+ param.requires_grad = any(pattern in name for pattern in trainable_patterns)
1316
+
1317
+ trainable_layers = [n for n, p in model.named_parameters() if p.requires_grad]
1318
+ logger.info("Trainable parameters after freeze: %s", trainable_layers)
1319
+ return model
1320
+
1321
+
1322
+ def ensure_stability(
1323
+ tensor: torch.Tensor, min_val: float = -1e4, max_val: float = 1e4
1324
+ ) -> torch.Tensor:
1325
+ """stability forcing"""
1326
+ dtype = tensor.dtype
1327
+ center = (max_val + min_val) / 2
1328
+ tensor = torch.clamp(tensor, min=min_val, max=max_val)
1329
+ tensor = torch.nan_to_num(tensor, nan=center, posinf=max_val, neginf=min_val)
1330
+ return tensor.to(dtype=dtype)
1331
+
1332
+
1333
+ def apply_linear_attention_mask(
1334
+ attention_mask: torch.Tensor, v: torch.Tensor, padding_side: str = "right"
1335
+ ) -> torch.Tensor:
1336
+ """Extract if padding --> [B,S]"""
1337
+ if attention_mask.dim() == 4 and attention_mask.shape[1] == 1:
1338
+ mask = attention_mask.diagonal(dim1=-2, dim2=-1).squeeze(1)
1339
+ else:
1340
+ mask = attention_mask.squeeze(
1341
+ dim=tuple(
1342
+ i
1343
+ for i in range(1, attention_mask.dim())
1344
+ if attention_mask.shape[i] == 1
1345
+ )
1346
+ )
1347
+ # Ensure cast to the same dtype as v and convert to binary mask
1348
+ if not (
1349
+ mask.dtype == torch.bool
1350
+ or (
1351
+ mask.dtype in [torch.uint8, torch.int32, torch.int64]
1352
+ and mask.max() <= 1
1353
+ and mask.min() >= 0
1354
+ )
1355
+ ):
1356
+ mask = (mask >= 0).to(v.dtype) # [-inf, 0, 0, -inf] --> [0, 1, 1, 0]
1357
+ else:
1358
+ mask = mask.to(v.dtype)
1359
+ # mask is [batch, seq] --> Broadcast to v [batch, seq, (...)]
1360
+ if padding_side == "left":
1361
+ mask = mask[:, -v.shape[-2] :][(...,) + (None,) * (v.dim() - 2)]
1362
+ else: # right padding
1363
+ mask = mask[:, : v.shape[-2]][(...,) + (None,) * (v.dim() - 2)]
1364
+ return v * mask
1365
+
1366
+
1367
+ def truncate_attention_mask(
1368
+ hidden_states: torch.Tensor, attention_mask: torch.Tensor, max_length: int
1369
+ ) -> tuple[torch.Tensor, torch.Tensor]:
1370
+ """Truncate hidden_states and attention_mask to the last window of size max_length"""
1371
+ seq_dim = 1 # convention: (batch, seq, ...)
1372
+ seq_len = hidden_states.shape[seq_dim]
1373
+ if seq_len > max_length:
1374
+ hidden_states = hidden_states.narrow(seq_dim, seq_len - max_length, max_length)
1375
+ if attention_mask is not None:
1376
+ # mask [batch, seq]
1377
+ if attention_mask.dim() == 2:
1378
+ attention_mask = attention_mask[:, -max_length:]
1379
+ # mask [batch, seq, seq]
1380
+ elif attention_mask.dim() == 3:
1381
+ attention_mask = attention_mask[:, -max_length:, -max_length:]
1382
+ # mask [batch, 1, seq, seq]
1383
+ elif attention_mask.dim() == 4 and attention_mask.shape[1] == 1:
1384
+ attention_mask = attention_mask[:, :, -max_length:, -max_length:]
1385
+ else:
1386
+ raise ValueError(
1387
+ "No dimension in attention_mask matches sequence length of hidden_states."
1388
+ )
1389
+ return hidden_states, attention_mask
1390
+
1391
+
1392
+ def fast_invert_matrix(
1393
+ tri_tensor: torch.Tensor, dtype: torch.dtype = torch.float32
1394
+ ) -> torch.Tensor:
1395
+ """Equivalent to vectorized forward substitution applied to the identity matrix."""
1396
+ tri_tensor = tri_tensor.to(dtype=dtype).clone()
1397
+ chunk_size = tri_tensor.shape[-1]
1398
+
1399
+ for i in range(1, chunk_size):
1400
+ tri_tensor[..., i, :i] = tri_tensor[..., i, :i] + (
1401
+ tri_tensor[..., i, :, None].clone() * tri_tensor[..., :, :i].clone()
1402
+ ).sum(-2)
1403
+
1404
+ tri_tensor = tri_tensor + torch.eye(
1405
+ chunk_size, dtype=dtype, device=tri_tensor.device
1406
+ )
1407
+ return tri_tensor.to(dtype=dtype)
1408
+
1409
+
1410
+ def get_valid_chunk_size(total_l: int, chunk_size: int) -> int:
1411
+ """Return the largest chunk_size <= chunk_size that divides total_l."""
1412
+ for c in range(min(chunk_size, total_l), 0, -1):
1413
+ if total_l % c == 0:
1414
+ return c
1415
+ return 1
1416
+
1417
+
1418
+ ## RARELY
1419
+ def split_qkv(
1420
+ base_attn: nn.Module, qkv: torch.Tensor
1421
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
1422
+ """Split the QKV tensor into separate Q, K, and V tensors."""
1423
+ num_q_heads = getattr(base_attn, "num_q_heads", None)
1424
+ num_k_heads = getattr(base_attn, "num_k_heads", None)
1425
+ num_v_heads = getattr(base_attn, "num_v_heads", None)
1426
+ head_dim = getattr(base_attn, "head_dim", None)
1427
+
1428
+ if num_q_heads is None or num_k_heads is None or num_v_heads is None:
1429
+ raise ValueError(
1430
+ "Base attention must have num_q_heads, num_k_heads, and num_v_heads defined."
1431
+ )
1432
+
1433
+ q_len = num_q_heads * head_dim
1434
+ k_len = num_k_heads * head_dim
1435
+ v_len = num_v_heads * head_dim
1436
+
1437
+ q, k, v = torch.split(qkv, [q_len, k_len, v_len], dim=-1)
1438
+ return q, k, v
1439
+
1440
+
1441
+ ## OPTIONAL
1442
+ def match_dim(x: torch.Tensor, dim: int, target_size: int) -> torch.Tensor:
1443
+ """Match the size of tensor x along dimension dim to target_size by interpolation"""
1444
+ src_size = x.shape[dim]
1445
+ if src_size == target_size:
1446
+ return x
1447
+ x = torch.moveaxis(x, dim, -1)
1448
+ shape = x.shape
1449
+ if src_size < target_size:
1450
+ x = x.reshape(-1, 1, src_size)
1451
+ x = F.interpolate(x, size=target_size, mode="linear", align_corners=False)
1452
+ x = x.reshape(*shape[:-1], target_size)
1453
+ else:
1454
+ eye = torch.eye(target_size, src_size, device=x.device, dtype=x.dtype)
1455
+ x = F.linear(x, eye) # pylint: disable=not-callable
1456
+ x = torch.moveaxis(x, -1, dim)
1457
+ return x
1458
+
1459
+
1460
+ def soft_clamp(
1461
+ x: torch.Tensor, min_val: float = 1e-6, max_val: float = 1 - 1e-6
1462
+ ) -> torch.Tensor:
1463
+ """Differentiable clamping for stability"""
1464
+ dtype = x.dtype
1465
+ scale = (max_val - min_val) / 2
1466
+ center = (max_val + min_val) / 2
1467
+ return (torch.tanh((x - center) / scale) * scale + center).to(dtype=dtype)
1468
+
1469
+
1470
+ def describe(x: torch.Tensor, name="tensor") -> None:
1471
+ """Prints the shape, min, max, mean, and std of a tensor."""
1472
+ stats = (x.min(), x.max(), x.mean(), x.std())
1473
+ print(
1474
+ f"{name} shape: {tuple(x.shape)}, "
1475
+ + f"min: {stats[0]:.4g}, max: {stats[1]:.4g}, "
1476
+ + f"mean: {stats[2]:.4g}, std: {stats[3]:.4g}, "
1477
+ + f"dtype: {x.dtype}, device: {x.device}"
1478
+ )
train_tptt.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pylint: disable=too-many-arguments, too-many-positional-arguments
2
+
3
+ """
4
+ Author : Fabien FURFARO
5
+ """
6
+
7
+ from typing import Optional, Union
8
+
9
+ from transformers import PreTrainedModel, TrainerCallback
10
+
11
+ from .modeling_tptt import LiZAttention
12
+
13
+
14
+ class LiZACallback(TrainerCallback):
15
+ """
16
+ TrainerCallback to schedule mag_weight or enable/disable linear attention during training.
17
+
18
+ Modes:
19
+ - "gradual": linear interpolation from initial_weight to final_weight.
20
+ - "cyclic": alternate between values in weight_list at each step.
21
+ - "switch": alternately enable/disable linear attention at each step.
22
+ """
23
+
24
+ def __init__(
25
+ self,
26
+ model: PreTrainedModel,
27
+ mode: str = "gradual",
28
+ initial_weight: float = 0.0,
29
+ final_weight: float = 0.5,
30
+ transition_step: Union[int, tuple, list] = 100,
31
+ weight_list: Optional[list] = None,
32
+ switch_period: int = 1, # period for switching
33
+ ):
34
+ self.model = model
35
+ self.mode = mode
36
+
37
+ # Ensure initial_weight is a float scalar, not tuple/list
38
+ if isinstance(initial_weight, (tuple, list)):
39
+ initial_weight = initial_weight[0]
40
+ if isinstance(final_weight, (tuple, list)):
41
+ final_weight = final_weight[0]
42
+ self.initial_weight = float(initial_weight)
43
+ self.final_weight = float(final_weight)
44
+
45
+ # Ensure transition_step is an int scalar, not tuple/list
46
+ self.transition_step = ensure_int(transition_step)
47
+
48
+ # For cyclic mode: ensure all weights are float scalars
49
+ if weight_list is not None:
50
+ self.weight_list = [
51
+ float(w[0]) if isinstance(w, (tuple, list)) else float(w)
52
+ for w in weight_list
53
+ ]
54
+ else:
55
+ self.weight_list = [self.initial_weight, self.final_weight]
56
+
57
+ # For switch_alternate mode
58
+ self.switch_period = int(switch_period)
59
+
60
+ def on_step_end(self, args, state, control, **kwargs):
61
+ current_step = state.global_step
62
+ transition_step = self.transition_step
63
+
64
+ # Ensure current_step and transition_step are plain ints
65
+ current_step = ensure_int(current_step)
66
+ transition_step = ensure_int(transition_step)
67
+
68
+ # Select mag_weight or enable/disable linear attention according to mode
69
+ if self.mode == "gradual":
70
+ if current_step <= transition_step:
71
+ weight = self.initial_weight + (
72
+ self.final_weight - self.initial_weight
73
+ ) * (current_step / transition_step)
74
+ else:
75
+ weight = self.final_weight
76
+ for _, module in self.model.named_modules():
77
+ if isinstance(module, LiZAttention):
78
+ module.mag_weight = weight
79
+
80
+ elif self.mode == "cyclic":
81
+ idx = current_step % len(self.weight_list)
82
+ weight = self.weight_list[idx]
83
+ for _, module in self.model.named_modules():
84
+ if isinstance(module, LiZAttention):
85
+ module.mag_weight = weight
86
+
87
+ elif self.mode == "switch":
88
+ # Alternately enable/disable linear attention every switch_period steps
89
+ disable = (current_step // self.switch_period) % 2 == 0
90
+ for _, module in self.model.named_modules():
91
+ if isinstance(module, LiZAttention):
92
+ module.disable_linear_attn = disable
93
+
94
+ else:
95
+ raise ValueError(f"Unknown mode: {self.mode}")
96
+
97
+ def on_log(self, args, state, control, logs=None, **kwargs):
98
+ mag_weight = None
99
+ disable_linear_attn = None
100
+ # Log the current mag_weight and disable_linear_attn
101
+ for _, module in self.model.named_modules():
102
+ if isinstance(module, LiZAttention):
103
+ mag_weight = getattr(module, "mag_weight", None)
104
+ disable_linear_attn = getattr(module, "disable_linear_attn", None)
105
+ break
106
+ if mag_weight is not None and logs is not None:
107
+ logs["mag_weight"] = float(mag_weight)
108
+ if disable_linear_attn is not None and logs is not None:
109
+ logs["disable_linear_attn"] = not bool(disable_linear_attn)
110
+
111
+
112
+ def ensure_int(value: Union[int, tuple, list]) -> int:
113
+ """Ensure the value is a plain integer."""
114
+ if isinstance(value, (tuple, list)):
115
+ value = int(value[0])
116
+ if hasattr(value, "item"):
117
+ value = int(value.item())
118
+ return value
119
+
120
+
121
+ class SaveBestModelCallback(TrainerCallback):
122
+ """TrainerCallback to save the best model based on evaluation loss."""
123
+
124
+ def __init__(self):
125
+ self.best_metric = float("inf")
126
+
127
+ def on_evaluate(self, args, state, control, metrics=None, **kwargs):
128
+ if metrics is not None and "eval_loss" in metrics:
129
+ if metrics["eval_loss"] < self.best_metric:
130
+ self.best_metric = metrics["eval_loss"]
131
+ control.should_save = True # Trigger save
132
+ else:
133
+ control.should_save = False # Skip save