ffurfaro commited on
Commit
fe0fd34
·
verified ·
1 Parent(s): 949f866

Upload model + init tptt code

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ tokenizer.json filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language: en
3
+ license: apache-2.0
4
+ library_name: transformers
5
+ tags:
6
+ - tptt
7
+ - peft
8
+ - trust_remote_code
9
+ ---
10
+
11
+ # Titans-Qwen2.5-1.5B
12
+
13
+ Titanesque version of `Qwen/Qwen2.5-1.5B` with parallel linearized attention (TPTT 😊) and PEFT.
14
+
15
+ ## Model Details
16
+
17
+ - **Architecture:** TpttModel
18
+ - **Base model:** Qwen/Qwen2.5-1.5B
19
+ - **LiZA config:** operator=delta_rule, mag=0.5
20
+ - **LoRA config:** r=8, alpha=16, dropout=0.05
21
+ - **torch_dtype:** bfloat16
22
+
23
+ ## Usage
24
+
25
+
26
+ ```python
27
+ from transformers import AutoModelForCausalLM, AutoTokenizer
28
+
29
+ model = AutoModelForCausalLM.from_pretrained(
30
+ "ffurfaro/Titans-Qwen2.5-1.5B",
31
+ trust_remote_code=True
32
+ )
33
+ tokenizer = AutoTokenizer.from_pretrained("ffurfaro/Titans-Qwen2.5-1.5B")
34
+
35
+ prompt = "Your prompt here"
36
+ inputs = tokenizer(prompt, return_tensors="pt")
37
+ outputs = model.generate(**inputs, max_new_tokens=100)
38
+ print(tokenizer.decode(outputs, skip_special_tokens=True))
39
+
40
+ ```
41
+
42
+ ## Training
43
+
44
+ - **Dataset:** yahma/alpaca-cleaned
45
+ - **Platform:** Kaggle
46
+ - **Hardware:** NVIDIA 2xT4
47
+ - **Batch size:** 2
48
+ - **Epochs:** 1.0
49
+ - **Learning rate (final):** 0.0
50
+ - **Loss (final):** 1.4369
51
+ - **Training runtime:** 81.526 sec
52
+ - **Samples per second:** 1.227
53
+ - **Steps per second:** 0.307
54
+ - **Total FLOPs:** 201603022848000.0
55
+ - **Gradient norm (final):** 0.8238493204116821
56
+
57
+ ## Evaluation
58
+
59
+ - **Metrics:** Training loss only (no eval yet, table soon : PiQA, ARC, Hella, Wino, GSM8K, MMLU)
60
+ - **Results:** Final training loss: 1.4369
61
+
62
+
63
+ ## Citation & Contact
64
+
65
+ If you use TPTT in your academic work, please cite [Furfaro](https://huggingface.co/ffurfaro). For questions or support, please open an issue on the [GitHub repository](https://github.com/fabienfrfr/tptt) or contact the maintainer.
66
+
67
+
68
+ ---
__init__.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __version__ = "0.1.0"
2
+
3
+ from .configuration_tptt import TpttConfig, generate_model_card
4
+ from .modeling_tptt import (
5
+ AttentionOperator,
6
+ LCache,
7
+ LiZAttention,
8
+ TpttModel,
9
+ get_tptt_model,
10
+ )
11
+ from .pipeline_tptt import TpttPipeline
12
+ from .train_tptt import AdjustMaGWeightCallback
13
+
14
+ __all__ = [
15
+ "TpttConfig",
16
+ "TpttModel",
17
+ "TpttPipeline",
18
+ "get_tptt_model",
19
+ "AdjustMaGWeightCallback",
20
+ "LCache",
21
+ "AttentionOperator",
22
+ "LiZAttention",
23
+ "generate_model_card",
24
+ ]
adapter_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5da9d8c6ef9e66a127464a1d1859a4ffd915e3012467c46bc201d48f8d0d8ab9
3
+ size 8747944
added_tokens.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "</tool_call>": 151658,
3
+ "<tool_call>": 151657,
4
+ "<|box_end|>": 151649,
5
+ "<|box_start|>": 151648,
6
+ "<|endoftext|>": 151643,
7
+ "<|file_sep|>": 151664,
8
+ "<|fim_middle|>": 151660,
9
+ "<|fim_pad|>": 151662,
10
+ "<|fim_prefix|>": 151659,
11
+ "<|fim_suffix|>": 151661,
12
+ "<|im_end|>": 151645,
13
+ "<|im_start|>": 151644,
14
+ "<|image_pad|>": 151655,
15
+ "<|object_ref_end|>": 151647,
16
+ "<|object_ref_start|>": 151646,
17
+ "<|quad_end|>": 151651,
18
+ "<|quad_start|>": 151650,
19
+ "<|repo_name|>": 151663,
20
+ "<|video_pad|>": 151656,
21
+ "<|vision_end|>": 151653,
22
+ "<|vision_pad|>": 151654,
23
+ "<|vision_start|>": 151652
24
+ }
config.json ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "TpttModel"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_tptt.TpttConfig",
7
+ "AutoModelForCausalLM": "modeling_tptt.TpttModel"
8
+ },
9
+ "base_model_name": "Qwen/Qwen2.5-1.5B",
10
+ "lora_config": {
11
+ "alpha_pattern": {},
12
+ "auto_mapping": null,
13
+ "base_model_name_or_path": null,
14
+ "bias": "none",
15
+ "eva_config": null,
16
+ "exclude_modules": null,
17
+ "fan_in_fan_out": false,
18
+ "inference_mode": false,
19
+ "init_lora_weights": true,
20
+ "layer_replication": null,
21
+ "layers_pattern": null,
22
+ "layers_to_transform": null,
23
+ "loftq_config": {},
24
+ "lora_alpha": 16,
25
+ "lora_bias": false,
26
+ "lora_dropout": 0.05,
27
+ "megatron_config": null,
28
+ "megatron_core": "megatron.core",
29
+ "modules_to_save": null,
30
+ "peft_type": "LORA",
31
+ "r": 8,
32
+ "rank_pattern": {},
33
+ "revision": null,
34
+ "target_modules": [
35
+ "v_proj",
36
+ "o_proj",
37
+ "q_proj",
38
+ "k_proj"
39
+ ],
40
+ "task_type": "CAUSAL_LM",
41
+ "use_dora": false,
42
+ "use_rslora": false
43
+ },
44
+ "mag_weight": 0.5,
45
+ "max_chunk_size": 32,
46
+ "max_self_attn_length": 2048,
47
+ "model_type": "tptt",
48
+ "operator_mode": "delta_rule",
49
+ "target_modules_names": [
50
+ "attn",
51
+ "self_attn",
52
+ "attention"
53
+ ],
54
+ "torch_dtype": "bfloat16",
55
+ "transformers_version": "4.49.0"
56
+ }
configuration_tptt.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Author : Fabien FURFARO
3
+ """
4
+
5
+ import os
6
+ import re
7
+ from typing import List, Optional, Union
8
+
9
+ from transformers import AutoConfig, PretrainedConfig
10
+
11
+
12
+ def convert_sets_to_lists(obj):
13
+ if isinstance(obj, set):
14
+ return list(obj)
15
+ elif isinstance(obj, dict):
16
+ return {k: convert_sets_to_lists(v) for k, v in obj.items()}
17
+ elif isinstance(obj, (list, tuple)):
18
+ return [convert_sets_to_lists(x) for x in obj]
19
+ else:
20
+ return obj
21
+
22
+
23
+ class TpttConfig(PretrainedConfig):
24
+ """
25
+ Configuration class for the TPTT model.
26
+ This class merges the backbone config (e.g., Llama) with custom TPTT parameters,
27
+ """
28
+
29
+ model_type = "tptt"
30
+ auto_map = {
31
+ "AutoModelForCausalLM": "modeling_tptt.TpttModel",
32
+ "AutoConfig": "configuration_tptt.TpttConfig",
33
+ }
34
+ architectures = ["TpttModel"]
35
+
36
+ def __init__(
37
+ self,
38
+ base_model_config: Optional[Union[str, dict, PretrainedConfig]] = None,
39
+ base_model_name: str = "meta-llama/Llama-3.2-1B",
40
+ name_or_path: Optional[str] = None,
41
+ target_modules_names: Optional[List[str]] = None,
42
+ operator_mode: str = "delta_rule",
43
+ max_self_attn_length: int = 4096,
44
+ mag_weight: float = 0.5,
45
+ max_chunk_size: int = 64,
46
+ lora_config: Optional[dict] = None, # only serialized accepted
47
+ **kwargs,
48
+ ):
49
+
50
+ if base_model_config is not None:
51
+ if isinstance(base_model_config, str):
52
+ # Load config from Hugging Face Hub or a local path
53
+ base_model_config = AutoConfig.from_pretrained(
54
+ base_model_config
55
+ ).to_dict()
56
+ elif isinstance(base_model_config, PretrainedConfig):
57
+ base_model_config = base_model_config.to_dict()
58
+ # Merge all backbone fields into this config
59
+ for k, v in base_model_config.items():
60
+ setattr(self, k, v)
61
+
62
+ self.base_model_name = base_model_name
63
+ self._name_or_path = (
64
+ name_or_path
65
+ if name_or_path is not None
66
+ else "Titans-" + base_model_name.split("/", 1)[1]
67
+ )
68
+
69
+ self.target_modules_names = target_modules_names or [
70
+ "attn",
71
+ "self_attn",
72
+ "attention",
73
+ ]
74
+ self.operator_mode = operator_mode
75
+ self.mag_weight = mag_weight
76
+ self.max_chunk_size = max_chunk_size
77
+ self.max_self_attn_length = max_self_attn_length
78
+
79
+ self.lora_config = lora_config
80
+ if lora_config is not None:
81
+ if hasattr(self.lora_config.get("peft_type"), "value"):
82
+ self.lora_config["peft_type"] = self.lora_config["peft_type"].value
83
+ self.lora_config = convert_sets_to_lists(self.lora_config)
84
+
85
+ super().__init__(**kwargs)
86
+ # Copy class attributes to instance for serialization (save dict)
87
+ self.model_type = self.__class__.model_type
88
+ self.auto_map = self.__class__.auto_map
89
+ self.architectures = self.__class__.architectures
90
+
91
+
92
+ TpttConfig.register_for_auto_class()
93
+
94
+
95
+ def extract_template_variables(template):
96
+ return set(re.findall(r"\{([^{}]+)\}", template))
97
+
98
+
99
+ def generate_model_card(path: str, config, **kwargs):
100
+ """Generate model card from template and training metadata."""
101
+ template_path = os.path.join(os.path.dirname(__file__), "model_card_template.md")
102
+ with open(template_path, "r", encoding="utf-8") as f:
103
+ template = f.read()
104
+
105
+ # Flatten config
106
+ def flatten_config(config):
107
+ result = {}
108
+ if hasattr(config, "__dict__"):
109
+ config = config.__dict__
110
+ for k, v in config.items():
111
+ if isinstance(v, dict):
112
+ for subk, subv in v.items():
113
+ result[f"{k}_{subk}"] = subv
114
+ else:
115
+ result[k] = v
116
+ return result
117
+
118
+ variables = flatten_config(config)
119
+ variables.update(kwargs)
120
+ variables["model_id"] = os.path.basename(path)
121
+
122
+ # Extract variables from template
123
+ template_vars = extract_template_variables(template)
124
+
125
+ # Add default values for missing variables
126
+ for var in template_vars:
127
+ if var not in variables:
128
+ variables[var] = "N/A"
129
+
130
+ # Handle list conversion (optional but useful)
131
+ for k, v in variables.items():
132
+ if isinstance(v, list):
133
+ variables[k] = ", ".join(map(str, v))
134
+
135
+ model_card_content = template.format(**variables)
136
+ with open(os.path.join(path, "README.md"), "w", encoding="utf-8") as f:
137
+ f.write(model_card_content)
generation_config.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "transformers_version": "4.49.0"
4
+ }
merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4baae0716ce1cff5226e6d2078a6bb2560b18d9fd6e47f1ca4caef1469f8dd75
3
+ size 3096231760
modeling_tptt.py ADDED
@@ -0,0 +1,778 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This module implements the TPTT model with linear attention (LiZA) and LoRA support.
3
+ Author : Fabien FURFARO
4
+ """
5
+
6
+ import os
7
+ import re
8
+ import shutil
9
+ from typing import Dict, List, Optional
10
+
11
+ import torch
12
+ import torch.nn.functional as F
13
+ from einops import rearrange
14
+ from huggingface_hub import hf_hub_download, list_repo_files
15
+ from peft import LoraConfig, get_peft_model
16
+ from safetensors import safe_open
17
+ from torch import nn
18
+ from transformers import AutoModelForCausalLM, DynamicCache, PreTrainedModel
19
+ from transformers.configuration_utils import PretrainedConfig
20
+
21
+ from .configuration_tptt import TpttConfig
22
+
23
+
24
+ def import_fla_ops():
25
+ """flash linear attention"""
26
+ if torch.cuda.is_available():
27
+ try:
28
+ from fla.ops.gla import fused_chunk_gla, fused_recurrent_gla
29
+
30
+ return fused_chunk_gla, fused_recurrent_gla
31
+ except ImportError:
32
+ return None, None
33
+ return None, None
34
+
35
+
36
+ fused_chunk_gla, fused_recurrent_gla = import_fla_ops() # TODO: add all ops
37
+
38
+
39
+ class LCache:
40
+ """
41
+ Cache for storing intermediate states of linear attention layers.
42
+ Supports a sliding window if max_length is set.
43
+ """
44
+
45
+ def __init__(self):
46
+ """
47
+ Initialize the cache.
48
+
49
+ Args:
50
+ max_length (Optional[int]): Maximum number of tokens to keep per layer (if set).
51
+ """
52
+ self.states: List[Dict[str, torch.Tensor]] = []
53
+ self.seen_tokens = 0
54
+
55
+ def __getitem__(self, layer_idx: int) -> Optional[Dict[str, torch.Tensor]]:
56
+ """
57
+ Retrieve the state for the given layer index, if it exists.
58
+ """
59
+ if layer_idx < len(self.states):
60
+ return self.states[layer_idx]
61
+ return None
62
+
63
+ def update(self, layer_idx: int, **kwargs):
64
+ """
65
+ Update the cache for a given layer.
66
+ If max_length is set, keep only the last max_length tokens in any sequence state.
67
+ """
68
+ detached_kwargs = {}
69
+ for key, value in kwargs.items():
70
+ if isinstance(value, torch.Tensor):
71
+ value = value.detach()
72
+ detached_kwargs[key] = value
73
+
74
+ if len(self.states) <= layer_idx:
75
+ self.states.append(detached_kwargs)
76
+ else:
77
+ self.states[layer_idx].update(detached_kwargs)
78
+
79
+ def reset(self):
80
+ """
81
+ Reset the cache and token counter.
82
+ """
83
+ self.states.clear()
84
+ self.seen_tokens = 0
85
+
86
+
87
+ class LiZAttention(nn.Module):
88
+ """LiZA Linear Attention module, mixing linear and vanilla attention."""
89
+
90
+ def __init__(
91
+ self,
92
+ base_attn: nn.Module,
93
+ layer_idx: int,
94
+ base_config, # Backbone Config
95
+ linear_cache: LCache = None,
96
+ operator_mode: str = "delta_rule",
97
+ max_self_attn_length: int = 2048,
98
+ mag_weight: float = 0.5,
99
+ max_chunk_size: int = 64,
100
+ ):
101
+ super().__init__()
102
+ self.base_attn = base_attn
103
+ self.base_config = base_config
104
+ self.layer_idx = layer_idx
105
+ self.max_self_attn_length = max_self_attn_length
106
+ self.mag_weight = mag_weight
107
+ self.max_chunk_size = max_chunk_size
108
+ self.linear_cache = linear_cache or LCache()
109
+ (
110
+ self.num_heads,
111
+ self.head_dim,
112
+ self.num_key_value_heads,
113
+ self.num_key_value_groups,
114
+ ) = self._get_attention_parameters(base_attn, base_config)
115
+ self.operator = get_attention_operator(operator_mode)
116
+ self.pool_g = nn.AdaptiveAvgPool1d(
117
+ output_size=self.head_dim * self.num_key_value_heads
118
+ )
119
+
120
+ def _get_attention_parameters(self, base_attn, base_config):
121
+ """Retrieve the attention parameters from the base attention module."""
122
+ # first order base attention module and second order config
123
+ num_heads = (
124
+ getattr(base_attn, "num_heads", None)
125
+ or getattr(base_attn, "num_q_heads", None)
126
+ or getattr(base_config, "num_heads", None)
127
+ or getattr(base_config, "num_attention_heads", None)
128
+ )
129
+ head_dim = getattr(base_attn, "head_dim", None) or getattr(
130
+ base_config, "head_dim", None
131
+ )
132
+ num_key_value_heads = (
133
+ getattr(base_attn, "num_kv_heads", None)
134
+ or getattr(base_attn, "num_k_heads", None)
135
+ or getattr(base_config, "num_key_value_heads", None)
136
+ or num_heads # fallback
137
+ )
138
+ num_key_value_groups = getattr(base_attn, "num_key_value_groups", None) or (
139
+ num_heads // num_key_value_heads if num_heads and num_key_value_heads else 1
140
+ )
141
+ return (
142
+ num_heads,
143
+ head_dim,
144
+ num_key_value_heads,
145
+ num_key_value_groups,
146
+ )
147
+
148
+ def _apply_projections(self, hidden_states):
149
+ base_attn = self.base_attn
150
+ if hasattr(base_attn, "q_proj"):
151
+ # LLama, OLMO and Mistral style
152
+ q = base_attn.q_proj(hidden_states)
153
+ k = base_attn.k_proj(hidden_states)
154
+ v = base_attn.v_proj(hidden_states)
155
+ out_proj = base_attn.o_proj
156
+ elif hasattr(base_attn, "qkv_proj"):
157
+ # OpenELM and GPT-Neo style : QKV fused, split on the last dimension
158
+ qkv = base_attn.qkv_proj(hidden_states)
159
+ q, k, v = split_qkv(base_attn, qkv)
160
+ out_proj = base_attn.out_proj
161
+ elif hasattr(base_attn, "c_attn") and hasattr(base_attn, "c_proj"):
162
+ # GPT-2 style
163
+ qkv = base_attn.c_attn(hidden_states)
164
+ q, k, v = qkv.chunk(3, dim=-1)
165
+ out_proj = base_attn.c_proj
166
+ else:
167
+ raise ValueError("Unsupported attention module: cannot find projections.")
168
+ # Ensure stability
169
+ q = torch.clamp(q, min=-1e4, max=1e4)
170
+ k = torch.clamp(k, min=-1e4, max=1e4)
171
+ v = torch.clamp(v, min=-1e4, max=1e4)
172
+ return q, k, v, out_proj
173
+
174
+ def _prepare_attn_input(self, q, k, v, gate_norm):
175
+ # Gating for linear attn
176
+ g = self.pool_g(k)
177
+
178
+ # Reshape for multi-head
179
+ q = rearrange(q, "b n (h d) -> b h n d", h=self.num_heads)
180
+ k = rearrange(k, "b n (h d) -> b h n d", h=self.num_key_value_heads)
181
+ v = rearrange(v, "b n (h d) -> b h n d", h=self.num_key_value_heads)
182
+ g = rearrange(g, "b n (h m) -> b h n m", h=self.num_key_value_heads)
183
+
184
+ # Repeat for GQA
185
+ k = repeat_kv(k, self.num_key_value_groups)
186
+ v = repeat_kv(v, self.num_key_value_groups)
187
+ g = repeat_kv(g, self.num_key_value_groups)
188
+
189
+ ## linear part
190
+ q = torch.clamp(F.softmax(q, dim=-1), min=1e-6, max=1 - 1e-6)
191
+ k = torch.clamp(F.softmax(k, dim=-1), min=1e-6, max=1 - 1e-6)
192
+
193
+ g = F.logsigmoid(g) / gate_norm
194
+ g = torch.clamp(g, min=-gate_norm, max=gate_norm)
195
+
196
+ # Convert to float32 for numerical stability and get model dtype
197
+ q, k, v, g = (x.to(torch.float32).contiguous() for x in (q, k, v, g))
198
+
199
+ return q, k, v, g
200
+
201
+ def _process_linear_attn(self, q, k, v, g, out_proj, tensor_dtype, kwargs):
202
+ # Retrieve recurrent state from cache (inference only)
203
+ if kwargs["use_cache"]:
204
+ last_state = self.linear_cache[self.layer_idx]
205
+ recurrent_state = (
206
+ last_state["recurrent_state"]
207
+ if last_state is not None and "recurrent_state" in last_state
208
+ else None
209
+ )
210
+ else:
211
+ recurrent_state = None
212
+
213
+ # Linear attention
214
+ o_lin, recurrent_state = self.operator(
215
+ q,
216
+ k,
217
+ v,
218
+ beta=g,
219
+ chunk_size=self.max_chunk_size,
220
+ recurrent_state=recurrent_state,
221
+ )
222
+ o_lin = rearrange(o_lin, "b h n d -> b n (h d)").to(tensor_dtype)
223
+ o_lin = out_proj(o_lin)
224
+ # Ensure stability (o_lin = soft_clamp(o_lin) ?)
225
+ o_lin = torch.clamp(o_lin, min=-1e4, max=1e4)
226
+
227
+ # Save recurrent state
228
+ if kwargs["use_cache"]:
229
+ self.linear_cache.update(self.layer_idx, recurrent_state=recurrent_state)
230
+ return o_lin
231
+
232
+ def _process_self_attn(self, hidden_states, attention_mask, kwargs):
233
+ # If cache_implementation="static" -> truncated attention
234
+ hidden_states, attention_mask = truncate_attention_mask(
235
+ hidden_states, attention_mask, self.max_self_attn_length
236
+ )
237
+
238
+ if kwargs.get("position_embeddings", None) is not None:
239
+ cos, sin = kwargs["position_embeddings"]
240
+ cos = cos[:, -self.max_self_attn_length :]
241
+ sin = sin[:, -self.max_self_attn_length :]
242
+ kwargs["position_embeddings"] = (cos, sin)
243
+
244
+ if isinstance(kwargs.get("past_key_value", None), DynamicCache):
245
+ # cache management
246
+ if len(kwargs["past_key_value"]) > self.layer_idx and self.layer_idx == 0:
247
+ kwargs["past_key_value"].crop(self.max_self_attn_length - 1)
248
+
249
+ # Standard attention (mask and rotation is applied inside)
250
+ base_attn_outputs = self.base_attn(
251
+ hidden_states,
252
+ attention_mask=attention_mask,
253
+ **kwargs,
254
+ )
255
+
256
+ if isinstance(base_attn_outputs, tuple):
257
+ if len(base_attn_outputs) == 3:
258
+ o_base, attn_weights, present_key_value = base_attn_outputs
259
+ expected_attn_mode = 3
260
+ elif len(base_attn_outputs) == 2:
261
+ o_base, attn_weights = base_attn_outputs
262
+ present_key_value, expected_attn_mode = None, 2
263
+ else:
264
+ raise ValueError(
265
+ f"Unexpected number of outputs from base_attn: {len(base_attn_outputs)}"
266
+ )
267
+ else:
268
+ o_base = base_attn_outputs
269
+ attn_weights, present_key_value, expected_attn_mode = None, None, 1
270
+ # Ensure stability
271
+ o_base = torch.clamp(o_base, min=-1e4, max=1e4)
272
+ return o_base, attn_weights, present_key_value, expected_attn_mode
273
+
274
+ def forward(
275
+ self,
276
+ hidden_states: torch.Tensor,
277
+ attention_mask: Optional[torch.Tensor] = None,
278
+ **kwargs,
279
+ ):
280
+ device = hidden_states.device
281
+ tensor_dtype = hidden_states.dtype
282
+ self.base_attn.to(device)
283
+
284
+ if self.training:
285
+ kwargs.pop("past_key_value", None)
286
+ kwargs["use_cache"] = False
287
+ else:
288
+ # Force evaluation
289
+ kwargs["use_cache"] = True
290
+
291
+ kwargs.pop("position_ids", None) # obsolete
292
+
293
+ # Apply projections to hidden states
294
+ q, k, v, out_proj = self._apply_projections(hidden_states)
295
+
296
+ # Manage attention mask (with padding)
297
+ if attention_mask is not None:
298
+ # attention_mask -> [batch, seq], v: [batch, seq, ...]
299
+ v = apply_linear_attention_mask(attention_mask, v)
300
+
301
+ # Prepare inputs tensor for linear attn
302
+ gate_norm = kwargs.get("gate_logit_normalizer", 16)
303
+ q, k, v, g = self._prepare_attn_input(q, k, v, gate_norm)
304
+
305
+ # Process linear attn from mask
306
+ o_lin = self._process_linear_attn(q, k, v, g, out_proj, tensor_dtype, kwargs)
307
+
308
+ # Process self attn with truncation
309
+ o_base, attn_weights, present_key_value, expected_attn_mode = (
310
+ self._process_self_attn(hidden_states, attention_mask, kwargs)
311
+ )
312
+
313
+ # Force cast typing
314
+ o_lin = o_lin.to(tensor_dtype)
315
+ o_base = o_base.to(tensor_dtype)
316
+
317
+ # Apply Memory as Gate in self-attention (with max length management)
318
+ if o_lin.shape[1] > o_base.shape[1]:
319
+ o_padding = torch.zeros_like(o_lin).to(tensor_dtype)
320
+ o_padding[:, -o_base.shape[1] :] = o_base
321
+ o_base = o_padding # Left PAD mask
322
+ elif o_lin.shape[1] != o_base.shape[1]: # Abnormality
323
+ left_trunc = min(o_lin.shape[1], o_base.shape[1])
324
+ o_lin, o_base = o_lin[:, -left_trunc:], o_base[:, -left_trunc:]
325
+ out = self.mag_weight * o_lin + (1 - self.mag_weight) * o_base
326
+ # Ensure stability
327
+ out = torch.clamp(out, min=-1e4, max=1e4)
328
+
329
+ # Return output following transformer convention
330
+ if expected_attn_mode == 3:
331
+ return out, attn_weights, present_key_value
332
+ elif expected_attn_mode == 2:
333
+ return out, attn_weights
334
+ else:
335
+ return out
336
+
337
+
338
+ def get_tptt_model( # pylint: disable=too-many-arguments, too-many-positional-arguments
339
+ model: nn.Module,
340
+ base_config: PretrainedConfig, # ou LlamaConfig, MistralConfig, etc.
341
+ liza_attention: LiZAttention,
342
+ target_modules: list,
343
+ linear_cache: LCache = None,
344
+ operator_mode: str = "delta_rule",
345
+ mag_weight: float = 0.5,
346
+ max_chunk_size: int = 64,
347
+ max_self_attn_length: int = 2048,
348
+ ):
349
+ """Replace target modules in a model with LiZAttention."""
350
+ linear_cache = linear_cache or LCache()
351
+ # Inject LiZAttention into the model
352
+ for name, _ in model.named_modules():
353
+ if name in target_modules:
354
+ parent = model
355
+ *path, last = name.split(".")
356
+ for p in path:
357
+ parent = getattr(parent, p)
358
+ layer_idx = extract_layer_idx(name)
359
+ setattr(
360
+ parent,
361
+ last,
362
+ liza_attention(
363
+ getattr(parent, last),
364
+ layer_idx=layer_idx,
365
+ base_config=base_config,
366
+ linear_cache=linear_cache,
367
+ operator_mode=operator_mode,
368
+ max_self_attn_length=max_self_attn_length,
369
+ mag_weight=mag_weight,
370
+ max_chunk_size=max_chunk_size,
371
+ ),
372
+ )
373
+ return model, linear_cache
374
+
375
+
376
+ class TpttModel(PreTrainedModel):
377
+ """
378
+ TPTT model wrapper with linear attention (LiZA) and LoRA support.
379
+ Handles only architecture and weights.
380
+ """
381
+
382
+ config_class = TpttConfig
383
+
384
+ def __init__(
385
+ self,
386
+ config: TpttConfig,
387
+ **kwargs,
388
+ ):
389
+ """
390
+ Initialize TpttModel with a given config and backbone.
391
+ Injects LiZA attention modules into the backbone.
392
+ """
393
+ super().__init__(config, **kwargs)
394
+ repo_or_path = getattr(config, "_base_path", None) or config._name_or_path
395
+
396
+ # 1. Load backbone
397
+ backbone = AutoModelForCausalLM.from_pretrained(
398
+ config.base_model_name, **kwargs
399
+ )
400
+
401
+ # 2. Inject LiZA attention
402
+ self.linear_cache = LCache()
403
+ self.backbone, self.linear_cache = self.inject_liza_attention(
404
+ backbone, config, self.linear_cache
405
+ )
406
+ # 3. Apply LoRA if present of configured
407
+ if config.lora_config is not None:
408
+ lora_config_obj = LoraConfig(**config.lora_config)
409
+ self.backbone = get_peft_model(self.backbone, lora_config_obj)
410
+ if repo_or_path:
411
+ self.load_peft_safetensors(
412
+ repo_or_path, token=kwargs.get("token", None)
413
+ )
414
+
415
+ def load_peft_safetensors(self, src, token=None):
416
+ # src: local dir or repo_id
417
+ fname = "adapter_model.safetensors"
418
+ if os.path.isdir(src):
419
+ path = os.path.join(src, fname)
420
+ if not os.path.exists(path):
421
+ return
422
+ else:
423
+ if fname not in list_repo_files(src, token=token):
424
+ return
425
+ path = hf_hub_download(src, fname, token=token)
426
+ with safe_open(path, framework="pt") as f:
427
+ self.backbone.load_state_dict(
428
+ {k: f.get_tensor(k) for k in f.keys()}, strict=False
429
+ )
430
+
431
+ @staticmethod
432
+ def inject_liza_attention(
433
+ backbone,
434
+ config,
435
+ linear_cache,
436
+ ):
437
+ """
438
+ Inject LiZAttention into the specified target modules of the base model.
439
+ """
440
+ # Find target modules by suffix (e.g., "attn", "attention")
441
+ target_modules = [
442
+ name
443
+ for name, _ in backbone.named_modules()
444
+ if any(name.endswith(suffix) for suffix in config.target_modules_names)
445
+ ]
446
+ if not target_modules:
447
+ raise ValueError(
448
+ f"Target modules '{config.target_modules_names}' not found in the model."
449
+ )
450
+ # Inject LiZAttention (external function, not shown here)
451
+ return get_tptt_model(
452
+ backbone,
453
+ base_config=backbone.config,
454
+ liza_attention=LiZAttention,
455
+ target_modules=target_modules,
456
+ linear_cache=linear_cache,
457
+ operator_mode=config.operator_mode,
458
+ max_self_attn_length=config.max_self_attn_length,
459
+ mag_weight=config.mag_weight,
460
+ max_chunk_size=config.max_chunk_size,
461
+ )
462
+
463
+ def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs):
464
+ """
465
+ Forward pass. All arguments are passed to the underlying base model.
466
+ """
467
+ if self.training:
468
+ kwargs["use_cache"] = False
469
+ kwargs.pop("num_items_in_batch", None)
470
+ else:
471
+ kwargs["use_cache"] = True
472
+ return self.backbone(
473
+ input_ids=input_ids, attention_mask=attention_mask, labels=labels, **kwargs
474
+ )
475
+
476
+ def generate(self, *args, **kwargs):
477
+ # Delegate the generate call to the backbone model, which supports generation
478
+ return self.backbone.generate(*args, **kwargs)
479
+
480
+ def save_pretrained(self, path: str, **kwargs):
481
+ """Save model weights, config, and source code to the given path."""
482
+ super().save_pretrained(path, **kwargs)
483
+
484
+ # 1. Save PEFT weights and clean adapter config
485
+ self._save_peft_weights(path, **kwargs)
486
+ # 2. Copy Python files for trust_remote_code
487
+ self._copy_source_files(path)
488
+
489
+ def _save_peft_weights(self, path: str, **kwargs):
490
+ """Save PEFT weights and remove redundant adapter config."""
491
+ self.backbone.save_pretrained(path, **kwargs)
492
+ adapter_config_path = os.path.join(path, "adapter_config.json")
493
+ if os.path.exists(adapter_config_path):
494
+ os.remove(adapter_config_path)
495
+
496
+ def _copy_source_files(self, path: str):
497
+ """Copy all .py files from package directory for trust_remote_code."""
498
+ src_dir = os.path.dirname(os.path.abspath(__file__))
499
+ for fname in os.listdir(src_dir):
500
+ if fname.endswith(".py"):
501
+ src = os.path.join(src_dir, fname)
502
+ dst = os.path.join(path, fname)
503
+ shutil.copy2(src, dst)
504
+
505
+
506
+ TpttModel.register_for_auto_class("AutoModelForCausalLM")
507
+
508
+
509
+ class AttentionOperator(nn.Module):
510
+ """Base class for linear attention operators."""
511
+
512
+ def __init__(self, mode="delta_rule"):
513
+ super().__init__()
514
+ self.mode = mode
515
+
516
+ def forward(self, q, k, v, **options):
517
+ """Forward pass for the attention operator."""
518
+ beta = options.get("beta", None)
519
+ chunk_size = options.get("chunk_size", 64)
520
+ scale = options.get("scale", 1)
521
+ recurrent_state = options.get("recurrent_state", None)
522
+
523
+ if self.mode == "delta_rule":
524
+ return self.chunk_delta_rule_forward(
525
+ q, k, v, beta, chunk_size, initial_state=recurrent_state
526
+ )
527
+ if self.mode == "gla":
528
+ return self.gla_forward(q, k, v, beta, scale, initial_state=recurrent_state)
529
+ raise ValueError(f"Unknown operator mode: {self.mode}")
530
+
531
+ @staticmethod
532
+ def chunk_delta_rule_forward(
533
+ query, key, value, beta, chunk_size, initial_state=None
534
+ ):
535
+ """
536
+ Implementation of https://arxiv.org/abs/2406.06484
537
+ query, key, value, beta: [batch, num_heads, seq_len, head_dim]
538
+ chunk_size: int
539
+ initial_state: [batch, num_heads, head_dim, head_dim] or None
540
+ """
541
+ batch_size, num_heads, seq_len, head_dim = query.shape
542
+ chunk_size = get_valid_chunk_size(seq_len, chunk_size)
543
+ num_chunks = seq_len // chunk_size
544
+
545
+ # Reshape for chunking: [batch, num_heads, num_chunks, chunk_size, head_dim]
546
+ q_chunks = query.reshape(
547
+ batch_size, num_heads, num_chunks, chunk_size, head_dim
548
+ )
549
+ k_chunks = key.reshape(batch_size, num_heads, num_chunks, chunk_size, head_dim)
550
+ v_chunks = value.reshape(
551
+ batch_size, num_heads, num_chunks, chunk_size, head_dim
552
+ )
553
+ beta_chunks = beta.reshape(
554
+ batch_size, num_heads, num_chunks, chunk_size, head_dim
555
+ )
556
+
557
+ # Output buffer
558
+ output = torch.empty_like(q_chunks)
559
+ # State: [batch, num_heads, head_dim, head_dim]
560
+ if initial_state is not None:
561
+ state = initial_state
562
+ else:
563
+ state = torch.zeros(
564
+ batch_size,
565
+ num_heads,
566
+ head_dim,
567
+ head_dim,
568
+ device=query.device,
569
+ dtype=query.dtype,
570
+ )
571
+
572
+ def process_chunk(q, k, v, b, state):
573
+ """
574
+ q, k, v, b: [batch, num_heads, chunk_size, head_dim]
575
+ state: [batch, num_heads, head_dim, head_dim]
576
+ Returns: (output_chunk, new_state)
577
+ """
578
+ # Clamp to avoid numerical instabilities (not in paper)
579
+ k = torch.clamp(k, min=-1e4, max=1e4)
580
+ v = torch.clamp(v, min=-1e4, max=1e4)
581
+ b = torch.clamp(b, min=1e-6, max=1e4)
582
+ q = torch.clamp(q, min=-1e4, max=1e4)
583
+
584
+ # Eq. (10): β_t * k_t and β_t * v_t
585
+ k_beta = k * b
586
+ v_beta = v * b
587
+
588
+ # Eq. (11): Lower-triangular matrix T (with -KβK^T off-diagonal, 1 on diagonal)
589
+ # T = I - tril(KβK^T, -1)
590
+ t_matrix = -(k_beta @ k.transpose(-2, -1)).tril(-1)
591
+ t_matrix = torch.clamp(t_matrix, min=-1e4, max=1e4)
592
+ t_matrix = t_matrix + torch.eye(
593
+ q.shape[-2], device=q.device, dtype=q.dtype
594
+ ).unsqueeze(0).unsqueeze(0)
595
+
596
+ # Eq. (11): W = T Kβ, U = T Vβ
597
+ w_matrix = t_matrix @ k_beta
598
+ w_matrix = torch.clamp(w_matrix, min=-1e4, max=1e4)
599
+
600
+ u_matrix = t_matrix @ v_beta
601
+ u_matrix = torch.clamp(u_matrix, min=-1e4, max=1e4)
602
+
603
+ # Eq. (12): u_i = U - W S (S = state)
604
+ u_i = u_matrix - torch.matmul(w_matrix, state)
605
+
606
+ # Eq. (12): inter-chunk output: q S
607
+ o_inter = torch.matmul(q, state)
608
+
609
+ # Eq. (12): intra-chunk attention: tril(q K^T)
610
+ a_i = (q @ k.transpose(-2, -1)).tril()
611
+
612
+ # Eq. (12): intra-chunk output: a_i u_i
613
+ o_intra = torch.matmul(a_i, u_i)
614
+
615
+ # Eq. (12): state update: S_new = S + K^T u_i
616
+ new_state = state + torch.matmul(k.transpose(-2, -1), u_i)
617
+ new_state = torch.clamp(new_state, min=-1e4, max=1e4)
618
+
619
+ # Eq. (12): output = intra + inter
620
+ return o_intra + o_inter, new_state
621
+
622
+ for chunk_idx in range(num_chunks):
623
+ q = q_chunks[:, :, chunk_idx]
624
+ k = k_chunks[:, :, chunk_idx]
625
+ v = v_chunks[:, :, chunk_idx]
626
+ b = beta_chunks[:, :, chunk_idx]
627
+
628
+ chunk_out, state = process_chunk(q, k, v, b, state)
629
+ output[:, :, chunk_idx] = chunk_out
630
+
631
+ # Reshape back to [batch, num_heads, seq_len, head_dim]
632
+ output = output.reshape(batch_size, num_heads, seq_len, head_dim)
633
+ return output, state
634
+
635
+ @staticmethod
636
+ def gla_forward(q, k, v, beta, scale, initial_state=None):
637
+ """Forward pass for GLA attention operator."""
638
+ if fused_chunk_gla is None or fused_recurrent_gla is None:
639
+ raise RuntimeError("GLA kernels are not available: CUDA required.")
640
+ if q.shape[-2] > 1:
641
+ # Training or sequence length > 1
642
+ return fused_chunk_gla(
643
+ q,
644
+ k,
645
+ v,
646
+ beta,
647
+ scale=scale,
648
+ initial_state=initial_state,
649
+ output_final_state=True,
650
+ )
651
+ return fused_recurrent_gla(
652
+ q,
653
+ k,
654
+ v,
655
+ beta,
656
+ scale=scale,
657
+ initial_state=initial_state,
658
+ output_final_state=True,
659
+ )
660
+
661
+
662
+ def get_attention_operator(mode):
663
+ """Factory for AttentionOperator."""
664
+ return AttentionOperator(mode=mode)
665
+
666
+
667
+ def extract_layer_idx(module_name: str) -> int:
668
+ """
669
+ Extract the layer index from a module name string.
670
+ """
671
+ match = re.search(r"\.(\d+)\.", module_name)
672
+ if match:
673
+ return int(match.group(1))
674
+ return -1
675
+
676
+
677
+ def soft_clamp(x, min_val=-1e4, max_val=1e4):
678
+ """Differentiable clamping for stability"""
679
+ scale = (max_val - min_val) / 2
680
+ center = (max_val + min_val) / 2
681
+ return torch.tanh((x - center) / scale) * scale + center
682
+
683
+
684
+ def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
685
+ """Repeat key/value heads for grouped query attention (GQA)."""
686
+ return x.repeat_interleave(n_rep, dim=1)
687
+
688
+
689
+ def split_qkv(base_attn, qkv):
690
+ """Split the QKV tensor into separate Q, K, and V tensors."""
691
+ num_q_heads = getattr(base_attn, "num_q_heads", None)
692
+ num_k_heads = getattr(base_attn, "num_k_heads", None)
693
+ num_v_heads = getattr(base_attn, "num_v_heads", None)
694
+ head_dim = getattr(base_attn, "head_dim", None)
695
+
696
+ q_len = num_q_heads * head_dim
697
+ k_len = num_k_heads * head_dim
698
+ v_len = num_v_heads * head_dim
699
+
700
+ q, k, v = torch.split(qkv, [q_len, k_len, v_len], dim=-1)
701
+ return q, k, v
702
+
703
+
704
+ def apply_linear_attention_mask(attention_mask, v):
705
+ # extract (if) padding mask
706
+ if attention_mask.dim() == 4 and attention_mask.shape[1] == 1:
707
+ # [batch, 1, seq, seq] -> [batch, seq]
708
+ mask = attention_mask.diagonal(dim1=-2, dim2=-1).squeeze(1)
709
+ else:
710
+ # Squeeze all singleton dims except batch (dim=0)
711
+ mask = attention_mask.squeeze(
712
+ dim=tuple(
713
+ i
714
+ for i in range(1, attention_mask.dim())
715
+ if attention_mask.shape[i] == 1
716
+ )
717
+ )
718
+ # handle left padding : mask is [batch, seq] --> Broadcast to v [batch, seq, (...)]
719
+ mask = mask[:, -v.shape[-2] :][(...,) + (None,) * (v.dim() - 2)]
720
+ return v * mask
721
+
722
+
723
+ def truncate_attention_mask(hidden_states, attention_mask, max_length):
724
+ """
725
+ Truncate hidden_states and attention_mask to the last window of size max_length,
726
+ matching the sequence dimension of hidden_states.
727
+ """
728
+ seq_dim = 1 # convention: (batch, seq, ...)
729
+ seq_len = hidden_states.shape[seq_dim]
730
+ if seq_len > max_length:
731
+ hidden_states = hidden_states.narrow(seq_dim, seq_len - max_length, max_length)
732
+ if attention_mask is not None:
733
+ # mask [batch, seq]
734
+ if attention_mask.dim() == 2:
735
+ attention_mask = attention_mask[:, -max_length:]
736
+ # mask [batch, seq, seq]
737
+ elif attention_mask.dim() == 3:
738
+ attention_mask = attention_mask[:, -max_length:, -max_length:]
739
+ # mask [batch, 1, seq, seq]
740
+ elif attention_mask.dim() == 4 and attention_mask.shape[1] == 1:
741
+ attention_mask = attention_mask[:, :, -max_length:, -max_length:]
742
+ else:
743
+ raise ValueError(
744
+ "No dimension in attention_mask matches sequence length of hidden_states."
745
+ )
746
+ return hidden_states, attention_mask
747
+
748
+
749
+ def get_valid_chunk_size(total_l: int, chunk_size: int) -> int:
750
+ """
751
+ Return the largest chunk_size <= chunk_size that divides total_l.
752
+ If no chunk_size > 1 fits, return 1.
753
+ """
754
+ for c in range(min(chunk_size, total_l), 0, -1):
755
+ if total_l % c == 0:
756
+ return c
757
+ return 1
758
+
759
+
760
+ def match_dim(x: torch.Tensor, dim: int, target_size: int) -> torch.Tensor:
761
+ """
762
+ Match the size of tensor x along dimension dim to target_size by interpolation
763
+ or projection.
764
+ """
765
+ src_size = x.shape[dim]
766
+ if src_size == target_size:
767
+ return x
768
+ x = torch.moveaxis(x, dim, -1)
769
+ shape = x.shape
770
+ if src_size < target_size:
771
+ x = x.reshape(-1, 1, src_size)
772
+ x = F.interpolate(x, size=target_size, mode="linear", align_corners=False)
773
+ x = x.reshape(*shape[:-1], target_size)
774
+ else:
775
+ eye = torch.eye(target_size, src_size, device=x.device, dtype=x.dtype)
776
+ x = F.linear(x, eye) # pylint: disable=not-callable
777
+ x = torch.moveaxis(x, -1, dim)
778
+ return x
pipeline_tptt.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import Pipeline
3
+
4
+
5
+ class TpttPipeline(Pipeline):
6
+ """Pipeline for TPTT model inference."""
7
+
8
+ def __init__(self, model, tokenizer, device=None, **kwargs):
9
+ """
10
+ Initialize TpttPipeline.
11
+ """
12
+ super().__init__(model=model, tokenizer=tokenizer, device=device, **kwargs)
13
+
14
+ def _sanitize_parameters(self, **kwargs):
15
+ # No special parameter handling for now
16
+ preprocess_kwargs = {}
17
+ forward_kwargs = {}
18
+ postprocess_kwargs = {}
19
+ return preprocess_kwargs, forward_kwargs, postprocess_kwargs
20
+
21
+ def preprocess(self, prompt):
22
+ # Tokenize the input prompt
23
+ return self.tokenizer(prompt, return_tensors="pt", truncation=False)
24
+
25
+ def _forward(self, model_inputs, **forward_params):
26
+ # Move tensors to the correct device
27
+ model_inputs = {k: v.to(self.device) for k, v in model_inputs.items()}
28
+ # Use generate for text generation
29
+ with torch.no_grad():
30
+ output = self.model.generate(
31
+ **model_inputs,
32
+ max_new_tokens=forward_params.get("max_new_tokens", 50),
33
+ do_sample=forward_params.get("do_sample", False),
34
+ # cache_implementation=forward_params.get("cache_implementation", "static"),
35
+ )
36
+ return {"generated_ids": output}
37
+
38
+ def postprocess(self, model_outputs):
39
+ # Decode the generated ids into text
40
+ generated_ids = model_outputs["generated_ids"]
41
+ return [
42
+ {"generated_text": self.tokenizer.decode(ids, skip_special_tokens=True)}
43
+ for ids in generated_ids
44
+ ]
runs/Jun09_19-12-59_bb326d838c91/events.out.tfevents.1749496381.bb326d838c91.35.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d3ef1c7c8b79507c60f72d6a4983b9c857e9aa9fdc08432b69f01ec64d20697e
3
+ size 7060
special_tokens_map.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<|im_start|>",
4
+ "<|im_end|>",
5
+ "<|object_ref_start|>",
6
+ "<|object_ref_end|>",
7
+ "<|box_start|>",
8
+ "<|box_end|>",
9
+ "<|quad_start|>",
10
+ "<|quad_end|>",
11
+ "<|vision_start|>",
12
+ "<|vision_end|>",
13
+ "<|vision_pad|>",
14
+ "<|image_pad|>",
15
+ "<|video_pad|>"
16
+ ],
17
+ "eos_token": {
18
+ "content": "<|endoftext|>",
19
+ "lstrip": false,
20
+ "normalized": false,
21
+ "rstrip": false,
22
+ "single_word": false
23
+ },
24
+ "pad_token": {
25
+ "content": "<|endoftext|>",
26
+ "lstrip": false,
27
+ "normalized": false,
28
+ "rstrip": false,
29
+ "single_word": false
30
+ }
31
+ }
tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:540b7fbf60b80e8293593a86960df91d2263723d69107ffc1afc89a7c08cda12
3
+ size 11422162
tokenizer_config.json ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": false,
3
+ "add_prefix_space": false,
4
+ "added_tokens_decoder": {
5
+ "151643": {
6
+ "content": "<|endoftext|>",
7
+ "lstrip": false,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false,
11
+ "special": true
12
+ },
13
+ "151644": {
14
+ "content": "<|im_start|>",
15
+ "lstrip": false,
16
+ "normalized": false,
17
+ "rstrip": false,
18
+ "single_word": false,
19
+ "special": true
20
+ },
21
+ "151645": {
22
+ "content": "<|im_end|>",
23
+ "lstrip": false,
24
+ "normalized": false,
25
+ "rstrip": false,
26
+ "single_word": false,
27
+ "special": true
28
+ },
29
+ "151646": {
30
+ "content": "<|object_ref_start|>",
31
+ "lstrip": false,
32
+ "normalized": false,
33
+ "rstrip": false,
34
+ "single_word": false,
35
+ "special": true
36
+ },
37
+ "151647": {
38
+ "content": "<|object_ref_end|>",
39
+ "lstrip": false,
40
+ "normalized": false,
41
+ "rstrip": false,
42
+ "single_word": false,
43
+ "special": true
44
+ },
45
+ "151648": {
46
+ "content": "<|box_start|>",
47
+ "lstrip": false,
48
+ "normalized": false,
49
+ "rstrip": false,
50
+ "single_word": false,
51
+ "special": true
52
+ },
53
+ "151649": {
54
+ "content": "<|box_end|>",
55
+ "lstrip": false,
56
+ "normalized": false,
57
+ "rstrip": false,
58
+ "single_word": false,
59
+ "special": true
60
+ },
61
+ "151650": {
62
+ "content": "<|quad_start|>",
63
+ "lstrip": false,
64
+ "normalized": false,
65
+ "rstrip": false,
66
+ "single_word": false,
67
+ "special": true
68
+ },
69
+ "151651": {
70
+ "content": "<|quad_end|>",
71
+ "lstrip": false,
72
+ "normalized": false,
73
+ "rstrip": false,
74
+ "single_word": false,
75
+ "special": true
76
+ },
77
+ "151652": {
78
+ "content": "<|vision_start|>",
79
+ "lstrip": false,
80
+ "normalized": false,
81
+ "rstrip": false,
82
+ "single_word": false,
83
+ "special": true
84
+ },
85
+ "151653": {
86
+ "content": "<|vision_end|>",
87
+ "lstrip": false,
88
+ "normalized": false,
89
+ "rstrip": false,
90
+ "single_word": false,
91
+ "special": true
92
+ },
93
+ "151654": {
94
+ "content": "<|vision_pad|>",
95
+ "lstrip": false,
96
+ "normalized": false,
97
+ "rstrip": false,
98
+ "single_word": false,
99
+ "special": true
100
+ },
101
+ "151655": {
102
+ "content": "<|image_pad|>",
103
+ "lstrip": false,
104
+ "normalized": false,
105
+ "rstrip": false,
106
+ "single_word": false,
107
+ "special": true
108
+ },
109
+ "151656": {
110
+ "content": "<|video_pad|>",
111
+ "lstrip": false,
112
+ "normalized": false,
113
+ "rstrip": false,
114
+ "single_word": false,
115
+ "special": true
116
+ },
117
+ "151657": {
118
+ "content": "<tool_call>",
119
+ "lstrip": false,
120
+ "normalized": false,
121
+ "rstrip": false,
122
+ "single_word": false,
123
+ "special": false
124
+ },
125
+ "151658": {
126
+ "content": "</tool_call>",
127
+ "lstrip": false,
128
+ "normalized": false,
129
+ "rstrip": false,
130
+ "single_word": false,
131
+ "special": false
132
+ },
133
+ "151659": {
134
+ "content": "<|fim_prefix|>",
135
+ "lstrip": false,
136
+ "normalized": false,
137
+ "rstrip": false,
138
+ "single_word": false,
139
+ "special": false
140
+ },
141
+ "151660": {
142
+ "content": "<|fim_middle|>",
143
+ "lstrip": false,
144
+ "normalized": false,
145
+ "rstrip": false,
146
+ "single_word": false,
147
+ "special": false
148
+ },
149
+ "151661": {
150
+ "content": "<|fim_suffix|>",
151
+ "lstrip": false,
152
+ "normalized": false,
153
+ "rstrip": false,
154
+ "single_word": false,
155
+ "special": false
156
+ },
157
+ "151662": {
158
+ "content": "<|fim_pad|>",
159
+ "lstrip": false,
160
+ "normalized": false,
161
+ "rstrip": false,
162
+ "single_word": false,
163
+ "special": false
164
+ },
165
+ "151663": {
166
+ "content": "<|repo_name|>",
167
+ "lstrip": false,
168
+ "normalized": false,
169
+ "rstrip": false,
170
+ "single_word": false,
171
+ "special": false
172
+ },
173
+ "151664": {
174
+ "content": "<|file_sep|>",
175
+ "lstrip": false,
176
+ "normalized": false,
177
+ "rstrip": false,
178
+ "single_word": false,
179
+ "special": false
180
+ }
181
+ },
182
+ "additional_special_tokens": [
183
+ "<|im_start|>",
184
+ "<|im_end|>",
185
+ "<|object_ref_start|>",
186
+ "<|object_ref_end|>",
187
+ "<|box_start|>",
188
+ "<|box_end|>",
189
+ "<|quad_start|>",
190
+ "<|quad_end|>",
191
+ "<|vision_start|>",
192
+ "<|vision_end|>",
193
+ "<|vision_pad|>",
194
+ "<|image_pad|>",
195
+ "<|video_pad|>"
196
+ ],
197
+ "bos_token": null,
198
+ "chat_template": "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0]['role'] == 'system' %}\n {{- messages[0]['content'] }}\n {%- else %}\n {{- 'You are a helpful assistant.' }}\n {%- endif %}\n {{- \"\\n\\n# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}\\n</tool_call><|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0]['role'] == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0]['content'] + '<|im_end|>\\n' }}\n {%- else %}\n {{- '<|im_start|>system\\nYou are a helpful assistant.<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- for message in messages %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) or (message.role == \"assistant\" and not message.tool_calls) %}\n {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {{- '<|im_start|>' + message.role }}\n {%- if message.content %}\n {{- '\\n' + message.content }}\n {%- endif %}\n {%- for tool_call in message.tool_calls %}\n {%- if tool_call.function is defined %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '\\n<tool_call>\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- '}\\n</tool_call>' }}\n {%- endfor %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n<tool_response>\\n' }}\n {{- message.content }}\n {{- '\\n</tool_response>' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n{%- endif %}\n",
199
+ "clean_up_tokenization_spaces": false,
200
+ "eos_token": "<|endoftext|>",
201
+ "errors": "replace",
202
+ "extra_special_tokens": {},
203
+ "model_max_length": 131072,
204
+ "pad_token": "<|endoftext|>",
205
+ "split_special_tokens": false,
206
+ "tokenizer_class": "Qwen2Tokenizer",
207
+ "unk_token": null
208
+ }
train_tptt.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Author : Fabien FURFARO
3
+ """
4
+
5
+ from transformers import TrainerCallback
6
+
7
+ from .modeling_tptt import LiZAttention
8
+
9
+
10
+ class AdjustMaGWeightCallback(TrainerCallback):
11
+ """TrainerCallback to schedule mag_weight during training."""
12
+
13
+ def __init__(
14
+ self, model, initial_weight=0.01, final_weight=0.5, transition_step=500
15
+ ):
16
+ self.model = model
17
+ # Ensure weights are always float scalars, not tuples/lists
18
+ if isinstance(initial_weight, (tuple, list)):
19
+ initial_weight = initial_weight[0]
20
+ if isinstance(final_weight, (tuple, list)):
21
+ final_weight = final_weight[0]
22
+ self.initial_weight = float(initial_weight)
23
+ self.final_weight = float(final_weight)
24
+
25
+ if isinstance(transition_step, (tuple, list)):
26
+ transition_step = transition_step[0]
27
+ self.transition_step = int(transition_step)
28
+
29
+ def on_step_end(self, args, state, control, **kwargs):
30
+ current_step = state.global_step
31
+ transition_step = self.transition_step
32
+
33
+ # Ensure both are plain ints (not tuple, list, tensor, numpy, etc.)
34
+ if isinstance(current_step, (tuple, list)):
35
+ current_step = current_step[0]
36
+ if hasattr(current_step, "item"):
37
+ current_step = int(current_step.item())
38
+ else:
39
+ current_step = int(current_step)
40
+
41
+ if isinstance(transition_step, (tuple, list)):
42
+ transition_step = transition_step[0]
43
+ if hasattr(transition_step, "item"):
44
+ transition_step = int(transition_step.item())
45
+ else:
46
+ transition_step = int(transition_step)
47
+
48
+ if current_step < transition_step:
49
+ weight = self.initial_weight + (self.final_weight - self.initial_weight) * (
50
+ current_step / transition_step
51
+ )
52
+ for _, module in self.model.named_modules():
53
+ if isinstance(module, LiZAttention):
54
+ module.mag_weight = weight
55
+
56
+ def on_log(self, args, state, control, logs=None, **kwargs):
57
+ mag_weight = None
58
+ for _, module in self.model.named_modules():
59
+ if isinstance(module, LiZAttention):
60
+ mag_weight = getattr(module, "mag_weight", None)
61
+ break
62
+ if mag_weight is not None and logs is not None:
63
+ logs["mag_weight"] = float(mag_weight)
vocab.json ADDED
The diff for this file is too large to render. See raw diff