akhauriyash commited on
Commit
bb3a05c
·
1 Parent(s): abd888e

Add final Butler model

Browse files
__pycache__/modeling_llama_butler.cpython-310.pyc ADDED
Binary file (3.39 kB). View file
 
__pycache__/modify_llama.cpython-310.pyc ADDED
Binary file (13.3 kB). View file
 
__pycache__/predictor.cpython-310.pyc ADDED
Binary file (14.6 kB). View file
 
__pycache__/utils.cpython-310.pyc ADDED
Binary file (33.9 kB). View file
 
config.json ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "LlamaButlerForCausalLM"
4
+ ],
5
+ "attention_bias": false,
6
+ "attention_dropout": 0.0,
7
+ "attn_reduce_factor": 8,
8
+ "auto_map": {
9
+ "AutoConfig": "modeling_llama_butler:LlamaButlerConfig",
10
+ "AutoModel": "modeling_llama_butler:LlamaButlerForCausalLM",
11
+ "AutoModelForCausalLM": "modeling_llama_butler:LlamaButlerForCausalLM"
12
+ },
13
+ "bos_token_id": 128000,
14
+ "dDash": 16,
15
+ "eos_token_id": 128001,
16
+ "eval_llm_mode": "ExpPred",
17
+ "flash_attn": false,
18
+ "head_attn_reduce_factor": 2,
19
+ "head_dim": 64,
20
+ "hidden_act": "silu",
21
+ "hidden_size": 2048,
22
+ "initializer_range": 0.02,
23
+ "intdim": 512,
24
+ "intermediate_size": 8192,
25
+ "lookahead": 0,
26
+ "max_position_embeddings": 131072,
27
+ "min_sparse_index": 8,
28
+ "mlp_bias": false,
29
+ "model_type": "llama_butler",
30
+ "num_attention_heads": 32,
31
+ "num_hidden_layers": 16,
32
+ "num_key_value_heads": 8,
33
+ "pretraining_tp": 1,
34
+ "producer_frequency": 16,
35
+ "rms_norm_eps": 1e-05,
36
+ "rope_scaling": {
37
+ "factor": 32.0,
38
+ "high_freq_factor": 4.0,
39
+ "low_freq_factor": 1.0,
40
+ "original_max_position_embeddings": 8192,
41
+ "rope_type": "llama3"
42
+ },
43
+ "rope_theta": 500000.0,
44
+ "sliding_window": 128,
45
+ "tie_word_embeddings": true,
46
+ "token_sparse_method": "fixed_50pc",
47
+ "torch_dtype": "float32",
48
+ "train_headpredictor": false,
49
+ "transformers_version": "4.48.3",
50
+ "use_cache": true,
51
+ "vocab_size": 128256
52
+ }
conversion.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import LlamaForCausalLM, LlamaConfig, AutoTokenizer
2
+ import torch
3
+ import os
4
+
5
+ question = "A $y$-intercept is a point on the graph that lies on the $y$-axis, so $x = 0$. Hence, the number $y$-intercepts corresponds to the number of real solutions of the quadratic equation $y^2 - 4y - 1 = 0$. The discriminant of this quadratic equation is $(-4)^2 + 4 \cdot 1 \cdot (-1) = 20$, which is positive, so the quadratic has two distinct real roots. Therefore, the number of $y$-intercepts is $\boxed{2}$. \n \n [asy] \n size(150); \n real ticklen=3; \n real tickspace=2; \n \n real ticklength=0.1cm; \n real axisarrowsize=0.14cm; \n pen axispen=black+1.3bp; \n real vectorarrowsize=0.2cm; \n real tickdown=-0.5; \n real tickdownlength=-0.15inch; \n real tickdownbase=0.3; \n real wholetickdown=tickdown; \n void rr_cartesian_axes(real xleft, real xright, real ybottom, real ytop, real xstep=1, real ystep=1, bool \n \n useticks=false, bool complexplane=false, bool usegrid=true) { \n \n import graph; \n \n real i; \n \n if(complexplane) { \n \n label('$\textnormal{Re}$',(xright,0),SE); \n \n label('$\textnormal{Im}$',(0,ytop),NW); \n \n } else { \n \n label('$x$',(xright+0.4,-0.5)); \n \n label('$y$',(-0.5,ytop+0.2)); \n \n } \n \n ylimits(ybottom,ytop); \n \n xlimits( xleft, xright); \n \n real[] TicksArrx,TicksArry; \n \n for(i=xleft+xstep; i<xright; i+=xstep) { \n \n if(abs(i) >0.1) { \n \n TicksArrx.push(i); \n \n } \n \n } \n \n for(i=ybottom+ystep; i<ytop; i+=ystep) { \n \n if(abs(i) >0.1) { \n \n TicksArry.push(i); \n \n } \n \n } \n \n if(usegrid) {"
6
+
7
+ def get_producer_layers(model):
8
+ """
9
+ Traverses the model to find the producer layer (layer_idx=0).cc
10
+ """
11
+ producer_modules = []
12
+ for module in model.modules():
13
+ if module.__class__.__name__.endswith("AttentionExperimental") and module.layer_idx == 0:
14
+ producer_modules.append(module)
15
+ return producer_modules
16
+
17
+ # 1) Load the base model from HF
18
+ base_model_name = "meta-llama/Llama-3.2-1B"
19
+ base_model = LlamaForCausalLM.from_pretrained(base_model_name, device_map="auto")
20
+ tokenizer = AutoTokenizer.from_pretrained(base_model_name)
21
+ inputs = tokenizer(question, return_tensors="pt")
22
+ inputs = {k: v.to(base_model.device) for k, v in inputs.items()}
23
+ question_length = inputs['attention_mask'].shape[1]
24
+
25
+ with torch.no_grad():
26
+ base_output_ids = base_model.generate(
27
+ **inputs,
28
+ max_new_tokens=200,
29
+ do_sample=True,
30
+ top_p=0.95,
31
+ temperature=0.7,
32
+ )
33
+ base_output_text = tokenizer.decode(base_output_ids[0][question_length:], skip_special_tokens=True)
34
+
35
+
36
+ from modeling_llama_butler import LlamaButlerConfig, LlamaButlerForCausalLM
37
+ butler_config = LlamaButlerConfig.from_pretrained('config.json')
38
+
39
+
40
+ butler_model = LlamaButlerForCausalLM(butler_config)
41
+ butler_model.load_state_dict(base_model.state_dict(), strict=False)
42
+
43
+ predictor_load_path = "/home/ya255/projects/TokenButler/expt_model/TrainTokenButler_42_finetune_None_None_500_llama_meta-llama_Llama-3.2-1B_L3_1B_2k.csv_L3_1B_2k_False_False_2000_False_redpajama_1024_1_1_20_0.001_1024/16_False_4_1000_ExpPred_fixed_40pc_True_False_0_None_False_False_4_8_2_16_512_False_False_True_16_0.37500000000000006__best.pt"
44
+
45
+ model_producer_layers = get_producer_layers(butler_model)
46
+ producer_layer_weights = torch.load(predictor_load_path)
47
+ for idx, producer_layer_weight in enumerate(producer_layer_weights):
48
+ try:
49
+ model_producer_layers[idx].load_state_dict(producer_layer_weight, strict=False)
50
+ except Exception as e:
51
+ print(f"Error loading producer layer {idx}: {e}")
52
+ print("\n\nContinuing... !! Bad Perf If Unintentional !!\n\n")
53
+
54
+
55
+ butler_model.to(base_model.device)
56
+ butler_model.eval()
57
+
58
+ with torch.no_grad():
59
+ butler_output_ids = butler_model.generate(
60
+ **inputs,
61
+ max_new_tokens=200,
62
+ do_sample=True,
63
+ top_p=0.95,
64
+ temperature=0.7,
65
+ )
66
+
67
+ butler_output_text = tokenizer.decode(butler_output_ids[0][question_length:], skip_special_tokens=True)
68
+
69
+ print("\n=== Base Model Output (Newlines Removed For Brevity) ===\n")
70
+ print(base_output_text.replace("\n", ""))
71
+ print("\n")
72
+ print("=== Butler Model Output (Newlines Removed For Brevity) ===\n")
73
+ print(butler_output_text.replace("\n", ""))
74
+ print("\n")
75
+
76
+ OUTPUT_DIR = "."
77
+ print(f"\nSaving final merged model to: {OUTPUT_DIR}")
78
+ butler_model.save_pretrained(OUTPUT_DIR, safe_serialization=False)
79
+
80
+ # tokenizer.save_pretrained(OUTPUT_DIR)
81
+ print("\nAll done! The folder should now have `pytorch_model.bin` and the updated `config.json`.\n")
generation_config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 128000,
4
+ "eos_token_id": 128001,
5
+ "transformers_version": "4.48.3"
6
+ }
modeling_llama_butler.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from typing import Dict
4
+ from transformers import LlamaForCausalLM, LlamaConfig
5
+ from transformers.generation.utils import GenerationConfig
6
+
7
+ # Import your custom patching logic & custom modules:
8
+ from modify_llama import convert_kvcache_experimental
9
+ from predictor import TokenImportancePredictorAttentive, HeadImportancePredictor, PredictorDynamicCache
10
+ from modify_llama import LlamaAttentionExperimental
11
+
12
+
13
+
14
+ # ---------------------------------------------------------------------
15
+ # 1) Custom Config subclass
16
+ # ---------------------------------------------------------------------
17
+ class LlamaButlerConfig(LlamaConfig):
18
+ """
19
+ Extends HF's LlamaConfig to hold optional extra parameters for the "Butler" logic.
20
+ You can store your custom attributes here, so they can be serialized in config.json.
21
+ """
22
+
23
+ model_type = "llama_butler"
24
+
25
+ def __init__(
26
+ self,
27
+ eval_llm_mode="ExpPred",
28
+ token_sparse_method="fixed_50pc",
29
+ producer_frequency=8,
30
+ dDash=16,
31
+ attn_reduce_factor=4,
32
+ head_attn_reduce_factor=4,
33
+ intdim=256,
34
+ flash_attn=False,
35
+ train_headpredictor=False,
36
+ min_sparse_index=5,
37
+ lookahead=0,
38
+ sliding_window=None,
39
+ **kwargs
40
+ ):
41
+ super().__init__(**kwargs)
42
+ self.eval_llm_mode = eval_llm_mode
43
+ self.token_sparse_method = token_sparse_method
44
+ self.producer_frequency = producer_frequency
45
+ self.dDash = dDash
46
+ self.attn_reduce_factor = attn_reduce_factor
47
+ self.head_attn_reduce_factor = head_attn_reduce_factor
48
+ self.intdim = intdim
49
+ self.flash_attn = flash_attn
50
+ self.train_headpredictor = train_headpredictor
51
+ self.min_sparse_index = min_sparse_index
52
+ self.lookahead = lookahead
53
+ self.sliding_window = sliding_window
54
+
55
+
56
+ # ---------------------------------------------------------------------
57
+ # 2) The main Butler model class
58
+ # ---------------------------------------------------------------------
59
+ class LlamaButlerForCausalLM(LlamaForCausalLM):
60
+ """
61
+ A subclass of HF's LlamaForCausalLM that:
62
+ - Patches each LlamaAttention to your LlamaAttentionExperimental
63
+ - Sets specialized attributes (eval_llm_mode, etc.)
64
+ - Overrides _prepare_cache_for_generation to inject PredictorDynamicCache
65
+ """
66
+
67
+ # Let HF auto-detect this config class from config.json:
68
+ config_class = LlamaButlerConfig
69
+
70
+ def __init__(self, config: LlamaButlerConfig):
71
+ super().__init__(config)
72
+ """
73
+ HF's LlamaForCausalLM initializes:
74
+ self.model = LlamaModel(config)
75
+ self.lm_head = nn.Linear(...)
76
+ """
77
+
78
+ # 1) Patch the underlying LlamaModel to replace LlamaAttention with LlamaAttentionExperimental
79
+ self.model = convert_kvcache_experimental(
80
+ self.model,
81
+ config,
82
+ config.producer_frequency
83
+ )
84
+
85
+ # 2) Optionally, set per-module attributes so each LlamaAttentionExperimental knows about them:
86
+ for module in self.model.modules():
87
+ if module.__class__.__name__.endswith("AttentionExperimental"):
88
+ # Set these from your config. Or you can hardcode them if you prefer.
89
+ module.eval_llm_mode = config.eval_llm_mode
90
+ module.token_sparse_method = config.token_sparse_method
91
+ module.set_token_sparsity() # e.g. sets module.sparse_aggression
92
+
93
+ module.producer_frequency = config.producer_frequency
94
+ module.dDash = config.dDash
95
+ module.attn_reduce_factor = config.attn_reduce_factor
96
+ module.head_attn_reduce_factor = config.head_attn_reduce_factor
97
+ module.intdim = config.intdim
98
+ module.flash_attn = config.flash_attn
99
+ module.train_headpredictor = config.train_headpredictor
100
+ module.min_sparse_index = config.min_sparse_index
101
+ module.lookahead = config.lookahead
102
+ module.sliding_window = config.sliding_window
103
+ module.num_layers_pred = config.producer_frequency # example usage
104
+
105
+ # If this is a "producer layer" (mod.layer_idx % freq == 0), run update_predictor():
106
+ if hasattr(module, "layer_idx") and (module.layer_idx % config.producer_frequency == 0):
107
+ module.update_predictor()
108
+
109
+ # 3) Patch the dynamic cache (past_key_values) creation. For your evaluation modes:
110
+ if config.eval_llm_mode in ["ExpPred", "ReplAttn"]:
111
+ self._prepare_cache_for_generation = self._patched_prepare_cache_for_generation.__get__(
112
+ self, self.__class__
113
+ )
114
+
115
+ # -----------------------------------------------------------------
116
+ # 3) The custom `_prepare_cache_for_generation` override
117
+ # -----------------------------------------------------------------
118
+ def _patched_prepare_cache_for_generation(
119
+ self,
120
+ generation_config: GenerationConfig,
121
+ model_kwargs: Dict,
122
+ *args,
123
+ **kwargs
124
+ ):
125
+ """
126
+ This override injects a PredictorDynamicCache
127
+ in place of the standard 'past_key_values'.
128
+ """
129
+ if "past_key_values" not in model_kwargs or model_kwargs["past_key_values"] is None:
130
+ model_kwargs["past_key_values"] = PredictorDynamicCache()
131
+ return model_kwargs
modify_llama.py ADDED
@@ -0,0 +1,464 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pdb
3
+ import copy
4
+ import math
5
+ import numpy as np
6
+ from dataclasses import dataclass
7
+ from typing import Optional, Tuple, Union
8
+ import gc
9
+
10
+ import traceback
11
+ import torch
12
+ from torch import nn
13
+ import torch.utils.checkpoint
14
+ import torch.nn.functional as F
15
+ from torch.cuda.amp import autocast
16
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
17
+
18
+ from transformers.models.llama.configuration_llama import LlamaConfig
19
+ from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, LlamaAttention, apply_rotary_pos_emb
20
+
21
+ from utils import LlamaLinearScalingRotaryEmbedding, LlamaDynamicNTKScalingRotaryEmbedding, repeat_kv, sorted_index_to_mask
22
+ from utils import calculate_hit_metrics, calculate_effective_sparsity, threshold_to_mask, SlidingWindowCache, enforce_sliding_window
23
+ from transformers.cache_utils import DynamicCache
24
+ from predictor import TokenImportancePredictorAttentive, PredictorDynamicCache, HeadImportancePredictor, attention_mse_loss, attention
25
+
26
+ from triton_kernels.flash_attn import attention
27
+ from triton_kernels.flash_attn_mse_loss import attention_mse_loss
28
+
29
+ # torch.backends.cuda.enable_flash_sdp(enabled=True)
30
+ # torch.backends.cuda.enable_mem_efficient_sdp(enabled=True)
31
+
32
+ class LlamaAttentionExperimental(nn.Module):
33
+ def __init__(self, config: LlamaConfig, producer=None, layer_idx=0):
34
+ super().__init__()
35
+ self.config = config
36
+ self.hidden_size = config.hidden_size
37
+ self.num_hidden_layers = config.num_hidden_layers
38
+ self.num_heads = config.num_attention_heads
39
+ self.head_dim = self.hidden_size // self.num_heads
40
+ self.num_key_value_heads = config.num_key_value_heads
41
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
42
+ self.max_position_embeddings = config.max_position_embeddings
43
+ self.rope_theta = config.rope_theta
44
+ self.inference_mode = False
45
+ self.producer = producer
46
+ self.layer_idx = layer_idx
47
+ self.token_sparse_method = None
48
+ self.sparse_aggression = None
49
+ self.stream_llm_start_size = None
50
+ self.dDash = None
51
+ self.intdim = None
52
+ self.attn_reduce_factor = None
53
+ self.head_attn_reduce_factor = None
54
+ self.effective_sparsity = None
55
+ self.min_sparse_index = None
56
+ self.pred_hid_size = self.hidden_size
57
+ self.num_tok_per_page = None
58
+ self.calc_hitrates = False
59
+ self.flash_attn = False
60
+ self.train_headpredictor = False
61
+ self.calibrate_thresholds = False
62
+ self.test_with_thresholds = False
63
+ self.old_predictor = None
64
+
65
+ if self.layer_idx > 0:
66
+ self.mseloss = MSELoss(reduction='none')
67
+ self.msemagn_loss = None
68
+ self.headmseloss = MSELoss(reduction='none')
69
+ self.headmsemagn_loss = None
70
+
71
+ if self.producer is None: # This is the producer layer
72
+ self.q_importance = None # Shared mask across layers during inference
73
+ self.k_importance = None
74
+ self.head_importances = None
75
+ self.actmagn_masklist = {}
76
+ self.available_tokens = {}
77
+
78
+ # Attention setup
79
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
80
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
81
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
82
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
83
+ self._init_rope()
84
+
85
+ def update_predictor(self):
86
+ self.sparse_token_predictor = TokenImportancePredictorAttentive(
87
+ self.config, self.pred_hid_size, self.num_heads, self.num_layers_pred, dropout=0.1, dDash = self.dDash, \
88
+ intdim = self.intdim, attn_reduce_factor=self.attn_reduce_factor
89
+ ).to('cuda:0')
90
+ self.sparse_token_predictor.flash_attn = self.flash_attn
91
+ if self.train_headpredictor:
92
+ self.sparse_head_predictor = HeadImportancePredictor(
93
+ self.config, self.pred_hid_size, self.num_heads, self.num_layers_pred, dropout=0.1, dDash = self.dDash, \
94
+ intdim = self.intdim, attn_reduce_factor=self.head_attn_reduce_factor
95
+ ).to('cuda:0')
96
+ self.sparse_head_predictor.flash_attn = self.flash_attn
97
+
98
+ def set_token_sparsity(self):
99
+ assert self.token_sparse_method is not None, "Set token sparse method first!"
100
+ if self.token_sparse_method is not None:
101
+ try:
102
+ mname = self.config._name_or_path.split("/")[-1]
103
+ read_path = f"threshold_calibs/{mname}/{self.token_sparse_method}.pkl"
104
+ threshold_model_dictionary = torch.load(read_path)
105
+ self.tok_calibration_set = threshold_model_dictionary
106
+ except:
107
+ pass
108
+ if self.token_sparse_method == "LazyLLM":
109
+ if self.layer_idx <= 9:
110
+ self.sparse_aggression = 1
111
+ elif self.layer_idx <= 19:
112
+ self.sparse_aggression = 0.7
113
+ elif self.layer_idx <= 28:
114
+ self.sparse_aggression = 0.4
115
+ else:
116
+ self.sparse_aggression = 0.1
117
+ elif "fixed" in self.token_sparse_method:
118
+ if self.layer_idx == 0:
119
+ self.sparse_aggression = 1
120
+ else:
121
+ self.sparse_aggression = 1 - float(self.token_sparse_method.split("_")[1].split("pc")[0])/100.
122
+ elif "progressive" in self.token_sparse_method:
123
+ pc_drop = float(self.token_sparse_method.split("_")[1].split("pc")[0])/100.
124
+ self.sparse_aggression = (1 - pc_drop) ** (self.layer_idx) # (x% per layer, progressive_xpc style)
125
+ else:
126
+ raise ValueError(f"Unknown token sparsity method {self.token_sparse_method}")
127
+
128
+
129
+ def _init_rope(self):
130
+ if self.config.rope_scaling is None:
131
+ self.rotary_emb = LlamaRotaryEmbedding(
132
+ self.config
133
+ )
134
+ else:
135
+ scaling_type = self.config.rope_scaling.get("type") or self.config.rope_scaling.get("rope_type")
136
+ scaling_factor = self.config.rope_scaling["factor"]
137
+ if scaling_type == "linear" or scaling_type == 'llama3':
138
+ self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
139
+ self.head_dim,
140
+ max_position_embeddings=self.max_position_embeddings,
141
+ scaling_factor=scaling_factor,
142
+ base=self.rope_theta,
143
+ config=self.config
144
+ )
145
+ elif scaling_type == "dynamic":
146
+ self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
147
+ self.head_dim,
148
+ max_position_embeddings=self.max_position_embeddings,
149
+ scaling_factor=scaling_factor,
150
+ base=self.rope_theta,
151
+ config=self.config
152
+ )
153
+ else:
154
+ raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
155
+
156
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
157
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
158
+
159
+ def forward(
160
+ self,
161
+ hidden_states: torch.Tensor,
162
+ attention_mask: Optional[torch.Tensor] = None,
163
+ position_ids: Optional[torch.LongTensor] = None,
164
+ past_key_value: Optional[Union[DynamicCache, PredictorDynamicCache]] = None,
165
+ output_attentions: bool = False,
166
+ use_cache: bool = False,
167
+ padding_mask: Optional[torch.LongTensor] = None,
168
+ cache_position: Optional[torch.LongTensor] = None,
169
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
170
+ **kwargs,
171
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[PredictorDynamicCache]]:
172
+ bsz, q_len, _ = hidden_states.size()
173
+ Ltrack = hidden_states.size(1)
174
+
175
+ if self.config.pretraining_tp > 1:
176
+ key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
177
+ query_slices = self.q_proj.weight.split(
178
+ (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
179
+ )
180
+ key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
181
+ value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
182
+
183
+ query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
184
+ query_states = torch.cat(query_states, dim=-1)
185
+
186
+ key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
187
+ key_states = torch.cat(key_states, dim=-1)
188
+
189
+ value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
190
+ value_states = torch.cat(value_states, dim=-1)
191
+ else:
192
+ query_states = self.q_proj(hidden_states)
193
+ key_states = self.k_proj(hidden_states)
194
+ value_states = self.v_proj(hidden_states)
195
+
196
+ evalmode = self.eval_llm_mode
197
+ num_tokens_to_keep = int(q_len * self.sparse_aggression)
198
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
199
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
200
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
201
+
202
+ # cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) # AHMED: Modified this to use the newer version.
203
+ cos, sin = position_embeddings
204
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
205
+
206
+ if use_cache:
207
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx)
208
+
209
+ kv_seq_len = key_states.shape[-2]
210
+ final_mask = None
211
+
212
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
213
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
214
+
215
+ key_len = key_states.size(2)
216
+ bsz, q_len = query_states.size(0), query_states.size(2)
217
+
218
+ if attention_mask is None:
219
+ # We want a [q_len, kv_seq_len] boolean upper-triangular mask
220
+ causal_mask_2d = torch.ones(q_len, kv_seq_len,
221
+ device=hidden_states.device,
222
+ dtype=torch.bool).triu(diagonal=1)
223
+ # Then shape it to [bsz, 1, q_len, kv_seq_len]
224
+ causal_mask_4d = causal_mask_2d.unsqueeze(0).expand(bsz, 1, q_len, kv_seq_len)
225
+ # Now fill -inf where the mask is True
226
+ attention_mask = torch.full_like(causal_mask_4d, 0, dtype=hidden_states.dtype)
227
+ if q_len != 1:
228
+ attention_mask = attention_mask.masked_fill(causal_mask_4d, float("-inf"))
229
+
230
+ if self.inference_mode:
231
+ min_sparse_index = self.min_sparse_index
232
+ with torch.no_grad():
233
+ if evalmode == "ExpPred":
234
+ if self.layer_idx > 0:
235
+ q_importance_tensor = self.producer.q_importance[:, self.layer_idx % self.producer_frequency, :, :].float().to(query_states.device) # [BH, Lq, D']
236
+ k_importance_tensor = self.producer.k_importance[:, self.layer_idx % self.producer_frequency, :, :].float().to(key_states.device) # [BH, Lk, D']
237
+ importance_mask = torch.bmm(q_importance_tensor, k_importance_tensor.transpose(-2, -1)) / math.sqrt(self.dDash) # [BH, Lq, Lk]
238
+ importance_mask = importance_mask.view(bsz, self.num_heads, q_len, key_len) # [B, H, Lq, Lk]
239
+ attn_weights = torch.matmul(query_states, key_states.transpose(-2, -1)) / math.sqrt(self.head_dim)
240
+ if self.calc_hitrates:
241
+ self.tok_hit_acc, self.tok_mean_rank_corr, self.tok_max_rank_corr = calculate_hit_metrics(
242
+ estimated_importance=importance_mask,
243
+ true_importance=attn_weights,
244
+ top_k_ratio=0.5
245
+ )
246
+ if self.calibrate_thresholds:
247
+ ### Threshold variance investigation
248
+ unadj_importance_mask = importance_mask.clone()
249
+ importance_mask = torch.softmax(importance_mask + attention_mask, dim=-1)
250
+ sorted_indices = torch.argsort(importance_mask, dim=-1, descending=True)
251
+ sorted_indices = sorted_indices[:, :, -q_len:, :]
252
+ sorted_values, sorted_ix = torch.sort(importance_mask, dim=-1)
253
+ sorted_true_values, _ = torch.sort(torch.gather(unadj_importance_mask, dim=-1, index=sorted_ix), dim=-1)
254
+ true_thresholds = sorted_true_values[:, :, :, int(importance_mask.size(-1) * self.sparse_aggression)]
255
+ thresholds = sorted_values[:, :, :, int(importance_mask.size(-1) * self.sparse_aggression)]
256
+ self.true_threshmean = true_thresholds
257
+ self.threshmean = thresholds
258
+ if self.test_with_thresholds:
259
+ unadj_importance_mask = importance_mask.clone()
260
+ perhead_thresholds = self.tok_calibration_set[self.layer_idx - 1].to(unadj_importance_mask.device) # 0 does not have calibration data.
261
+ mask_tensor = threshold_to_mask(unadj_importance_mask, perhead_thresholds, min_sparse_index, bsz, q_len, key_len)
262
+ else:
263
+ importance_mask = torch.softmax(importance_mask + attention_mask, dim=-1)
264
+ sorted_indices = torch.argsort(importance_mask, dim=-1, descending=True)
265
+ sorted_indices = sorted_indices[:, :, -q_len:, :]
266
+ mask_tensor = sorted_index_to_mask(sorted_indices, attention_mask, min_sparse_index, bsz, q_len, key_len, self.sparse_aggression, self.sliding_window)
267
+ ### Threshold variance investigation
268
+ if self.sliding_window is not None:
269
+ if not hasattr(self, "window_cache"):
270
+ self.window_cache = SlidingWindowCache(max_seq_len=1024,
271
+ sliding_window=self.sliding_window,
272
+ device=mask_tensor.device)
273
+ window = self.window_cache.get_window(q_len, key_len)
274
+ mask_tensor = enforce_sliding_window(mask_tensor, window)
275
+ final_mask = mask_tensor
276
+
277
+ self.final_mask_investigate = final_mask
278
+ attn_weights = attn_weights + mask_tensor + attention_mask
279
+ else:
280
+ attn_weights = torch.matmul(query_states, key_states.transpose(-2, -1)) / math.sqrt(self.head_dim)
281
+ attn_weights = attn_weights + attention_mask
282
+ else:
283
+ raise ValueError(f"Unknown eval mode {evalmode}")
284
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(value_states.dtype)
285
+ attn_output = torch.matmul(attn_weights, value_states)
286
+
287
+ else:
288
+ if self.flash_attn:
289
+ if self.layer_idx > 0:
290
+ # Token hit-rates cannot be calculated if using flash attention.
291
+ self.tok_hit_acc = 0
292
+ q_importance_tensor = self.producer.q_importance[:, self.layer_idx % self.producer_frequency, :, :].float().to(query_states.device) # [BH, Lq, D']
293
+ k_importance_tensor = self.producer.k_importance[:, self.layer_idx % self.producer_frequency, :, :].float().to(key_states.device) # [BH, Lk, D']
294
+ q_importance_tensor = q_importance_tensor.view(bsz, self.num_heads, q_len, self.dDash)
295
+ k_importance_tensor = k_importance_tensor.view(bsz, self.num_heads, key_len, self.dDash)
296
+ device_index = query_states.device.index
297
+ assert self.lookahead == 0, "Lookahead not supported with flash attention yet. Please disable --flash_attn"
298
+ with torch.cuda.device(device_index):
299
+ attn_output, mse_loss = attention_mse_loss(query_states.contiguous().to(torch.float16),
300
+ key_states.contiguous().to(torch.float16),
301
+ value_states.contiguous().to(torch.float16),
302
+ q_importance_tensor.contiguous().to(torch.float16),
303
+ k_importance_tensor.contiguous().to(torch.float16),
304
+ True
305
+ )
306
+ self.tok_hit_acc, self.tok_mean_rank_corr, self.tok_max_rank_corr = 0, 0, 0
307
+ attn_output = attn_output.to(query_states.dtype)
308
+ if not torch.isnan(mse_loss):
309
+ self.msemagn_loss = mse_loss
310
+ else:
311
+ raise ValueError(f"NaN loss detected: {mse_loss}")
312
+ else:
313
+ attn_output = torch.nn.functional.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask=None, is_causal=True)
314
+ else:
315
+ attn_weights = torch.matmul(query_states, key_states.transpose(-2, -1)) / math.sqrt(self.head_dim)
316
+ if self.layer_idx > 0:
317
+ q_importance_tensor = self.producer.q_importance[:, self.layer_idx % self.producer_frequency, :, :].float().to(query_states.device) # [BH, Lq, D']
318
+ k_importance_tensor = self.producer.k_importance[:, self.layer_idx % self.producer_frequency, :, :].float().to(key_states.device) # [BH, Lk, D']
319
+ importance_mask = torch.bmm(q_importance_tensor, k_importance_tensor.transpose(-2, -1)) / math.sqrt(self.dDash) # [BH, Lq, Lk]
320
+ importance_mask = importance_mask.view(bsz, self.num_heads, q_len, key_len) # [B, H, Lq, Lk]
321
+
322
+ if self.lookahead == 0:
323
+ self.msemagn_loss = self.mseloss(attn_weights, importance_mask)
324
+ else:
325
+ self.msemagn_loss = self.mseloss(attn_weights[:, :, self.lookahead:, :], importance_mask[:, :, :-self.lookahead, :])
326
+ self.msemagn_loss = (self.msemagn_loss).mean(dim=(-1, -2))
327
+ self.msemagn_loss = self.msemagn_loss.mean()
328
+
329
+ if self.calc_hitrates:
330
+ self.tok_hit_acc, self.tok_mean_rank_corr, self.tok_max_rank_corr = calculate_hit_metrics(
331
+ estimated_importance=importance_mask,
332
+ true_importance=attn_weights,
333
+ top_k_ratio=0.5
334
+ )
335
+
336
+ if attention_mask is not None:
337
+ attn_weights = attn_weights + attention_mask
338
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(value_states.dtype)
339
+ attn_output = torch.matmul(attn_weights, value_states)
340
+
341
+ if self.layer_idx > 0 and self.train_headpredictor:
342
+ head_importance_tensor = self.producer.head_importances[:, :, :, self.layer_idx % self.producer_frequency].float().to(attn_output.device)
343
+ attn_head_weights = attn_output.mean(dim=-1).permute(0, 2, 1)
344
+ self.headmsemagn_loss = self.headmseloss(attn_head_weights, head_importance_tensor).mean()
345
+
346
+ if self.calc_hitrates:
347
+ self.head_hit_acc, self.head_mean_rank_corr, self.head_max_rank_corr = calculate_hit_metrics(
348
+ estimated_importance=head_importance_tensor,
349
+ true_importance=attn_head_weights,
350
+ top_k_ratio=0.5
351
+ )
352
+ else:
353
+ self.headmsemagn_loss = 0
354
+ if self.calc_hitrates:
355
+ self.head_hit_acc, self.head_mean_rank_corr, self.head_max_rank_corr = 0, 0, 0
356
+
357
+
358
+ checkeverytime = hasattr(self, 'test_with_thresholds')
359
+ if checkeverytime:
360
+ checkeverytime = self.test_with_thresholds
361
+ if final_mask is not None:
362
+ if self.effective_sparsity is None or checkeverytime:
363
+ true_mask = final_mask + attention_mask
364
+ num_deact = true_mask.bool().sum(dim=-1) # Number of tokens disabled.
365
+ causally_deact = (attention_mask.bool()).sum(dim=-1).expand_as(num_deact) # Number of tokens disabled causally anyway
366
+ additional_deact = (num_deact - causally_deact)
367
+ num_active = (~attention_mask.bool()).sum(dim=-1).expand_as(num_deact) # Number of tokens active at this position if zero-sparsity
368
+ effective_sparsity = 100 * (additional_deact.float() / num_active.float()).mean().item()
369
+ self.effective_sparsity = effective_sparsity
370
+ print("Effective Sparsity:", effective_sparsity, "%\t Sequence Length:", q_len)
371
+ if self.layer_idx == 0:
372
+ if self.effective_sparsity is None:
373
+ self.effective_sparsity = 0.0
374
+
375
+ attn_output = attn_output.transpose(1, 2).contiguous()
376
+ attn_output = attn_output.view(bsz, -1, self.hidden_size)
377
+
378
+ if self.config.pretraining_tp > 1:
379
+ attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
380
+ o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
381
+ attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
382
+ else:
383
+ attn_output = self.o_proj(attn_output)
384
+
385
+ if self.producer is None:
386
+ try:
387
+ q_importance, k_importance = self.sparse_token_predictor(
388
+ hidden_states,
389
+ attention_mask=attention_mask,
390
+ position_ids=position_ids,
391
+ past_key_value=past_key_value, # the same single cache
392
+ use_cache=use_cache,
393
+ layer_idx=self.layer_idx, # or pass 0
394
+ )
395
+ if self.train_headpredictor:
396
+ head_importances, past_key_value_hp = self.sparse_head_predictor(
397
+ hidden_states,
398
+ attention_mask=attention_mask,
399
+ position_ids=position_ids,
400
+ past_key_value=past_key_value_hp,
401
+ use_cache=use_cache
402
+ )
403
+ head_importances = head_importances.view(bsz, q_len, self.num_heads, self.num_hidden_layers) # [B L H N]
404
+ q_len = attn_output.size(1)
405
+ k_len = k_importance.size(-1)
406
+ except:
407
+ print(traceback.format_exc())
408
+ import pdb; pdb.set_trace()
409
+
410
+ self.q_importance = q_importance
411
+ self.k_importance = k_importance
412
+
413
+ if self.train_headpredictor:
414
+ if self.head_importances is None:
415
+ self.head_importances = head_importances
416
+ else:
417
+ self.head_importances = torch.cat([self.head_importances, head_importances], dim=1)
418
+
419
+ if self.layer_idx == 31:
420
+ if q_len == 1:
421
+ self.dtok += 1
422
+ print(f"Primary Key-Value Shape: {past_key_value.predictor_primary_key[0].shape}, Importance: {past_key_value.predictor_importance_key[0].shape}, Tok-Decoded: {self.dtok}")
423
+ else:
424
+ self.dtok = 0
425
+
426
+ if not output_attentions:
427
+ attn_weights = None
428
+ return attn_output, attn_weights
429
+
430
+ def convert_kvcache_experimental(model, config, producer_frequency):
431
+ producer_layer = None
432
+ producer_layer_device = None
433
+ layer_counter = {'idx': 0}
434
+
435
+ def recurse_convert(parent_module):
436
+ nonlocal producer_layer
437
+ nonlocal producer_layer_device
438
+ for name, module in parent_module._modules.items():
439
+ if len(list(module.children())) > 0:
440
+ recurse_convert(module)
441
+ if isinstance(module, LlamaAttention):
442
+ device = next(module.parameters()).device
443
+ dtype = next(module.parameters()).dtype
444
+ if layer_counter['idx'] % producer_frequency == 0:
445
+ new_module = LlamaAttentionExperimental(config).to(dtype).to(device)
446
+ producer_layer = new_module
447
+ producer_layer_device = device
448
+ else:
449
+ new_module = LlamaAttentionExperimental(
450
+ config,
451
+ producer=producer_layer,
452
+ layer_idx=layer_counter['idx']
453
+ ).to(dtype).to(device)
454
+ new_module.load_state_dict(module.state_dict(), strict=False)
455
+ is_producer = layer_counter['idx'] % producer_frequency == 0
456
+ if is_producer:
457
+ print(f"Converted Producer layer '{name}' to LlamaAttentionExperimental at layer index {layer_counter['idx']}")
458
+ else:
459
+ print(f"Converted layer '{name}' to LlamaAttentionExperimental at layer index {layer_counter['idx']}")
460
+ parent_module._modules[name] = new_module
461
+ layer_counter['idx'] += 1
462
+ recurse_convert(model)
463
+ producer_layer = producer_layer.to(producer_layer_device)
464
+ return model
predictor.py ADDED
@@ -0,0 +1,643 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pdb
3
+ import copy
4
+ import math
5
+ import numpy as np
6
+ from dataclasses import dataclass
7
+ from typing import Optional, Tuple, Union
8
+ import gc
9
+
10
+ from typing import Any, Dict, List, Optional, Tuple
11
+ import traceback
12
+ import torch
13
+ from torch import nn
14
+ import torch.utils.checkpoint
15
+ import torch.nn.functional as F
16
+ from torch.cuda.amp import autocast
17
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
18
+
19
+ from transformers.models.llama.configuration_llama import LlamaConfig
20
+ from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, LlamaAttention, apply_rotary_pos_emb
21
+
22
+ from utils import LlamaLinearScalingRotaryEmbedding, LlamaDynamicNTKScalingRotaryEmbedding, repeat_kv, sorted_index_to_mask
23
+ from transformers.cache_utils import DynamicCache
24
+
25
+ from triton_kernels.flash_attn import attention
26
+ from triton_kernels.flash_attn_mse_loss import attention_mse_loss
27
+
28
+
29
+ class PredictorDynamicCache(DynamicCache):
30
+ def __init__(self):
31
+ super().__init__()
32
+ self.predictor_primary_key: List[Optional[torch.Tensor]] = []
33
+ self.predictor_primary_value: List[Optional[torch.Tensor]] = []
34
+ self.predictor_importance_key: List[Optional[torch.Tensor]] = []
35
+
36
+ def update_predictor_primary(
37
+ self,
38
+ key_states: torch.Tensor,
39
+ value_states: torch.Tensor,
40
+ layer_idx: int,
41
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
42
+ """
43
+ Append or create the predictor's "primary" K/V states for `layer_idx`.
44
+
45
+ shape for key_states, value_states is typically [batch_size, num_heads, seq_len, head_dim].
46
+ """
47
+ # Extend the lists so that `predictor_primary_key[layer_idx]` and
48
+ # `predictor_primary_value[layer_idx]` exist.
49
+ self._ensure_list_capacity(
50
+ self.predictor_primary_key, layer_idx, fill=None
51
+ )
52
+ self._ensure_list_capacity(
53
+ self.predictor_primary_value, layer_idx, fill=None
54
+ )
55
+
56
+ # If this is the very first time we are updating that layer's predictor cache, just assign
57
+ if self.predictor_primary_key[layer_idx] is None:
58
+ self.predictor_primary_key[layer_idx] = key_states
59
+ self.predictor_primary_value[layer_idx] = value_states
60
+ else:
61
+ # Otherwise, concatenate along the seq_len dimension (=-2 or =2 depending on your shape).
62
+ self.predictor_primary_key[layer_idx] = torch.cat(
63
+ [self.predictor_primary_key[layer_idx], key_states], dim=2
64
+ )
65
+ self.predictor_primary_value[layer_idx] = torch.cat(
66
+ [self.predictor_primary_value[layer_idx], value_states], dim=2
67
+ )
68
+
69
+ return (
70
+ self.predictor_primary_key[layer_idx],
71
+ self.predictor_primary_value[layer_idx],
72
+ )
73
+
74
+ def update_predictor_importance(
75
+ self,
76
+ key_states: torch.Tensor,
77
+ layer_idx: int,
78
+ ) -> torch.Tensor:
79
+ """
80
+ Append or create the predictor's "importance" key for `layer_idx`.
81
+ """
82
+ self._ensure_list_capacity(
83
+ self.predictor_importance_key, layer_idx, fill=None
84
+ )
85
+
86
+ if self.predictor_importance_key[layer_idx] is None:
87
+ self.predictor_importance_key[layer_idx] = key_states
88
+ else:
89
+ self.predictor_importance_key[layer_idx] = torch.cat(
90
+ [self.predictor_importance_key[layer_idx], key_states], dim=2
91
+ )
92
+ return self.predictor_importance_key[layer_idx]
93
+
94
+ def crop(self, max_length: int):
95
+ super().crop(max_length)
96
+ # Now also crop predictor caches
97
+ for idx in range(len(self.predictor_primary_key)):
98
+ if self.predictor_primary_key[idx] is not None:
99
+ self.predictor_primary_key[idx] = self.predictor_primary_key[idx][..., :max_length, :]
100
+ self.predictor_primary_value[idx] = self.predictor_primary_value[idx][..., :max_length, :]
101
+
102
+ for idx in range(len(self.predictor_importance_key)):
103
+ if self.predictor_importance_key[idx] is not None:
104
+ self.predictor_importance_key[idx] = self.predictor_importance_key[idx][..., :max_length, :]
105
+
106
+ # Remember to adjust self._seen_tokens accordingly
107
+ self._seen_tokens = min(self._seen_tokens, max_length)
108
+
109
+ def batch_split(
110
+ self, full_batch_size: int, split_size: int, num_hidden_layers: int = None
111
+ ) -> List["PredictorDynamicCache"]:
112
+ # Use the base split logic for the standard K/V
113
+ base_splits = super().batch_split(full_batch_size, split_size, num_hidden_layers)
114
+ # `base_splits` is now a list of new DynamicCache objects. But we *actually*
115
+ # want them to be PredictorDynamicCache so we can store the predictor states.
116
+ # Easiest: we can cast and fill them.
117
+ out: List[PredictorDynamicCache] = []
118
+
119
+ for split_i, base_split in enumerate(base_splits):
120
+ # Construct an empty PredictorDynamicCache
121
+ new_cache = PredictorDynamicCache()
122
+ # Copy over the underlying fields from base_split
123
+ new_cache.key_cache = base_split.key_cache
124
+ new_cache.value_cache = base_split.value_cache
125
+ new_cache._seen_tokens = base_split._seen_tokens
126
+
127
+ # Now also slice our predictor fields
128
+ # The slice in batch dim is [i:i+split_size].
129
+ b_start = split_i * split_size
130
+ b_end = min(full_batch_size, b_start + split_size)
131
+
132
+ new_cache.predictor_primary_key = self._slice_list_tensors(
133
+ self.predictor_primary_key, b_start, b_end
134
+ )
135
+ new_cache.predictor_primary_value = self._slice_list_tensors(
136
+ self.predictor_primary_value, b_start, b_end
137
+ )
138
+ new_cache.predictor_importance_key = self._slice_list_tensors(
139
+ self.predictor_importance_key, b_start, b_end
140
+ )
141
+
142
+ out.append(new_cache)
143
+
144
+ return out
145
+
146
+ @classmethod
147
+ def from_batch_splits(cls, splits: List["PredictorDynamicCache"], num_hidden_layers: int = None) -> "PredictorDynamicCache":
148
+ # Let the base class handle the normal K/V merges
149
+ base_merged = DynamicCache.from_batch_splits(splits, num_hidden_layers=num_hidden_layers)
150
+ merged = cls()
151
+ merged.key_cache = base_merged.key_cache
152
+ merged.value_cache = base_merged.value_cache
153
+ merged._seen_tokens = base_merged._seen_tokens
154
+
155
+ # Now unify predictor states by concatenating along batch dim=0
156
+ merged.predictor_primary_key = cls._merge_list_tensors(
157
+ [split.predictor_primary_key for split in splits]
158
+ )
159
+ merged.predictor_primary_value = cls._merge_list_tensors(
160
+ [split.predictor_primary_value for split in splits]
161
+ )
162
+ merged.predictor_importance_key = cls._merge_list_tensors(
163
+ [split.predictor_importance_key for split in splits]
164
+ )
165
+
166
+ return merged
167
+
168
+ def batch_repeat_interleave(self, repeats: int):
169
+ super().batch_repeat_interleave(repeats)
170
+ self.predictor_primary_key = self._repeat_list_tensors(
171
+ self.predictor_primary_key, repeats
172
+ )
173
+ self.predictor_primary_value = self._repeat_list_tensors(
174
+ self.predictor_primary_value, repeats
175
+ )
176
+ self.predictor_importance_key = self._repeat_list_tensors(
177
+ self.predictor_importance_key, repeats
178
+ )
179
+
180
+ def batch_select_indices(self, indices: torch.Tensor):
181
+ super().batch_select_indices(indices)
182
+ self.predictor_primary_key = self._select_list_tensors(
183
+ self.predictor_primary_key, indices
184
+ )
185
+ self.predictor_primary_value = self._select_list_tensors(
186
+ self.predictor_primary_value, indices
187
+ )
188
+ self.predictor_importance_key = self._select_list_tensors(
189
+ self.predictor_importance_key, indices
190
+ )
191
+
192
+ @staticmethod
193
+ def _ensure_list_capacity(lst: list, idx: int, fill=None):
194
+ if len(lst) <= idx:
195
+ lst.extend([fill] * (idx + 1 - len(lst)))
196
+
197
+ @staticmethod
198
+ def _slice_list_tensors(
199
+ tensor_list: List[Optional[torch.Tensor]], start: int, end: int
200
+ ) -> List[Optional[torch.Tensor]]:
201
+ out = []
202
+ for t in tensor_list:
203
+ if t is None:
204
+ out.append(None)
205
+ else:
206
+ out.append(t[start:end, ...])
207
+ return out
208
+
209
+ @classmethod
210
+ def _merge_list_tensors(
211
+ cls, list_of_lists: List[List[Optional[torch.Tensor]]]
212
+ ) -> List[Optional[torch.Tensor]]:
213
+ # If no splits, return empty
214
+ if not list_of_lists:
215
+ return []
216
+
217
+ # Number of layers is length of the sub-list from the first split
218
+ max_len = len(list_of_lists[0])
219
+ merged = [None] * max_len
220
+
221
+ for layer_idx in range(max_len):
222
+ # collect that layer_idx from each split
223
+ chunk_tensors = []
224
+ for split in list_of_lists:
225
+ t = split[layer_idx] if layer_idx < len(split) else None
226
+ if t is not None:
227
+ chunk_tensors.append(t)
228
+ if len(chunk_tensors) == 0:
229
+ merged[layer_idx] = None
230
+ else:
231
+ merged[layer_idx] = torch.cat(chunk_tensors, dim=0)
232
+ return merged
233
+
234
+ @staticmethod
235
+ def _repeat_list_tensors(
236
+ tensor_list: List[Optional[torch.Tensor]], repeats: int
237
+ ) -> List[Optional[torch.Tensor]]:
238
+ out = []
239
+ for t in tensor_list:
240
+ if t is None:
241
+ out.append(None)
242
+ else:
243
+ out.append(t.repeat_interleave(repeats, dim=0))
244
+ return out
245
+
246
+ @staticmethod
247
+ def _select_list_tensors(
248
+ tensor_list: List[Optional[torch.Tensor]], indices: torch.Tensor
249
+ ) -> List[Optional[torch.Tensor]]:
250
+ out = []
251
+ for t in tensor_list:
252
+ if t is None:
253
+ out.append(None)
254
+ else:
255
+ out.append(t.index_select(0, indices))
256
+ return out
257
+
258
+
259
+ class TokenImportancePredictorAttentive(nn.Module):
260
+ def __init__(self, config, pred_hid_size, num_heads, num_hidden_layers, dDash, intdim, \
261
+ attn_reduce_factor, dropout=0.1):
262
+ """
263
+ Optimized Token Importance Predictor with parallel Q-K projections and simplified mapping.
264
+
265
+ Args:
266
+ config: Configuration object containing model parameters.
267
+ pred_hid_size (int): Hidden size for the predictor's attention layer.
268
+ num_heads (int): Number of attention heads.
269
+ num_hidden_layers (int): Number of transformer layers to predict.
270
+ dropout (float): Dropout probability.
271
+ q_downscale (int): Factor to downscale the Q dimension for efficiency.
272
+ intermediate_dim (int): Intermediate dimension for non-linear transformations in projections.
273
+ """
274
+ super().__init__()
275
+ self.config = config
276
+ self.hidden_size = pred_hid_size
277
+ self.num_heads = num_heads
278
+ self.num_hidden_layers = num_hidden_layers
279
+ self.dropout = dropout
280
+ self.head_dim = pred_hid_size // (num_heads * 4) # Predictor head dimension is not the same as the model head dimension.
281
+ self.rope_theta = config.rope_theta
282
+ self.dDash = dDash
283
+ self.intermediate_dim = intdim
284
+ self.attn_reduce_factor = attn_reduce_factor
285
+ self.max_position_embeddings = config.max_position_embeddings
286
+ self.flash_attn = False
287
+ assert pred_hid_size % (num_heads * 4) == 0, "pred_hid_size must be divisible by num_heads * 4."
288
+
289
+ # Reduce the hidden size for attention computations
290
+ self.hidden_size_reduced = self.hidden_size // self.attn_reduce_factor # For example, reduce to 1/4th
291
+ assert self.hidden_size_reduced % self.num_heads == 0, "Reduced hidden size must be divisible by num_heads"
292
+ self.attn_head_dim = self.hidden_size_reduced // self.num_heads
293
+
294
+ # Input projection to reduce hidden size
295
+ self.input_proj = nn.Linear(self.hidden_size, self.hidden_size_reduced, bias=False)
296
+
297
+ # Query, Key, Value projections for attention
298
+ self.q_proj_attn = nn.Linear(self.hidden_size_reduced, self.hidden_size_reduced, bias=False)
299
+ self.k_proj_attn = nn.Linear(self.hidden_size_reduced, self.hidden_size_reduced, bias=False)
300
+ self.v_proj_attn = nn.Linear(self.hidden_size_reduced, self.hidden_size_reduced, bias=False)
301
+ # Output projection to restore hidden size
302
+ # self.o_proj_attn = nn.Linear(self.hidden_size_reduced, self.hidden_size_reduced, bias=False)
303
+ self.attn_dropout = nn.Dropout(self.dropout)
304
+
305
+ # LayerNorm and Feed-forward network
306
+ self.norm1 = nn.LayerNorm(self.hidden_size_reduced)
307
+ self.norm2 = nn.LayerNorm(self.hidden_size)
308
+
309
+ self.ffn_hidden_size = 2 * self.hidden_size_reduced # Typical FFN hidden size
310
+ self.ffn = nn.Sequential(
311
+ nn.Linear(self.hidden_size_reduced, self.ffn_hidden_size),
312
+ nn.GELU(),
313
+ nn.Linear(self.ffn_hidden_size, self.hidden_size),
314
+ nn.Dropout(self.dropout)
315
+ )
316
+ # Add extra LayerNorm for the importance branch when not using the old design.
317
+ self.norm_importance = nn.LayerNorm(self.hidden_size)
318
+
319
+ # Define Q and K projection layers for all layers in parallel with non-linearity[]
320
+ # Output shape: [B, L, N * H * D']
321
+ self.q_proj_importance = nn.Sequential(
322
+ nn.Linear(pred_hid_size, self.intermediate_dim, bias=False),
323
+ nn.SiLU(),
324
+ nn.Linear(self.intermediate_dim, num_hidden_layers * num_heads * self.dDash, bias=False)
325
+ )
326
+ self.k_proj_importance = nn.Sequential(
327
+ nn.Linear(pred_hid_size, self.intermediate_dim, bias=False),
328
+ nn.SiLU(),
329
+ nn.Linear(self.intermediate_dim, num_hidden_layers * num_heads * self.dDash, bias=False)
330
+ )
331
+
332
+ # Initialize rotary positional embeddings
333
+ self._init_rope()
334
+ self._initialize_weights()
335
+ self.device = None
336
+
337
+ def _initialize_weights(self):
338
+ for name, module in self.named_modules():
339
+ if isinstance(module, nn.Linear):
340
+ nn.init.xavier_uniform_(module.weight) # Xavier initialization for linear layers
341
+ if module.bias is not None:
342
+ nn.init.constant_(module.bias, 0)
343
+ elif isinstance(module, nn.LayerNorm):
344
+ nn.init.constant_(module.weight, 1.0)
345
+ nn.init.constant_(module.bias, 0.0)
346
+ elif isinstance(module, nn.MultiheadAttention):
347
+ # Initialize in_proj_weight
348
+ nn.init.xavier_uniform_(module.in_proj_weight)
349
+ if module.in_proj_bias is not None:
350
+ nn.init.constant_(module.in_proj_bias, 0)
351
+
352
+ # Initialize out_proj
353
+ nn.init.xavier_uniform_(module.out_proj.weight)
354
+ if module.out_proj.bias is not None:
355
+ nn.init.constant_(module.out_proj.bias, 0)
356
+
357
+ def _init_rope(self):
358
+
359
+ # send self.config but after modifying head_dim to be self.head_dim just in the function call
360
+ config_copy = copy.deepcopy(self.config)
361
+ config_copy.rope_scaling = {
362
+ "factor": 32.0,
363
+ "high_freq_factor": 4.0,
364
+ "low_freq_factor": 1.0,
365
+ "original_max_position_embeddings": 8192,
366
+ "rope_type": "llama3"
367
+ }
368
+ config_copy.head_dim = self.attn_head_dim
369
+
370
+ # Rotary embedding for attention layer
371
+ self.rotary_emb_attn = LlamaRotaryEmbedding(
372
+ config_copy
373
+ )
374
+
375
+ config_copy.head_dim = self.dDash
376
+ # Rotary embedding for importance projection
377
+ self.rotary_emb_importance = LlamaRotaryEmbedding(
378
+ config_copy
379
+ )
380
+
381
+ def forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, use_cache=False, layer_idx=None):
382
+ """
383
+ Forward pass for the Optimized Token Importance Predictor.
384
+
385
+ Args:
386
+ hidden_states (torch.Tensor): Input tensor of shape [B, L, HQ].
387
+ attention_mask (torch.Tensor, optional): Attention mask of shape [B, 1, 1, L] or [B, 1, L, L].
388
+ position_ids (torch.Tensor, optional): Position IDs.
389
+ past_key_value (tuple, optional): Past key and value states.
390
+ use_cache (bool, optional): Whether to use cache.
391
+
392
+ Returns:
393
+ torch.Tensor: Importance scores of shape [B, N, H, L, L].
394
+ """
395
+ layer_idx = 0 # Guaranteed to be 0, as we only have one predictor!
396
+
397
+ # Set device if not already set
398
+ if self.device != hidden_states.device:
399
+ self.device = hidden_states.device
400
+ self.to(self.device)
401
+
402
+ B, L, E = hidden_states.size()
403
+
404
+ # Reduce hidden size
405
+ hidden_states = hidden_states.to(self.input_proj.weight.dtype)
406
+ hidden_states_reduced = self.input_proj(hidden_states) # [B, L, hidden_size_reduced]
407
+ # Compute q, k, v for attention
408
+ q = self.q_proj_attn(hidden_states_reduced) # [B, L, hidden_size_reduced]
409
+ k = self.k_proj_attn(hidden_states_reduced) # [B, L, hidden_size_reduced]
410
+ v = self.v_proj_attn(hidden_states_reduced) # [B, L, hidden_size_reduced]
411
+ # Reshape q, k, v to [B, num_heads, L, attn_head_dim]
412
+ q = q.view(B, L, self.num_heads, self.attn_head_dim).transpose(1, 2) # [B, num_heads, L, attn_head_dim]
413
+ k = k.view(B, L, self.num_heads, self.attn_head_dim).transpose(1, 2) # [B, num_heads, L, attn_head_dim]
414
+ v = v.view(B, L, self.num_heads, self.attn_head_dim).transpose(1, 2) # [B, num_heads, L, attn_head_dim]
415
+ if (past_key_value is not None
416
+ and layer_idx < len(past_key_value.predictor_primary_key)
417
+ and past_key_value.predictor_primary_key[layer_idx] is not None):
418
+ offset = past_key_value.predictor_primary_key[layer_idx].shape[2] # old_k.shape[2]
419
+ else:
420
+ offset = 0
421
+
422
+ # total seq length for new + old
423
+ kv_seq_len = offset + L
424
+
425
+ # Step 2: build position_ids for just the new chunk [offset..offset+L-1]
426
+ if position_ids is None:
427
+ # shape [B, L], e.g. [0..(offset+L-1)]
428
+ position_ids = torch.arange(offset, offset + L, dtype=torch.long, device=self.device)
429
+ position_ids = position_ids.unsqueeze(0).expand(B, L)
430
+
431
+ # Step 3: apply rotary to just the new chunk k,v with the correct offset
432
+ cos, sin = self.rotary_emb_attn(v, position_ids)
433
+ q, k = apply_rotary_pos_emb(q, k, cos, sin, position_ids)
434
+
435
+ # Step 4: ask the cache to append them. Then re‐assign k, v to the full cat
436
+ if use_cache and past_key_value is not None:
437
+ k, v = past_key_value.update_predictor_primary(k.detach(), v.detach(), layer_idx)
438
+ kv_seq_len = k.size(2) # now includes old + new
439
+
440
+ attn_output = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, is_causal=True)
441
+ attn_output = attn_output.to(q.dtype)
442
+ attn_output = attn_output.transpose(1, 2).contiguous().view(B, L, self.hidden_size_reduced)
443
+ attn_output = self.norm1(attn_output)
444
+ ffn_output = self.ffn(attn_output)
445
+ # Temporary measure, till old predictor fully deprecated
446
+ hidden_states = self.norm2(hidden_states + ffn_output)
447
+
448
+ B, L, E = hidden_states.size()
449
+ # Importance projections
450
+ H = self.num_heads
451
+ N = self.num_hidden_layers
452
+
453
+ hidden_states_for_importance = self.norm_importance(hidden_states)
454
+ q_importance = self.q_proj_importance(hidden_states_for_importance)
455
+ k_importance = self.k_proj_importance(hidden_states_for_importance)
456
+
457
+ # Reshape and permute to [B, N, H, L, D']
458
+ q_importance = q_importance.view(B, L, N, H, self.dDash).permute(0, 2, 3, 1, 4).contiguous() # [B, N, H, L, D']
459
+ k_importance = k_importance.view(B, L, N, H, self.dDash).permute(0, 2, 3, 1, 4).contiguous() # [B, N, H, L, D']
460
+
461
+ # Flatten N and H for efficient computation
462
+ q_importance = q_importance.view(B * N * H, L, self.dDash) # [BNH, L, D']
463
+ k_importance = k_importance.view(B * N * H, L, self.dDash) # [BNH, L, D']
464
+
465
+ # Apply rotary positional embeddings
466
+ cos, sin = self.rotary_emb_importance(k_importance, position_ids)
467
+ q_importance, k_importance = apply_rotary_pos_emb(q_importance, k_importance, cos, sin, position_ids)
468
+
469
+ if use_cache and past_key_value is not None:
470
+ k_importance = past_key_value.update_predictor_importance(k_importance.detach(), layer_idx)
471
+
472
+ k_importance = k_importance.view(B * H, N, -1, self.dDash) # [BNH, L, D']
473
+ q_importance = q_importance.view(B * H, N, -1, self.dDash) # [BH, N, L, D']
474
+ return q_importance, k_importance
475
+
476
+
477
+
478
+ class HeadImportancePredictor(nn.Module):
479
+ def __init__(self, config, pred_hid_size, num_heads, num_hidden_layers, dDash, intdim, \
480
+ attn_reduce_factor, dropout=0.1):
481
+ """
482
+ Optimized Token Importance Predictor with parallel Q-K projections and simplified mapping.
483
+
484
+ Args:
485
+ config: Configuration object containing model parameters.
486
+ pred_hid_size (int): Hidden size for the predictor's attention layer.
487
+ num_heads (int): Number of attention heads.
488
+ num_hidden_layers (int): Number of transformer layers to predict.
489
+ dropout (float): Dropout probability.
490
+ q_downscale (int): Factor to downscale the Q dimension for efficiency.
491
+ intermediate_dim (int): Intermediate dimension for non-linear transformations in projections.
492
+ """
493
+ super().__init__()
494
+ self.is_head_predictor = None
495
+ self.config = config
496
+ self.hidden_size = pred_hid_size
497
+ self.num_heads = num_heads
498
+ self.num_hidden_layers = num_hidden_layers
499
+ self.dropout = dropout
500
+ self.head_dim = pred_hid_size // (num_heads * 4)
501
+ self.rope_theta = config.rope_theta
502
+ self.dDash = dDash
503
+ self.intermediate_dim = intdim
504
+ self.attn_reduce_factor = attn_reduce_factor
505
+ self.max_position_embeddings = config.max_position_embeddings
506
+ self.flash_attn = False
507
+
508
+ # Reduce the hidden size for attention computations
509
+ self.hidden_size_reduced = self.hidden_size // self.attn_reduce_factor # For example, reduce to 1/4th
510
+ assert self.hidden_size_reduced % self.num_heads == 0, "Reduced hidden size must be divisible by num_heads"
511
+ self.attn_head_dim = self.hidden_size_reduced // self.num_heads
512
+
513
+ # Input projection to reduce hidden size
514
+ self.input_proj = nn.Linear(self.hidden_size, self.hidden_size_reduced, bias=False)
515
+
516
+ # Query, Key, Value projections for attention
517
+ self.q_proj_attn = nn.Linear(self.hidden_size_reduced, self.hidden_size_reduced, bias=False)
518
+ self.k_proj_attn = nn.Linear(self.hidden_size_reduced, self.hidden_size_reduced, bias=False)
519
+ self.v_proj_attn = nn.Linear(self.hidden_size_reduced, self.hidden_size_reduced, bias=False)
520
+ # Output projection to restore hidden size
521
+ # self.o_proj_attn = nn.Linear(self.hidden_size_reduced, self.hidden_size_reduced, bias=False)
522
+ self.attn_dropout = nn.Dropout(self.dropout)
523
+
524
+ # LayerNorm and Feed-forward network
525
+ self.norm1 = nn.LayerNorm(self.hidden_size_reduced)
526
+ self.norm2 = nn.LayerNorm(self.hidden_size)
527
+
528
+ self.ffn_hidden_size = 4 * self.hidden_size_reduced # Typical FFN hidden size
529
+ self.ffn = nn.Sequential(
530
+ nn.Linear(self.hidden_size_reduced, self.ffn_hidden_size),
531
+ nn.GELU(),
532
+ nn.Linear(self.ffn_hidden_size, self.num_heads * self.num_hidden_layers),
533
+ )
534
+
535
+ # Initialize rotary positional embeddings
536
+ self._init_rope()
537
+ self._initialize_weights()
538
+ self.device = None
539
+
540
+ def _initialize_weights(self):
541
+ for name, module in self.named_modules():
542
+ if isinstance(module, nn.Linear):
543
+ nn.init.xavier_uniform_(module.weight) # Xavier initialization for linear layers
544
+ if module.bias is not None:
545
+ nn.init.constant_(module.bias, 0)
546
+ elif isinstance(module, nn.LayerNorm):
547
+ nn.init.constant_(module.weight, 1.0)
548
+ nn.init.constant_(module.bias, 0.0)
549
+ elif isinstance(module, nn.MultiheadAttention):
550
+ # Initialize in_proj_weight
551
+ nn.init.xavier_uniform_(module.in_proj_weight)
552
+ if module.in_proj_bias is not None:
553
+ nn.init.constant_(module.in_proj_bias, 0)
554
+
555
+ # Initialize out_proj
556
+ nn.init.xavier_uniform_(module.out_proj.weight)
557
+ if module.out_proj.bias is not None:
558
+ nn.init.constant_(module.out_proj.bias, 0)
559
+
560
+ def _init_rope(self):
561
+ config_copy = copy.deepcopy(self.config)
562
+ config_copy.head_dim = self.attn_head_dim
563
+ # Rotary embedding for attention layer
564
+ self.rotary_emb_attn = LlamaRotaryEmbedding(
565
+ config_copy
566
+ )
567
+ # Rotary embedding for importance projection
568
+ self.rotary_emb_importance = LlamaRotaryEmbedding(
569
+ config_copy
570
+ )
571
+
572
+ def forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, use_cache=False):
573
+ """
574
+ Forward pass for the Optimized Token Importance Predictor.
575
+
576
+ Args:
577
+ hidden_states (torch.Tensor): Input tensor of shape [B, L, HQ].
578
+ attention_mask (torch.Tensor, optional): Attention mask of shape [B, 1, 1, L] or [B, 1, L, L].
579
+ position_ids (torch.Tensor, optional): Position IDs.
580
+ past_key_value (tuple, optional): Past key and value states.
581
+ use_cache (bool, optional): Whether to use cache.
582
+
583
+ Returns:
584
+ torch.Tensor: Importance scores of shape [B, N, H, L, L].
585
+ """
586
+ # Set device if not already set
587
+ if self.device != hidden_states.device:
588
+ self.device = hidden_states.device
589
+ self.to(self.device)
590
+
591
+ B, L, E = hidden_states.size()
592
+ if past_key_value is None:
593
+ past_key_value = {}
594
+ # if L == 1:
595
+ # import pdb; pdb.set_trace()
596
+ past_primary = past_key_value.get('primary', None)
597
+ # Reduce hidden size
598
+ hidden_states = hidden_states.to(self.input_proj.weight.dtype)
599
+ hidden_states_reduced = self.input_proj(hidden_states) # [B, L, hidden_size_reduced]
600
+ # Compute q, k, v for attention
601
+ q = self.q_proj_attn(hidden_states_reduced) # [B, L, hidden_size_reduced]
602
+ k = self.k_proj_attn(hidden_states_reduced) # [B, L, hidden_size_reduced]
603
+ v = self.v_proj_attn(hidden_states_reduced) # [B, L, hidden_size_reduced]
604
+ # Reshape q, k, v to [B, num_heads, L, attn_head_dim]
605
+ q = q.view(B, L, self.num_heads, self.attn_head_dim).transpose(1, 2) # [B, num_heads, L, attn_head_dim]
606
+ k = k.view(B, L, self.num_heads, self.attn_head_dim).transpose(1, 2) # [B, num_heads, L, attn_head_dim]
607
+ v = v.view(B, L, self.num_heads, self.attn_head_dim).transpose(1, 2) # [B, num_heads, L, attn_head_dim]
608
+ # Compute kv_seq_len before concatenation
609
+ if past_primary is not None:
610
+ past_L = past_primary[0].shape[2]
611
+ kv_seq_len = past_L + L
612
+ else:
613
+ kv_seq_len = L
614
+
615
+ # Apply rotary positional embeddings based on kv_seq_len
616
+ cos, sin = self.rotary_emb_attn(v, position_ids)
617
+ if position_ids is None:
618
+ position_ids = torch.arange(kv_seq_len, dtype=torch.long, device=self.device)
619
+ position_ids = position_ids.unsqueeze(0).expand(B, kv_seq_len)
620
+
621
+ if past_primary is not None:
622
+ # Concatenate past k and v
623
+ k = torch.cat([past_primary[0], k], dim=2) # [B, num_heads, past_L + L, attn_head_dim]
624
+ v = torch.cat([past_primary[1], v], dim=2) # [B, num_heads, past_L + L, attn_head_dim]
625
+
626
+ # Apply rotary embeddings after concatenation
627
+ q, k = apply_rotary_pos_emb(q, k, cos, sin, position_ids)
628
+
629
+ # Update cache if use_cache is True
630
+ if use_cache:
631
+ past_key_value['primary'] = (k.detach(), v.detach())
632
+
633
+ # if self.flash_attn:
634
+ # sm_scale = 1.0 / math.sqrt(self.attn_head_dim)
635
+ # attn_output = attention(q.contiguous().to(torch.float16), k.contiguous().to(torch.float16), v.contiguous().to(torch.float16), True, sm_scale).to(q.dtype)
636
+ # else:
637
+ # attn_output = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, is_causal=True)
638
+ attn_output = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, is_causal=True)
639
+ attn_output = attn_output.to(q.dtype)
640
+ attn_output = attn_output.transpose(1, 2).contiguous().view(B, L, self.hidden_size_reduced)
641
+ attn_output = self.norm1(attn_output)
642
+ head_importances = self.ffn(attn_output)
643
+ return head_importances, past_key_value
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5b6022d8733580ff4744a5e7002122b704a1e6426980a07db95e717d917ec553
3
+ size 4992957478
triton_kernels/__pycache__/flash_attn.cpython-310.pyc ADDED
Binary file (11.4 kB). View file
 
triton_kernels/__pycache__/flash_attn.cpython-39.pyc ADDED
Binary file (11.3 kB). View file
 
triton_kernels/__pycache__/flash_attn_mse_loss.cpython-310.pyc ADDED
Binary file (16.2 kB). View file
 
triton_kernels/__pycache__/flash_attn_mse_loss.cpython-39.pyc ADDED
Binary file (16 kB). View file
 
triton_kernels/flash_attn.py ADDED
@@ -0,0 +1,543 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ import triton
4
+ import triton.language as tl
5
+
6
+ DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
7
+
8
+
9
+ def is_hip():
10
+ return triton.runtime.driver.active.get_current_target().backend == "hip"
11
+
12
+
13
+ @triton.jit
14
+ def _attn_fwd_inner(acc, l_i, m_i, q, #
15
+ K_block_ptr, V_block_ptr, #
16
+ start_m, qk_scale, #
17
+ BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr, BLOCK_N: tl.constexpr, #
18
+ STAGE: tl.constexpr, offs_m: tl.constexpr, offs_n: tl.constexpr, #
19
+ N_CTX: tl.constexpr, fp8_v: tl.constexpr):
20
+ # range of values handled by this stage
21
+ if STAGE == 1:
22
+ lo, hi = 0, start_m * BLOCK_M
23
+ elif STAGE == 2:
24
+ lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M
25
+ lo = tl.multiple_of(lo, BLOCK_M)
26
+ # causal = False
27
+ else:
28
+ lo, hi = 0, N_CTX
29
+ K_block_ptr = tl.advance(K_block_ptr, (0, lo))
30
+ V_block_ptr = tl.advance(V_block_ptr, (lo, 0))
31
+ # loop over k, v and update accumulator
32
+ for start_n in range(lo, hi, BLOCK_N):
33
+ start_n = tl.multiple_of(start_n, BLOCK_N)
34
+ # -- compute qk ----
35
+ k = tl.load(K_block_ptr)
36
+ qk = tl.dot(q, k)
37
+ if STAGE == 2:
38
+ mask = offs_m[:, None] >= (start_n + offs_n[None, :])
39
+ qk = qk * qk_scale + tl.where(mask, 0, -1.0e6)
40
+ m_ij = tl.maximum(m_i, tl.max(qk, 1))
41
+ qk -= m_ij[:, None]
42
+ else:
43
+ m_ij = tl.maximum(m_i, tl.max(qk, 1) * qk_scale)
44
+ qk = qk * qk_scale - m_ij[:, None]
45
+ p = tl.math.exp2(qk)
46
+ l_ij = tl.sum(p, 1)
47
+ # -- update m_i and l_i
48
+ alpha = tl.math.exp2(m_i - m_ij)
49
+ l_i = l_i * alpha + l_ij
50
+ # -- update output accumulator --
51
+ acc = acc * alpha[:, None]
52
+ # update acc
53
+ v = tl.load(V_block_ptr)
54
+ if fp8_v:
55
+ p = p.to(tl.float8e5)
56
+ else:
57
+ p = p.to(tl.float16)
58
+ acc = tl.dot(p, v, acc)
59
+ # update m_i and l_i
60
+ m_i = m_ij
61
+ V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))
62
+ K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
63
+ return acc, l_i, m_i
64
+
65
+
66
+ # We don't run auto-tuning every time to keep the tutorial fast. Keeping
67
+ # the code below and commenting out the equivalent parameters is convenient for
68
+ # re-tuning.
69
+ configs = [
70
+ triton.Config({'BLOCK_M': BM, 'BLOCK_N': BN}, num_stages=s, num_warps=w) \
71
+ for BM in [16, 32, 64, 128]\
72
+ for BN in [16, 32]\
73
+ for s in ([1] if is_hip() else [3, 4, 7])\
74
+ for w in [4, 8]\
75
+ ]
76
+
77
+
78
+ def keep(conf):
79
+ BLOCK_M = conf.kwargs["BLOCK_M"]
80
+ BLOCK_N = conf.kwargs["BLOCK_N"]
81
+ if BLOCK_M * BLOCK_N < 128 * 128 and conf.num_warps == 8:
82
+ return False
83
+ return True
84
+
85
+
86
+ @triton.autotune(list(filter(keep, configs)), key=["N_CTX", "HEAD_DIM"])
87
+ @triton.jit
88
+ def _attn_fwd(Q, K, V, sm_scale, M, Out, #
89
+ stride_qz, stride_qh, stride_qm, stride_qk, #
90
+ stride_kz, stride_kh, stride_kn, stride_kk, #
91
+ stride_vz, stride_vh, stride_vk, stride_vn, #
92
+ stride_oz, stride_oh, stride_om, stride_on, #
93
+ Z, H, N_CTX, #
94
+ HEAD_DIM: tl.constexpr, #
95
+ BLOCK_M: tl.constexpr, #
96
+ BLOCK_N: tl.constexpr, #
97
+ STAGE: tl.constexpr #
98
+ ):
99
+ tl.static_assert(BLOCK_N <= HEAD_DIM)
100
+ start_m = tl.program_id(0)
101
+ off_hz = tl.program_id(1)
102
+ off_z = off_hz // H
103
+ off_h = off_hz % H
104
+ qvk_offset = off_z.to(tl.int64) * stride_qz + off_h.to(tl.int64) * stride_qh
105
+
106
+ # block pointers
107
+ Q_block_ptr = tl.make_block_ptr(
108
+ base=Q + qvk_offset,
109
+ shape=(N_CTX, HEAD_DIM),
110
+ strides=(stride_qm, stride_qk),
111
+ offsets=(start_m * BLOCK_M, 0),
112
+ block_shape=(BLOCK_M, HEAD_DIM),
113
+ order=(1, 0),
114
+ )
115
+ v_order: tl.constexpr = (0, 1) if V.dtype.element_ty == tl.float8e5 else (1, 0)
116
+ V_block_ptr = tl.make_block_ptr(
117
+ base=V + qvk_offset,
118
+ shape=(N_CTX, HEAD_DIM),
119
+ strides=(stride_vk, stride_vn),
120
+ offsets=(0, 0),
121
+ block_shape=(BLOCK_N, HEAD_DIM),
122
+ order=v_order,
123
+ )
124
+ K_block_ptr = tl.make_block_ptr(
125
+ base=K + qvk_offset,
126
+ shape=(HEAD_DIM, N_CTX),
127
+ strides=(stride_kk, stride_kn),
128
+ offsets=(0, 0),
129
+ block_shape=(HEAD_DIM, BLOCK_N),
130
+ order=(0, 1),
131
+ )
132
+ O_block_ptr = tl.make_block_ptr(
133
+ base=Out + qvk_offset,
134
+ shape=(N_CTX, HEAD_DIM),
135
+ strides=(stride_om, stride_on),
136
+ offsets=(start_m * BLOCK_M, 0),
137
+ block_shape=(BLOCK_M, HEAD_DIM),
138
+ order=(1, 0),
139
+ )
140
+ # initialize offsets
141
+ offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
142
+ offs_n = tl.arange(0, BLOCK_N)
143
+ # initialize pointer to m and l
144
+ m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
145
+ l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0
146
+ acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)
147
+ # load scales
148
+ qk_scale = sm_scale
149
+ qk_scale *= 1.44269504 # 1/log(2)
150
+ # load q: it will stay in SRAM throughout
151
+ q = tl.load(Q_block_ptr)
152
+ # stage 1: off-band
153
+ # For causal = True, STAGE = 3 and _attn_fwd_inner gets 1 as its STAGE
154
+ # For causal = False, STAGE = 1, and _attn_fwd_inner gets 3 as its STAGE
155
+ if STAGE & 1:
156
+ acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr, #
157
+ start_m, qk_scale, #
158
+ BLOCK_M, HEAD_DIM, BLOCK_N, #
159
+ 4 - STAGE, offs_m, offs_n, N_CTX, V.dtype.element_ty == tl.float8e5 #
160
+ )
161
+ # stage 2: on-band
162
+ if STAGE & 2:
163
+ # barrier makes it easier for compielr to schedule the
164
+ # two loops independently
165
+ acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr, #
166
+ start_m, qk_scale, #
167
+ BLOCK_M, HEAD_DIM, BLOCK_N, #
168
+ 2, offs_m, offs_n, N_CTX, V.dtype.element_ty == tl.float8e5 #
169
+ )
170
+ # epilogue
171
+ m_i += tl.math.log2(l_i)
172
+ acc = acc / l_i[:, None]
173
+ m_ptrs = M + off_hz * N_CTX + offs_m
174
+ tl.store(m_ptrs, m_i)
175
+ tl.store(O_block_ptr, acc.to(Out.type.element_ty))
176
+
177
+
178
+ @triton.jit
179
+ def _attn_bwd_preprocess(O, DO, #
180
+ Delta, #
181
+ Z, H, N_CTX, #
182
+ BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr #
183
+ ):
184
+ off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
185
+ off_hz = tl.program_id(1)
186
+ off_n = tl.arange(0, HEAD_DIM)
187
+ # load
188
+ o = tl.load(O + off_hz * HEAD_DIM * N_CTX + off_m[:, None] * HEAD_DIM + off_n[None, :])
189
+ do = tl.load(DO + off_hz * HEAD_DIM * N_CTX + off_m[:, None] * HEAD_DIM + off_n[None, :]).to(tl.float32)
190
+ delta = tl.sum(o * do, axis=1)
191
+ # write-back
192
+ tl.store(Delta + off_hz * N_CTX + off_m, delta)
193
+
194
+
195
+ # The main inner-loop logic for computing dK and dV.
196
+ @triton.jit
197
+ def _attn_bwd_dkdv(dk, dv, #
198
+ Q, k, v, sm_scale, #
199
+ DO, #
200
+ M, D, #
201
+ # shared by Q/K/V/DO.
202
+ stride_tok, stride_d, #
203
+ H, N_CTX, BLOCK_M1: tl.constexpr, #
204
+ BLOCK_N1: tl.constexpr, #
205
+ HEAD_DIM: tl.constexpr, #
206
+ # Filled in by the wrapper.
207
+ start_n, start_m, num_steps, #
208
+ MASK: tl.constexpr):
209
+ offs_m = start_m + tl.arange(0, BLOCK_M1)
210
+ offs_n = start_n + tl.arange(0, BLOCK_N1)
211
+ offs_k = tl.arange(0, HEAD_DIM)
212
+ qT_ptrs = Q + offs_m[None, :] * stride_tok + offs_k[:, None] * stride_d
213
+ do_ptrs = DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d
214
+ # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
215
+ tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
216
+ curr_m = start_m
217
+ step_m = BLOCK_M1
218
+ for blk_idx in range(num_steps):
219
+ qT = tl.load(qT_ptrs)
220
+ # Load m before computing qk to reduce pipeline stall.
221
+ offs_m = curr_m + tl.arange(0, BLOCK_M1)
222
+ m = tl.load(M + offs_m)
223
+ qkT = tl.dot(k, qT)
224
+ pT = tl.math.exp2(qkT - m[None, :])
225
+ # Autoregressive masking.
226
+ if MASK:
227
+ mask = (offs_m[None, :] >= offs_n[:, None])
228
+ pT = tl.where(mask, pT, 0.0)
229
+ do = tl.load(do_ptrs)
230
+ # Compute dV.
231
+ ppT = pT
232
+ ppT = ppT.to(tl.float16)
233
+ dv += tl.dot(ppT, do)
234
+ # D (= delta) is pre-divided by ds_scale.
235
+ Di = tl.load(D + offs_m)
236
+ # Compute dP and dS.
237
+ dpT = tl.dot(v, tl.trans(do)).to(tl.float32)
238
+ dsT = pT * (dpT - Di[None, :])
239
+ dsT = dsT.to(tl.float16)
240
+ dk += tl.dot(dsT, tl.trans(qT))
241
+ # Increment pointers.
242
+ curr_m += step_m
243
+ qT_ptrs += step_m * stride_tok
244
+ do_ptrs += step_m * stride_tok
245
+ return dk, dv
246
+
247
+
248
+ # the main inner-loop logic for computing dQ
249
+ @triton.jit
250
+ def _attn_bwd_dq(dq, q, K, V, #
251
+ do, m, D,
252
+ # shared by Q/K/V/DO.
253
+ stride_tok, stride_d, #
254
+ H, N_CTX, #
255
+ BLOCK_M2: tl.constexpr, #
256
+ BLOCK_N2: tl.constexpr, #
257
+ HEAD_DIM: tl.constexpr,
258
+ # Filled in by the wrapper.
259
+ start_m, start_n, num_steps, #
260
+ MASK: tl.constexpr):
261
+ offs_m = start_m + tl.arange(0, BLOCK_M2)
262
+ offs_n = start_n + tl.arange(0, BLOCK_N2)
263
+ offs_k = tl.arange(0, HEAD_DIM)
264
+ kT_ptrs = K + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d
265
+ vT_ptrs = V + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d
266
+ # D (= delta) is pre-divided by ds_scale.
267
+ Di = tl.load(D + offs_m)
268
+ # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
269
+ tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0)
270
+ curr_n = start_n
271
+ step_n = BLOCK_N2
272
+ for blk_idx in range(num_steps):
273
+ kT = tl.load(kT_ptrs)
274
+ vT = tl.load(vT_ptrs)
275
+ qk = tl.dot(q, kT)
276
+ p = tl.math.exp2(qk - m)
277
+ # Autoregressive masking.
278
+ if MASK:
279
+ offs_n = curr_n + tl.arange(0, BLOCK_N2)
280
+ mask = (offs_m[:, None] >= offs_n[None, :])
281
+ p = tl.where(mask, p, 0.0)
282
+ # Compute dP and dS.
283
+ dp = tl.dot(do, vT).to(tl.float32)
284
+ ds = p * (dp - Di[:, None])
285
+ ds = ds.to(tl.float16)
286
+ # Compute dQ.
287
+ # NOTE: We need to de-scale dq in the end, because kT was pre-scaled.
288
+ dq += tl.dot(ds, tl.trans(kT))
289
+ # Increment pointers.
290
+ curr_n += step_n
291
+ kT_ptrs += step_n * stride_tok
292
+ vT_ptrs += step_n * stride_tok
293
+ return dq
294
+
295
+
296
+ @triton.jit
297
+ def _attn_bwd(Q, K, V, sm_scale, #
298
+ DO, #
299
+ DQ, DK, DV, #
300
+ M, D,
301
+ # shared by Q/K/V/DO.
302
+ stride_z, stride_h, stride_tok, stride_d, #
303
+ H, N_CTX, #
304
+ BLOCK_M1: tl.constexpr, #
305
+ BLOCK_N1: tl.constexpr, #
306
+ BLOCK_M2: tl.constexpr, #
307
+ BLOCK_N2: tl.constexpr, #
308
+ BLK_SLICE_FACTOR: tl.constexpr, #
309
+ HEAD_DIM: tl.constexpr):
310
+ LN2: tl.constexpr = 0.6931471824645996 # = ln(2)
311
+
312
+ bhid = tl.program_id(2)
313
+ off_chz = (bhid * N_CTX).to(tl.int64)
314
+ adj = (stride_h * (bhid % H) + stride_z * (bhid // H)).to(tl.int64)
315
+ pid = tl.program_id(0)
316
+
317
+ # offset pointers for batch/head
318
+ Q += adj
319
+ K += adj
320
+ V += adj
321
+ DO += adj
322
+ DQ += adj
323
+ DK += adj
324
+ DV += adj
325
+ M += off_chz
326
+ D += off_chz
327
+
328
+ # load scales
329
+ offs_k = tl.arange(0, HEAD_DIM)
330
+
331
+ start_n = pid * BLOCK_N1
332
+ start_m = start_n
333
+
334
+ MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 // BLK_SLICE_FACTOR
335
+ offs_n = start_n + tl.arange(0, BLOCK_N1)
336
+
337
+ dv = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32)
338
+ dk = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32)
339
+
340
+ # load K and V: they stay in SRAM throughout the inner loop.
341
+ k = tl.load(K + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d)
342
+ v = tl.load(V + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d)
343
+
344
+ num_steps = BLOCK_N1 // MASK_BLOCK_M1
345
+
346
+ dk, dv = _attn_bwd_dkdv(dk, dv, #
347
+ Q, k, v, sm_scale, #
348
+ DO, #
349
+ M, D, #
350
+ stride_tok, stride_d, #
351
+ H, N_CTX, #
352
+ MASK_BLOCK_M1, BLOCK_N1, HEAD_DIM, #
353
+ start_n, start_m, num_steps, #
354
+ MASK=True #
355
+ )
356
+
357
+ start_m += num_steps * MASK_BLOCK_M1
358
+ num_steps = (N_CTX - start_m) // BLOCK_M1
359
+
360
+ # Compute dK and dV for non-masked blocks.
361
+ dk, dv = _attn_bwd_dkdv( #
362
+ dk, dv, #
363
+ Q, k, v, sm_scale, #
364
+ DO, #
365
+ M, D, #
366
+ stride_tok, stride_d, #
367
+ H, N_CTX, #
368
+ BLOCK_M1, BLOCK_N1, HEAD_DIM, #
369
+ start_n, start_m, num_steps, #
370
+ MASK=False #
371
+ )
372
+
373
+ dv_ptrs = DV + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d
374
+ tl.store(dv_ptrs, dv)
375
+
376
+ # Write back dK.
377
+ dk *= sm_scale
378
+ dk_ptrs = DK + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d
379
+ tl.store(dk_ptrs, dk)
380
+
381
+ # THIS BLOCK DOES DQ:
382
+ start_m = pid * BLOCK_M2
383
+ end_n = start_m + BLOCK_M2
384
+
385
+ MASK_BLOCK_N2: tl.constexpr = BLOCK_N2 // BLK_SLICE_FACTOR
386
+ offs_m = start_m + tl.arange(0, BLOCK_M2)
387
+
388
+ q = tl.load(Q + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d)
389
+ dq = tl.zeros([BLOCK_M2, HEAD_DIM], dtype=tl.float32)
390
+ do = tl.load(DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d)
391
+
392
+ m = tl.load(M + offs_m)
393
+ m = m[:, None]
394
+
395
+ # Compute dQ for masked (diagonal) blocks.
396
+ # NOTE: This code scans each row of QK^T backward (from right to left,
397
+ # but inside each call to _attn_bwd_dq, from left to right), but that's
398
+ # not due to anything important. I just wanted to reuse the loop
399
+ # structure for dK & dV above as much as possible.
400
+ num_steps = BLOCK_M2 // MASK_BLOCK_N2
401
+ dq = _attn_bwd_dq(dq, q, K, V, #
402
+ do, m, D, #
403
+ stride_tok, stride_d, #
404
+ H, N_CTX, #
405
+ BLOCK_M2, MASK_BLOCK_N2, HEAD_DIM, #
406
+ start_m, end_n - num_steps * MASK_BLOCK_N2, num_steps, #
407
+ MASK=True #
408
+ )
409
+ end_n -= num_steps * MASK_BLOCK_N2
410
+ # stage 2
411
+ num_steps = end_n // BLOCK_N2
412
+ dq = _attn_bwd_dq(dq, q, K, V, #
413
+ do, m, D, #
414
+ stride_tok, stride_d, #
415
+ H, N_CTX, #
416
+ BLOCK_M2, BLOCK_N2, HEAD_DIM, #
417
+ start_m, end_n - num_steps * BLOCK_N2, num_steps, #
418
+ MASK=False #
419
+ )
420
+ # Write back dQ.
421
+ dq_ptrs = DQ + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d
422
+ dq *= LN2
423
+ tl.store(dq_ptrs, dq)
424
+
425
+
426
+ class _attention(torch.autograd.Function):
427
+
428
+ @staticmethod
429
+ def forward(ctx, q, k, v, causal, sm_scale):
430
+ # shape constraints
431
+ HEAD_DIM_Q, HEAD_DIM_K = q.shape[-1], k.shape[-1]
432
+ # when v is in float8_e5m2 it is transposed.
433
+ HEAD_DIM_V = v.shape[-1]
434
+ assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V
435
+ assert HEAD_DIM_K in {16, 32, 64, 128, 256}
436
+ o = torch.empty_like(q)
437
+ stage = 3 if causal else 1
438
+ extra_kern_args = {}
439
+ # Tuning for AMD target
440
+ if is_hip():
441
+ waves_per_eu = 3 if HEAD_DIM_K <= 64 else 2
442
+ extra_kern_args = {"waves_per_eu": waves_per_eu, "allow_flush_denorm": True}
443
+
444
+ grid = lambda args: (triton.cdiv(q.shape[2], args["BLOCK_M"]), q.shape[0] * q.shape[1], 1)
445
+ M = torch.empty((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
446
+ _attn_fwd[grid](
447
+ q, k, v, sm_scale, M, o, #
448
+ q.stride(0), q.stride(1), q.stride(2), q.stride(3), #
449
+ k.stride(0), k.stride(1), k.stride(2), k.stride(3), #
450
+ v.stride(0), v.stride(1), v.stride(2), v.stride(3), #
451
+ o.stride(0), o.stride(1), o.stride(2), o.stride(3), #
452
+ q.shape[0], q.shape[1], #
453
+ N_CTX=q.shape[2], #
454
+ HEAD_DIM=HEAD_DIM_K, #
455
+ STAGE=stage, #
456
+ **extra_kern_args)
457
+
458
+ ctx.save_for_backward(q, k, v, o, M)
459
+ ctx.grid = grid
460
+ ctx.sm_scale = sm_scale
461
+ ctx.HEAD_DIM = HEAD_DIM_K
462
+ ctx.causal = causal
463
+ return o
464
+
465
+ @staticmethod
466
+ def backward(ctx, do):
467
+ q, k, v, o, M = ctx.saved_tensors
468
+ do = do.contiguous()
469
+ assert do.is_contiguous()
470
+ assert q.stride() == k.stride() == v.stride() == o.stride() == do.stride()
471
+ dq = torch.empty_like(q)
472
+ dk = torch.empty_like(k)
473
+ dv = torch.empty_like(v)
474
+ BATCH, N_HEAD, N_CTX = q.shape[:3]
475
+ PRE_BLOCK = 128
476
+ NUM_WARPS, NUM_STAGES = 4, 5
477
+ BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 128, 128, 32
478
+ BLK_SLICE_FACTOR = 2
479
+ RCP_LN2 = 1.4426950408889634 # = 1.0 / ln(2)
480
+ arg_k = k
481
+ arg_k = arg_k * (ctx.sm_scale * RCP_LN2)
482
+ PRE_BLOCK = 128
483
+ assert N_CTX % PRE_BLOCK == 0
484
+ pre_grid = (N_CTX // PRE_BLOCK, BATCH * N_HEAD)
485
+ delta = torch.empty_like(M)
486
+ _attn_bwd_preprocess[pre_grid](
487
+ o, do, #
488
+ delta, #
489
+ BATCH, N_HEAD, N_CTX, #
490
+ BLOCK_M=PRE_BLOCK, HEAD_DIM=ctx.HEAD_DIM #
491
+ )
492
+ grid = (N_CTX // BLOCK_N1, 1, BATCH * N_HEAD)
493
+ _attn_bwd[grid](
494
+ q, arg_k, v, ctx.sm_scale, do, dq, dk, dv, #
495
+ M, delta, #
496
+ q.stride(0), q.stride(1), q.stride(2), q.stride(3), #
497
+ N_HEAD, N_CTX, #
498
+ BLOCK_M1=BLOCK_M1, BLOCK_N1=BLOCK_N1, #
499
+ BLOCK_M2=BLOCK_M2, BLOCK_N2=BLOCK_N2, #
500
+ BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, #
501
+ HEAD_DIM=ctx.HEAD_DIM, #
502
+ num_warps=NUM_WARPS, #
503
+ num_stages=NUM_STAGES #
504
+ )
505
+
506
+ return dq, dk, dv, None, None
507
+
508
+
509
+ attention = _attention.apply
510
+
511
+ from torch import nn
512
+ from torch.nn import MSELoss
513
+ import math
514
+ if __name__ == '__main__':
515
+ B, H, N_CTX, D = 1, 32, 84, 16
516
+ causal = True
517
+ DTYPE = torch.float32
518
+ torch.manual_seed(20)
519
+ gain_q = math.sqrt(5.0 / D)
520
+ gain_k = math.sqrt(5.0 / D)
521
+ gain_v = math.sqrt(3.0 / D)
522
+
523
+ q = (torch.randn((B, H, N_CTX, D), dtype=DTYPE, device=DEVICE) * gain_q).requires_grad_()
524
+ k = (torch.randn((B, H, N_CTX, D), dtype=DTYPE, device=DEVICE) * gain_k).requires_grad_()
525
+ v = (torch.randn((B, H, N_CTX, D), dtype=DTYPE, device=DEVICE) * gain_v).requires_grad_()
526
+
527
+ sm_scale = 1.0 / math.sqrt(D)
528
+ att_output_triton = attention(q.to(torch.float16), k.to(torch.float16), v.to(torch.float16), causal, sm_scale)
529
+
530
+ M = torch.tril(torch.ones((N_CTX, N_CTX), device=DEVICE))
531
+
532
+ attn_weights = torch.matmul(q, k.transpose(-2, -1)) * sm_scale
533
+ if causal:
534
+ causal_mask = torch.triu(torch.ones(N_CTX, N_CTX, device=attn_weights.device), diagonal=1).bool()
535
+ attn_weights = attn_weights.masked_fill(causal_mask.unsqueeze(0).unsqueeze(0), float('-inf'))
536
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1).to(v.dtype)
537
+ attn_output_original = torch.matmul(attn_weights, v)
538
+
539
+ print(f'att_output_triton: {att_output_triton}')
540
+ print(f'attn_output_original: {attn_output_original}')
541
+ # assert torch.allclose(attn_output_original, att_output_triton, atol=1e-2, rtol=0), f"Error: {torch.mean(torch.abs(attn_output_original - att_output_triton))}"
542
+ perc_diff = 100 * torch.abs(att_output_triton - attn_output_original).mean() / torch.abs(attn_output_original).mean()
543
+ print(f'passed test for mse loss with {att_output_triton.mean()} vs {attn_output_original.mean()}, \t\t\t\tpercentage diff: {perc_diff.item()}%')
triton_kernels/flash_attn_mse_loss.py ADDED
@@ -0,0 +1,733 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ import triton
4
+ import triton.language as tl
5
+
6
+ import torch
7
+ import math
8
+ DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
9
+
10
+
11
+ def is_hip():
12
+ return triton.runtime.driver.active.get_current_target().backend == "hip"
13
+
14
+ @triton.jit
15
+ def _attn_fwd_inner(acc, l_i, m_i, q, q_imp, #
16
+ K_block_ptr, K_imp_block_ptr, V_block_ptr, #
17
+ mse_loss, #
18
+ start_m, qk_sqrt, qk_scale, qk_sqrt_imp, qk_scale_imp, #
19
+ BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr, NUM_ELEMENTS: tl.constexpr, BLOCK_N: tl.constexpr, #
20
+ STAGE: tl.constexpr, offs_m: tl.constexpr, offs_n: tl.constexpr, #
21
+ N_CTX: tl.constexpr, fp8_v: tl.constexpr):
22
+ # range of values handled by this stage
23
+ if STAGE == 1:
24
+ lo, hi = 0, start_m * BLOCK_M
25
+ elif STAGE == 2:
26
+ lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M
27
+ lo = tl.multiple_of(lo, BLOCK_M)
28
+ # causal = False
29
+ else:
30
+ lo, hi = 0, N_CTX
31
+ K_block_ptr = tl.advance(K_block_ptr, (0, lo))
32
+ K_imp_block_ptr = tl.advance(K_imp_block_ptr, (0, lo))
33
+ V_block_ptr = tl.advance(V_block_ptr, (lo, 0))
34
+ mse_contrib_total = 0.0
35
+ # loop over k, v and update accumulator
36
+ for start_n in range(lo, hi, BLOCK_N):
37
+ start_n = tl.multiple_of(start_n, BLOCK_N)
38
+ # -- compute qk ----
39
+ k = tl.load(K_block_ptr).to(tl.float32)
40
+ k_imp = tl.load(K_imp_block_ptr).to(tl.float32)
41
+ qf32 = q.to(tl.float32)
42
+ qf32_imp = q_imp.to(tl.float32)
43
+ qk = tl.dot(qf32, k)
44
+ qk_imp = tl.dot(qf32_imp, k_imp)
45
+ diff = (qk * qk_sqrt - qk_imp * qk_sqrt_imp)
46
+ diff_sqr = diff * diff
47
+ mse_contrib = tl.sum(diff_sqr)
48
+ # tl.atomic_add(mse_loss, mse_contrib)
49
+ mse_contrib_total += mse_contrib
50
+ if STAGE == 2:
51
+ mask = offs_m[:, None] >= (start_n + offs_n[None, :])
52
+ qk = qk * qk_scale + tl.where(mask, 0, -1.0e6)
53
+ m_ij = tl.maximum(m_i, tl.max(qk, 1))
54
+ qk -= m_ij[:, None]
55
+ else:
56
+ m_ij = tl.maximum(m_i, tl.max(qk, 1) * qk_scale)
57
+ qk = qk * qk_scale - m_ij[:, None]
58
+ p = tl.math.exp2(qk)
59
+ l_ij = tl.sum(p, 1)
60
+ # -- update m_i and l_i
61
+ alpha = tl.math.exp2(m_i - m_ij)
62
+ l_i = l_i * alpha + l_ij
63
+ # -- update output accumulator --
64
+ acc = acc * alpha[:, None]
65
+ # update acc
66
+ v = tl.load(V_block_ptr)
67
+ p = p.to(v.type.element_ty)
68
+ acc = tl.dot(p, v, acc)
69
+ # update m_i and l_i
70
+ m_i = m_ij
71
+ V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))
72
+ K_imp_block_ptr = tl.advance(K_imp_block_ptr, (0, BLOCK_N))
73
+ K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
74
+
75
+ if STAGE == 2:
76
+ for start_n in range(hi, N_CTX, BLOCK_N):
77
+ start_n = tl.multiple_of(start_n, BLOCK_N)
78
+ # -- compute qk ----
79
+ k = tl.load(K_block_ptr).to(tl.float32)
80
+ k_imp = tl.load(K_imp_block_ptr).to(tl.float32)
81
+ qf32 = q.to(tl.float32)
82
+ qf32_imp = q_imp.to(tl.float32)
83
+ qk = tl.dot(qf32, k)
84
+ qk_imp = tl.dot(qf32_imp, k_imp)
85
+ diff = (qk * qk_sqrt - qk_imp * qk_sqrt_imp)
86
+ diff_sqr = diff * diff
87
+ mse_contrib = tl.sum(diff_sqr)
88
+ # tl.atomic_add(mse_loss, mse_contrib)
89
+ mse_contrib_total += mse_contrib
90
+ K_imp_block_ptr = tl.advance(K_imp_block_ptr, (0, BLOCK_N))
91
+ K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
92
+ mse_contrib_total /= NUM_ELEMENTS
93
+ tl.atomic_add(mse_loss, mse_contrib_total)
94
+ return acc, l_i, m_i
95
+
96
+
97
+ # We don't run auto-tuning every time to keep the tutorial fast. Keeping
98
+ # the code below and commenting out the equivalent parameters is convenient for
99
+ # re-tuning.
100
+ # configs = [
101
+ # triton.Config({'BLOCK_M': BM, 'BLOCK_N': BN}, num_stages=s, num_warps=w) \
102
+ # for BM in [32]\
103
+ # for BN in [32]\
104
+ # for s in ([1] if is_hip() else [3, 4, 7])\
105
+ # for w in [4, 8]\
106
+ # ]
107
+
108
+ fixed_config = triton.Config(
109
+ {'BLOCK_M': 16, 'BLOCK_N': 16},
110
+ num_stages=4,
111
+ num_warps=8
112
+ )
113
+
114
+ def keep(conf):
115
+ BLOCK_M = conf.kwargs["BLOCK_M"]
116
+ BLOCK_N = conf.kwargs["BLOCK_N"]
117
+ if BLOCK_M * BLOCK_N < 128 * 128 and conf.num_warps == 8:
118
+ return False
119
+ return True
120
+
121
+
122
+ # @triton.autotune(list(filter(keep, configs)), key=["N_CTX", "HEAD_DIM"])
123
+ @triton.jit
124
+ def _attn_fwd(Q, K, V, sm_scale, sm_scale_imp, M, Out, #
125
+ Q_importance, K_importance, mse_loss,
126
+ stride_qz, stride_qh, stride_qm, stride_qk, #
127
+ stride_kz, stride_kh, stride_kn, stride_kk, #
128
+ stride_vz, stride_vh, stride_vk, stride_vn, #
129
+ stride_oz, stride_oh, stride_om, stride_on, #
130
+ stride_qz_imp, stride_qh_imp, stride_qm_imp, stride_qk_imp, #
131
+ stride_kz_imp, stride_kh_imp, stride_kn_imp, stride_kk_imp, #
132
+ Z, H, N_CTX, #
133
+ HEAD_DIM: tl.constexpr, #
134
+ D_DASH: tl.constexpr, #
135
+ NUM_ELEMENTS: tl.constexpr, #
136
+ BLOCK_M: tl.constexpr, #
137
+ BLOCK_N: tl.constexpr, #
138
+ STAGE: tl.constexpr #
139
+ ):
140
+ tl.static_assert(BLOCK_N <= HEAD_DIM)
141
+ tl.static_assert(BLOCK_N <= D_DASH)
142
+ start_m = tl.program_id(0)
143
+ off_hz = tl.program_id(1)
144
+ off_z = off_hz // H
145
+ off_h = off_hz % H
146
+ qvk_offset = off_z.to(tl.int64) * stride_qz + off_h.to(tl.int64) * stride_qh
147
+ qv_imp_offset = off_z.to(tl.int64) * stride_qz_imp + off_h.to(tl.int64) * stride_qh_imp
148
+
149
+ # block pointers
150
+ Q_block_ptr = tl.make_block_ptr(
151
+ base=Q + qvk_offset,
152
+ shape=(N_CTX, HEAD_DIM),
153
+ strides=(stride_qm, stride_qk),
154
+ offsets=(start_m * BLOCK_M, 0),
155
+ block_shape=(BLOCK_M, HEAD_DIM),
156
+ order=(1, 0),
157
+ )
158
+
159
+ v_order: tl.constexpr = (0, 1) if V.dtype.element_ty == tl.float8e5 else (1, 0)
160
+ V_block_ptr = tl.make_block_ptr(
161
+ base=V + qvk_offset,
162
+ shape=(N_CTX, HEAD_DIM),
163
+ strides=(stride_vk, stride_vn),
164
+ offsets=(0, 0),
165
+ block_shape=(BLOCK_N, HEAD_DIM),
166
+ order=v_order,
167
+ )
168
+ K_block_ptr = tl.make_block_ptr(
169
+ base=K + qvk_offset,
170
+ shape=(HEAD_DIM, N_CTX),
171
+ strides=(stride_kk, stride_kn),
172
+ offsets=(0, 0),
173
+ block_shape=(HEAD_DIM, BLOCK_N),
174
+ order=(0, 1),
175
+ )
176
+ O_block_ptr = tl.make_block_ptr(
177
+ base=Out + qvk_offset,
178
+ shape=(N_CTX, HEAD_DIM),
179
+ strides=(stride_om, stride_on),
180
+ offsets=(start_m * BLOCK_M, 0),
181
+ block_shape=(BLOCK_M, HEAD_DIM),
182
+ order=(1, 0),
183
+ )
184
+ Q_imp_block_ptr = tl.make_block_ptr(
185
+ base=Q_importance + qv_imp_offset,
186
+ shape=(N_CTX, D_DASH),
187
+ strides=(stride_qm_imp, stride_qk_imp),
188
+ offsets=(start_m * BLOCK_M, 0),
189
+ block_shape=(BLOCK_M, D_DASH),
190
+ order=(1, 0),
191
+ )
192
+ K_imp_block_ptr = tl.make_block_ptr(
193
+ base=K_importance + qv_imp_offset,
194
+ shape=(D_DASH, N_CTX),
195
+ strides=(stride_kk_imp, stride_kn_imp),
196
+ offsets=(0, 0),
197
+ block_shape=(D_DASH, BLOCK_N),
198
+ order=(0, 1),
199
+ )
200
+ # initialize offsets
201
+ offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
202
+ offs_n = tl.arange(0, BLOCK_N)
203
+ # initialize pointer to m and l
204
+ m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
205
+ l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0
206
+ acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)
207
+ # load scales
208
+ qk_sqrt = sm_scale
209
+ qk_scale = qk_sqrt * 1.44269504 # 1/log(2)
210
+ qk_sqrt_imp = sm_scale_imp
211
+ qk_scale_imp = qk_sqrt_imp * 1.44269504
212
+ # load q: it will stay in SRAM throughout
213
+ q = tl.load(Q_block_ptr)
214
+ q_imp = tl.load(Q_imp_block_ptr)
215
+ # stage 1: off-band
216
+ # For causal = True, STAGE = 3 and _attn_fwd_inner gets 1 as its STAGE
217
+ # For causal = False, STAGE = 1, and _attn_fwd_inner gets 3 as its STAGE
218
+ # _attn_fwd_get_loss(q, q_imp, K_block_ptr, K_imp_block_ptr, V_block_ptr, mse_loss, start_m, qk_scale, qk_scale_imp, BLOCK_M, HEAD_DIM, BLOCK_N, STAGE, offs_m, offs_n, N_CTX, V.dtype.element_ty == tl.float8e5)
219
+ if STAGE & 1:
220
+ acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, q_imp, K_block_ptr, K_imp_block_ptr, V_block_ptr, #
221
+ mse_loss, #
222
+ start_m, qk_sqrt, qk_scale, qk_sqrt_imp, qk_scale_imp, #
223
+ BLOCK_M, HEAD_DIM, NUM_ELEMENTS, BLOCK_N, #
224
+ 4 - STAGE, offs_m, offs_n, N_CTX, V.dtype.element_ty == tl.float8e5 #
225
+ )
226
+ # stage 2: on-band
227
+ if STAGE & 2:
228
+ # barrier makes it easier for compielr to schedule the
229
+ # two loops independently
230
+ acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, q_imp, K_block_ptr, K_imp_block_ptr, V_block_ptr, #
231
+ mse_loss, #
232
+ start_m, qk_sqrt, qk_scale, qk_sqrt_imp, qk_scale_imp, #
233
+ BLOCK_M, HEAD_DIM, NUM_ELEMENTS, BLOCK_N, #
234
+ 2, offs_m, offs_n, N_CTX, V.dtype.element_ty == tl.float8e5 #
235
+ )
236
+ # epilogue
237
+ m_i += tl.math.log2(l_i)
238
+ acc = acc / l_i[:, None]
239
+ m_ptrs = M + off_hz * N_CTX + offs_m
240
+ tl.store(m_ptrs, m_i)
241
+ tl.store(O_block_ptr, acc.to(Out.type.element_ty))
242
+
243
+
244
+ @triton.jit
245
+ def _attn_bwd_preprocess(O, DO, #
246
+ Delta, #
247
+ Z, H, N_CTX, #
248
+ BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr #
249
+ ):
250
+ off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
251
+ off_hz = tl.program_id(1)
252
+ off_n = tl.arange(0, HEAD_DIM)
253
+ # load
254
+ o = tl.load(O + off_hz * HEAD_DIM * N_CTX + off_m[:, None] * HEAD_DIM + off_n[None, :])
255
+ do = tl.load(DO + off_hz * HEAD_DIM * N_CTX + off_m[:, None] * HEAD_DIM + off_n[None, :]).to(tl.float32)
256
+ delta = tl.sum(o * do, axis=1)
257
+ # write-back
258
+ tl.store(Delta + off_hz * N_CTX + off_m, delta)
259
+
260
+
261
+ # The main inner-loop logic for computing dK and dV.
262
+ @triton.jit
263
+ def _attn_bwd_dkdv(dk, dv, #
264
+ Q, k, v, sm_scale, #
265
+ DO, #
266
+ M, D, #
267
+ # shared by Q/K/V/DO.
268
+ stride_tok, stride_d, #
269
+ H, N_CTX, BLOCK_M1: tl.constexpr, #
270
+ BLOCK_N1: tl.constexpr, #
271
+ HEAD_DIM: tl.constexpr, #
272
+ # Filled in by the wrapper.
273
+ start_n, start_m, num_steps, #
274
+ MASK: tl.constexpr):
275
+ offs_m = start_m + tl.arange(0, BLOCK_M1)
276
+ offs_n = start_n + tl.arange(0, BLOCK_N1)
277
+ offs_k = tl.arange(0, HEAD_DIM)
278
+ qT_ptrs = Q + offs_m[None, :] * stride_tok + offs_k[:, None] * stride_d
279
+ do_ptrs = DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d
280
+ # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
281
+ tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
282
+ curr_m = start_m
283
+ step_m = BLOCK_M1
284
+ for blk_idx in range(num_steps):
285
+ qT = tl.load(qT_ptrs)
286
+ # Load m before computing qk to reduce pipeline stall.
287
+ offs_m = curr_m + tl.arange(0, BLOCK_M1)
288
+ m = tl.load(M + offs_m)
289
+ qkT = tl.dot(k, qT)
290
+ pT = tl.math.exp2(qkT - m[None, :])
291
+ # Autoregressive masking.
292
+ if MASK:
293
+ mask = (offs_m[None, :] >= offs_n[:, None])
294
+ pT = tl.where(mask, pT, 0.0)
295
+ do = tl.load(do_ptrs)
296
+ # Compute dV.
297
+ ppT = pT
298
+ ppT = ppT.to(tl.float16)
299
+ dv += tl.dot(ppT, do)
300
+ # D (= delta) is pre-divided by ds_scale.
301
+ Di = tl.load(D + offs_m)
302
+ # Compute dP and dS.
303
+ dpT = tl.dot(v, tl.trans(do)).to(tl.float32)
304
+ dsT = pT * (dpT - Di[None, :])
305
+ dsT = dsT.to(tl.float16)
306
+ dk += tl.dot(dsT, tl.trans(qT))
307
+ # Increment pointers.
308
+ curr_m += step_m
309
+ qT_ptrs += step_m * stride_tok
310
+ do_ptrs += step_m * stride_tok
311
+ return dk, dv
312
+
313
+
314
+ # the main inner-loop logic for computing dQ
315
+ @triton.jit
316
+ def _attn_bwd_dq(dq, q, K, V, #
317
+ do, m, D,
318
+ # shared by Q/K/V/DO.
319
+ stride_tok, stride_d, #
320
+ H, N_CTX, #
321
+ BLOCK_M2: tl.constexpr, #
322
+ BLOCK_N2: tl.constexpr, #
323
+ HEAD_DIM: tl.constexpr,
324
+ # Filled in by the wrapper.
325
+ start_m, start_n, num_steps, #
326
+ MASK: tl.constexpr):
327
+ offs_m = start_m + tl.arange(0, BLOCK_M2)
328
+ offs_n = start_n + tl.arange(0, BLOCK_N2)
329
+ offs_k = tl.arange(0, HEAD_DIM)
330
+ kT_ptrs = K + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d
331
+ vT_ptrs = V + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d
332
+ # D (= delta) is pre-divided by ds_scale.
333
+ Di = tl.load(D + offs_m)
334
+ # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
335
+ tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0)
336
+ curr_n = start_n
337
+ step_n = BLOCK_N2
338
+ for blk_idx in range(num_steps):
339
+ kT = tl.load(kT_ptrs)
340
+ vT = tl.load(vT_ptrs)
341
+ qk = tl.dot(q, kT)
342
+ p = tl.math.exp2(qk - m)
343
+ # Autoregressive masking.
344
+ if MASK:
345
+ offs_n = curr_n + tl.arange(0, BLOCK_N2)
346
+ mask = (offs_m[:, None] >= offs_n[None, :])
347
+ p = tl.where(mask, p, 0.0)
348
+ # Compute dP and dS.
349
+ dp = tl.dot(do, vT).to(tl.float32)
350
+ ds = p * (dp - Di[:, None])
351
+ ds = ds.to(tl.float16)
352
+ # Compute dQ.
353
+ # NOTE: We need to de-scale dq in the end, because kT was pre-scaled.
354
+ dq += tl.dot(ds, tl.trans(kT))
355
+ # Increment pointers.
356
+ curr_n += step_n
357
+ kT_ptrs += step_n * stride_tok
358
+ vT_ptrs += step_n * stride_tok
359
+ return dq
360
+
361
+ @triton.jit
362
+ def _attn_bwd_dk_imp(
363
+ Q, Q_imp, k, k_imp, sm_scale, sm_scale_imp, num_elements, #
364
+ DMSE, #
365
+ # shared by Q/K/V/DO.
366
+ stride_tok, stride_d, #
367
+ stride_tok_imp, stride_d_imp, #
368
+ H, N_CTX, BLOCK_M1: tl.constexpr, #
369
+ BLOCK_N1: tl.constexpr, #
370
+ HEAD_DIM: tl.constexpr, #
371
+ D_DASH: tl.constexpr, #
372
+ # Filled in by the wrapper.
373
+ start_n, start_m, num_steps, is_float16):
374
+ dk_imp = tl.zeros([BLOCK_N1, D_DASH], dtype=tl.float32)
375
+ offs_m = start_m + tl.arange(0, BLOCK_M1)
376
+ offs_k = tl.arange(0, HEAD_DIM)
377
+ offs_k_imp = tl.arange(0, D_DASH)
378
+ qT_ptrs = Q + offs_m[None, :] * stride_tok + offs_k[:, None] * stride_d
379
+ qT_imp_ptrs = Q_imp + offs_m[None, :] * stride_tok_imp + offs_k_imp[:, None] * stride_d_imp
380
+ # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
381
+ tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
382
+ curr_m = start_m
383
+ step_m = BLOCK_M1
384
+ dmse_eval = tl.load(DMSE)
385
+ for blk_idx in range(num_steps):
386
+ qT = tl.load(qT_ptrs)
387
+ qT_imp = tl.load(qT_imp_ptrs)
388
+ # Load m before computing qk to reduce pipeline stall.
389
+ offs_m = curr_m + tl.arange(0, BLOCK_M1)
390
+ qkT = tl.dot(k, qT)
391
+ qkT_imp = tl.dot(k_imp, qT_imp)
392
+ diff = (qkT_imp * sm_scale_imp - qkT * sm_scale)
393
+ tmp = dmse_eval * 2.0 * (1 / num_elements) * sm_scale_imp
394
+ diff = diff.to(tl.float16)
395
+ qT_imp = qT_imp.to(tl.float16)
396
+ dk_imp += tmp * tl.dot(diff, tl.trans(qT_imp))
397
+ if not is_float16:
398
+ dk_imp = dk_imp.to(tl.float32)
399
+
400
+ # Increment pointers.
401
+ curr_m += step_m
402
+ qT_ptrs += step_m * stride_tok
403
+ qT_imp_ptrs += step_m * stride_tok_imp
404
+ return dk_imp
405
+
406
+ @triton.jit
407
+ def _attn_bwd_dq_imp(
408
+ q, q_imp, K, K_imp, sm_scale, sm_scale_imp, num_elements, #
409
+ DMSE, #
410
+ # shared by Q/K/V/DO.
411
+ stride_tok, stride_d, #
412
+ stride_tok_imp, stride_d_imp, #
413
+ H, N_CTX, BLOCK_M1: tl.constexpr, #
414
+ BLOCK_N1: tl.constexpr, #
415
+ HEAD_DIM: tl.constexpr, #
416
+ D_DASH: tl.constexpr, #
417
+ # Filled in by the wrapper.
418
+ start_n, start_m, num_steps, is_float16):
419
+ dq_imp = tl.zeros([BLOCK_N1, D_DASH], dtype=tl.float32)
420
+ offs_m = start_m + tl.arange(0, BLOCK_M1)
421
+ offs_k = tl.arange(0, HEAD_DIM)
422
+ offs_k_imp = tl.arange(0, D_DASH)
423
+ kT_ptrs = K + offs_m[None, :] * stride_tok + offs_k[:, None] * stride_d
424
+ kT_imp_ptrs = K_imp + offs_m[None, :] * stride_tok_imp + offs_k_imp[:, None] * stride_d_imp
425
+ # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
426
+ tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
427
+ curr_m = start_m
428
+ step_m = BLOCK_M1
429
+ dmse_eval = tl.load(DMSE)
430
+ for blk_idx in range(num_steps):
431
+ kT = tl.load(kT_ptrs)
432
+ kT_imp = tl.load(kT_imp_ptrs)
433
+ # Load m before computing qk to reduce pipeline stall.
434
+ offs_m = curr_m + tl.arange(0, BLOCK_M1)
435
+ qkT = tl.dot(q, kT)
436
+ qkT_imp = tl.dot(q_imp, kT_imp)
437
+ diff = (qkT_imp * sm_scale_imp - qkT * sm_scale)
438
+ tmp = dmse_eval * 2.0 * (1 / num_elements) * sm_scale_imp
439
+ diff = diff.to(tl.float16)
440
+ kT_imp = kT_imp.to(tl.float16)
441
+ dq_imp += tmp * tl.dot(diff, tl.trans(kT_imp))
442
+ if not is_float16:
443
+ dq_imp = dq_imp.to(tl.float32)
444
+ # Increment pointers.
445
+ curr_m += step_m
446
+ kT_ptrs += step_m * stride_tok
447
+ kT_imp_ptrs += step_m * stride_tok_imp
448
+ return dq_imp
449
+
450
+
451
+ @triton.jit
452
+ def _attn_bwd(Q, Q_imp, K, K_imp, sm_scale, sm_scale_imp, num_elements, #
453
+ DQ_imp, DK_imp, DMSE, #
454
+ # shared by Q/K/V/DO.
455
+ stride_z, stride_h, stride_tok, stride_d, #
456
+ stride_z_imp, stride_h_imp, stride_tok_imp, stride_d_imp, #
457
+ H, N_CTX, #
458
+ BLOCK_M1: tl.constexpr, #
459
+ BLOCK_N1: tl.constexpr, #
460
+ HEAD_DIM: tl.constexpr,
461
+ D_DASH: tl.constexpr):
462
+ LN2: tl.constexpr = 0.6931471824645996 # = ln(2)
463
+
464
+ bhid = tl.program_id(2)
465
+ off_chz = (bhid * N_CTX).to(tl.int64)
466
+ adj = (stride_h * (bhid % H) + stride_z * (bhid // H)).to(tl.int64)
467
+ adj_imp = (stride_h_imp * (bhid % H) + stride_z_imp * (bhid // H)).to(tl.int64)
468
+ pid = tl.program_id(0)
469
+
470
+ # offset pointers for batch/head
471
+ Q += adj
472
+ K += adj
473
+ Q_imp += adj_imp
474
+ K_imp += adj_imp
475
+ DQ_imp += adj_imp
476
+ DK_imp += adj_imp
477
+
478
+ # load scales
479
+ offs_k = tl.arange(0, HEAD_DIM)
480
+ offs_k_imp = tl.arange(0, D_DASH)
481
+
482
+ start_n = pid * BLOCK_N1
483
+ start_m = 0
484
+ offs_n = start_n + tl.arange(0, BLOCK_N1)
485
+ num_steps = N_CTX // BLOCK_M1
486
+ k = tl.load(K + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d)
487
+ k_imp = tl.load(K_imp + offs_n[:, None] * stride_tok_imp + offs_k_imp[None, :] * stride_d_imp)
488
+ dk_imp = _attn_bwd_dk_imp(
489
+ Q, Q_imp, k, k_imp, sm_scale, sm_scale_imp, num_elements, #
490
+ DMSE, #
491
+ stride_tok, stride_d, #
492
+ stride_tok_imp, stride_d_imp, #
493
+ H, N_CTX, BLOCK_M1, #
494
+ BLOCK_N1, HEAD_DIM, D_DASH, #
495
+ start_n, start_m, num_steps, k.dtype == tl.float16 #
496
+ )
497
+ dk_imp_ptrs = DK_imp + offs_n[:, None] * stride_tok_imp + offs_k_imp[None, :] * stride_d_imp
498
+ tl.store(dk_imp_ptrs, dk_imp)
499
+
500
+ start_n = pid * BLOCK_N1
501
+ start_m = 0
502
+ num_steps = N_CTX // BLOCK_M1
503
+ q = tl.load(Q + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d)
504
+ q_imp = tl.load(Q_imp + offs_n[:, None] * stride_tok_imp + offs_k_imp[None, :] * stride_d_imp)
505
+ dq_imp = _attn_bwd_dq_imp(
506
+ q, q_imp, K, K_imp, sm_scale, sm_scale_imp, num_elements, #
507
+ DMSE, #
508
+ stride_tok, stride_d, #
509
+ stride_tok_imp, stride_d_imp, #
510
+ H, N_CTX, BLOCK_M1, #
511
+ BLOCK_N1, HEAD_DIM, D_DASH, #
512
+ start_n, start_m, num_steps, q.dtype == tl.float16 #
513
+ )
514
+ dq_imp_ptrs = DQ_imp + offs_n[:, None] * stride_tok_imp + offs_k_imp[None, :] * stride_d_imp
515
+ tl.store(dq_imp_ptrs, dq_imp)
516
+
517
+
518
+ class _attention_mse_loss(torch.autograd.Function):
519
+
520
+ @staticmethod
521
+ def forward(ctx, q, k, v, q_importance, k_importance, causal):
522
+ # shape constraints
523
+ HEAD_DIM_Q, HEAD_DIM_K = q.shape[-1], k.shape[-1]
524
+ # when v is in float8_e5m2 it is transposed.
525
+ HEAD_DIM_V = v.shape[-1]
526
+ assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V
527
+ assert HEAD_DIM_K in {16, 32, 64, 128, 256}
528
+ D_DASH = q_importance.shape[-1]
529
+ assert D_DASH == k_importance.shape[-1], "q_importance and k_importance must have the same last dimension"
530
+ sm_scale = 1.0 / math.sqrt(HEAD_DIM_Q)
531
+ sm_scale_imp = 1.0 / math.sqrt(D_DASH)
532
+ o = torch.empty_like(q)
533
+ stage = 3 if causal else 1
534
+ extra_kern_args = {}
535
+ # Tuning for AMD target
536
+ if is_hip():
537
+ waves_per_eu = 3 if HEAD_DIM_K <= 64 else 2
538
+ extra_kern_args = {"waves_per_eu": waves_per_eu, "allow_flush_denorm": True}
539
+
540
+ grid = lambda args: (triton.cdiv(q.shape[2], args["BLOCK_M"]), q.shape[0] * q.shape[1], 1)
541
+ M = torch.empty((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
542
+ mse_loss = torch.zeros(1, device=q.device, dtype=torch.float32)
543
+ num_elements = (q.shape[0] * q.shape[1] * q.shape[2] * k.shape[2])
544
+ _attn_fwd[grid](
545
+ q, k, v, sm_scale, sm_scale_imp, M, o, #
546
+ q_importance, k_importance, mse_loss, #
547
+ q.stride(0), q.stride(1), q.stride(2), q.stride(3), #
548
+ k.stride(0), k.stride(1), k.stride(2), k.stride(3), #
549
+ v.stride(0), v.stride(1), v.stride(2), v.stride(3), #
550
+ o.stride(0), o.stride(1), o.stride(2), o.stride(3), #
551
+ q_importance.stride(0), q_importance.stride(1), q_importance.stride(2), q_importance.stride(3), #
552
+ k_importance.stride(0), k_importance.stride(1), k_importance.stride(2), k_importance.stride(3), #
553
+ Z=q.shape[0], H=q.shape[1], #
554
+ N_CTX=q.shape[2], #
555
+ HEAD_DIM=HEAD_DIM_K, #
556
+ D_DASH=D_DASH, #
557
+ NUM_ELEMENTS=num_elements, #
558
+ STAGE=stage, #
559
+ **fixed_config.kwargs)
560
+ ctx.save_for_backward(q, q_importance, k, k_importance, v, o, M)
561
+ ctx.grid = grid
562
+ ctx.sm_scale = sm_scale
563
+ ctx.sm_scale_imp = sm_scale_imp
564
+ ctx.HEAD_DIM = HEAD_DIM_K
565
+ ctx.D_DASH = D_DASH
566
+ ctx.num_elements = num_elements
567
+ ctx.causal = causal
568
+ return o, mse_loss
569
+
570
+ @staticmethod
571
+ def backward(ctx, do, dmse):
572
+ q, q_importance, k, k_importance, v, o, M = ctx.saved_tensors
573
+ do = do.contiguous()
574
+ assert q.stride() == k.stride() == v.stride() == o.stride() == do.stride()
575
+ dq = torch.empty_like(q)
576
+ dq_imp = torch.empty_like(q_importance)
577
+ dk = torch.empty_like(k)
578
+ dk_imp = torch.empty_like(k_importance)
579
+ dv = torch.empty_like(v)
580
+ BATCH, N_HEAD, N_CTX = q.shape[:3]
581
+ NUM_WARPS, NUM_STAGES = 4, 5
582
+ BLOCK_M1, BLOCK_N1 = 16, 16
583
+ RCP_LN2 = 1.4426950408889634 # = 1.0 / ln(2)
584
+
585
+ # PRE_BLOCK = 128
586
+ # assert N_CTX % PRE_BLOCK == 0
587
+ # pre_grid = (N_CTX // PRE_BLOCK, BATCH * N_HEAD)
588
+ # delta = torch.empty_like(M)
589
+ # _attn_bwd_preprocess[pre_grid](
590
+ # o, do, #
591
+ # delta, #
592
+ # BATCH, N_HEAD, N_CTX, #
593
+ # BLOCK_M=PRE_BLOCK, HEAD_DIM=ctx.HEAD_DIM #
594
+ # )
595
+ grid = (N_CTX // BLOCK_N1, 1, BATCH * N_HEAD)
596
+ _attn_bwd[grid](
597
+ q, q_importance, k, k_importance, ctx.sm_scale, ctx.sm_scale_imp, ctx.num_elements, #
598
+ dq_imp, dk_imp, dmse, #
599
+ q.stride(0), q.stride(1), q.stride(2), q.stride(3), #
600
+ q_importance.stride(0), q_importance.stride(1), q_importance.stride(2), q_importance.stride(3), #
601
+ N_HEAD, N_CTX, #
602
+ BLOCK_M1=BLOCK_M1, BLOCK_N1=BLOCK_N1, #
603
+ HEAD_DIM=ctx.HEAD_DIM, #
604
+ D_DASH=ctx.D_DASH, #
605
+ num_warps=NUM_WARPS, #
606
+ num_stages=NUM_STAGES #
607
+ )
608
+
609
+ return None, None, None, dq_imp, dk_imp, None
610
+
611
+
612
+ attention_mse_loss = _attention_mse_loss.apply
613
+
614
+ from torch import nn
615
+ from torch.nn import MSELoss
616
+ import time
617
+ if __name__ == '__main__':
618
+ B, H, N_CTX, D = 1, 32, 512, 128
619
+ D_DASH = 32
620
+ causal = True
621
+ DTYPE = torch.float32
622
+ LOAD_WEIGHT = False
623
+ import os
624
+ if LOAD_WEIGHT and os.path.exists("export_params.pt"):
625
+ print("[Info] Detected export_params.pt, loading saved tensors...")
626
+ debug_tensors = torch.load("export_params.pt", map_location=DEVICE)
627
+
628
+ q = debug_tensors["q"].detach().clone().requires_grad_().contiguous()
629
+ k = debug_tensors["k"].detach().clone().requires_grad_().contiguous()
630
+ v = debug_tensors["v"].detach().clone().requires_grad_().contiguous()
631
+ q_importance = debug_tensors["q_importance"].detach().clone().requires_grad_().contiguous()
632
+ k_importance = debug_tensors["k_importance"].detach().clone().requires_grad_().contiguous()
633
+
634
+ print("[Success] Tensors loaded successfully!")
635
+ print(f"q.shape: {q.shape}, k.shape: {k.shape}, v.shape: {v.shape}")
636
+ print(f"q_importance.shape: {q_importance.shape}, k_importance.shape: {k_importance.shape}")
637
+ else:
638
+ print("[Info] No export_params.pt found, initializing random tensors...")
639
+ DTYPE = torch.float32
640
+
641
+ gain_q = math.sqrt(5.0 / D)
642
+ gain_k = math.sqrt(5.0 / D)
643
+ gain_v = math.sqrt(3.0 / D)
644
+
645
+ q = (torch.randn((B, H, N_CTX, D), dtype=DTYPE, device=DEVICE) * gain_q).requires_grad_()
646
+ k = (torch.randn((B, H, N_CTX, D), dtype=DTYPE, device=DEVICE) * gain_k).requires_grad_()
647
+ v = (torch.randn((B, H, N_CTX, D), dtype=DTYPE, device=DEVICE) * gain_v).requires_grad_()
648
+
649
+ gain_q_imp = math.sqrt(5.0 / D_DASH)
650
+ gain_k_imp = math.sqrt(5.0 / D_DASH)
651
+ q_importance = (torch.randn((B, H, N_CTX, D_DASH), dtype=DTYPE, device=DEVICE) * gain_q_imp + 0.1).requires_grad_()
652
+ k_importance = (torch.randn((B, H, N_CTX, D_DASH), dtype=DTYPE, device=DEVICE) * gain_k_imp - 0.1).requires_grad_()
653
+ print("[Info] Random tensors initialized.")
654
+
655
+ # warm up for the triton implementation
656
+ attn_output, mse_loss_triton = attention_mse_loss(q.to(torch.float16),
657
+ k.to(torch.float16),
658
+ v.to(torch.float16),
659
+ q_importance.to(torch.float16),
660
+ k_importance.to(torch.float16), True)
661
+
662
+
663
+ mse_loss_triton.backward()
664
+ q_importance.grad = None
665
+ k_importance.grad = None
666
+
667
+ # warm up for the original implementation
668
+ attn_weights = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(D)
669
+ importance_mask = torch.matmul(q_importance, k_importance.transpose(-2, -1)) / math.sqrt(D_DASH) # [B, H, Lq, Lk]
670
+ mse_func = MSELoss(reduction='none')
671
+ mse_loss_original = mse_func(attn_weights, importance_mask)
672
+ mse_loss_original = mse_loss_original.mean()
673
+ if causal:
674
+ causal_mask = torch.triu(torch.ones(N_CTX, N_CTX, device=attn_weights.device), diagonal=1).bool()
675
+ attn_weights = attn_weights.masked_fill(causal_mask.unsqueeze(0).unsqueeze(0), float('-inf'))
676
+ attn_weights = nn.functional.softmax(attn_weights.float(), dim=-1).to(v.dtype)
677
+ attn_output_original = torch.matmul(attn_weights, v)
678
+ mse_loss_original.backward()
679
+ q_importance.grad = None
680
+ k_importance.grad = None
681
+
682
+ # triton implementation
683
+ tri_start = time.time()
684
+ att_output_triton, mse_loss_triton = attention_mse_loss(q, k, v, q_importance, k_importance, causal)
685
+ mse_loss_triton.backward()
686
+ tri_dq_imp, q_importance.grad = q_importance.grad.clone(), None
687
+ tri_dk_imp, k_importance.grad = k_importance.grad.clone(), None
688
+ tri_end = time.time()
689
+
690
+ # original implementation
691
+ ori_start = time.time()
692
+ attn_weights = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(D)
693
+ importance_mask = torch.matmul(q_importance, k_importance.transpose(-2, -1)) / math.sqrt(D_DASH) # [B, H, Lq, Lk]
694
+ mse_func = MSELoss(reduction='none')
695
+ mse_loss_original = mse_func(attn_weights, importance_mask)
696
+ mse_loss_original = mse_loss_original.mean()
697
+ mse_loss_original = mse_loss_original
698
+ if causal:
699
+ causal_mask = torch.triu(torch.ones(N_CTX, N_CTX, device=attn_weights.device), diagonal=1).bool()
700
+ attn_weights = attn_weights.masked_fill(causal_mask.unsqueeze(0).unsqueeze(0), float('-inf'))
701
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(v.dtype)
702
+ attn_output_original = torch.matmul(attn_weights, v)
703
+
704
+ mse_loss_original.backward()
705
+ ref_dk_imp, k_importance.grad = k_importance.grad.clone(), None
706
+ ref_dq_imp, q_importance.grad = q_importance.grad.clone(), None
707
+ ori_end = time.time()
708
+
709
+ print(f'mse_loss_triton: {mse_loss_triton}')
710
+ print(f'mse_loss_original: {mse_loss_original}')
711
+ print(f'mean of ref_dk_imp: {ref_dk_imp.mean()}')
712
+ print(f'mean of tri_dk_imp: {tri_dk_imp.mean()}')
713
+ # print(f'error: {torch.abs(ref_dk_imp - tri_dk_imp).mean()}')
714
+ print(f'ref_dq_imp: {ref_dq_imp[0][0]}')
715
+ print(f'tri_dq_imp: {tri_dq_imp[0][0]}')
716
+ print(f'mean of ref_dq_imp: {ref_dq_imp.mean()}')
717
+ print(f'mean of tri_dq_imp: {tri_dq_imp.mean()}')
718
+ # print(f'error: {torch.abs(ref_dq_imp - tri_dq_imp).mean()}')
719
+
720
+ # assert torch.allclose(attn_output_original, att_output_triton, atol=1e-2, rtol=0), f'{attn_output_original.mean()} vs {att_output_triton.mean()}'
721
+ perc_diff = 100 * torch.abs(attn_output_original - att_output_triton).mean() / torch.abs(attn_output_original).mean()
722
+ print(f'passed test for attention output with {attn_output_original.mean()} vs {att_output_triton.mean()}, \t\t\tpercentage diff: {perc_diff}%')
723
+ # assert torch.allclose(mse_loss_triton, mse_loss_original, atol=1e-1, rtol=0), f'{mse_loss_triton} vs {mse_loss_original}'
724
+ perc_diff = 100 * torch.abs(mse_loss_triton - mse_loss_original) / torch.abs(mse_loss_original)
725
+ print(f'passed test for mse loss with {mse_loss_triton.item()} vs {mse_loss_original.item()}, \t\t\t\tpercentage diff: {perc_diff.item()}%')
726
+ # assert torch.allclose(ref_dk_imp, tri_dk_imp, atol=1e-1, rtol=0), f'{ref_dk_imp.mean()} vs {tri_dk_imp.mean()}'
727
+ perc_diff = 100 * torch.abs(ref_dk_imp - tri_dk_imp).mean() / torch.abs(ref_dk_imp).mean()
728
+ print(f'passed test for dk_imp with {ref_dk_imp.mean()} vs {tri_dk_imp.mean()}, \t\tpercentage diff: {perc_diff}%')
729
+ # assert torch.allclose(ref_dq_imp, tri_dq_imp, atol=1e-1, [rtol=0), f'{ref_dq_imp.mean()} vs {tri_dq_imp.mean()}'
730
+ perc_diff = 100 * torch.abs(ref_dq_imp - tri_dq_imp).mean() / torch.abs(ref_dq_imp).mean()
731
+ print(f'passed test for dq_imp with {ref_dq_imp.mean()} vs {tri_dq_imp.mean()}, \t\tpercentage diff: {perc_diff}%')
732
+ print(f'original time: {ori_end - ori_start}')
733
+ print(f'triton time: {tri_end - tri_start}')
utils.py ADDED
@@ -0,0 +1,1521 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pdb
3
+ import copy
4
+ import math
5
+ import numpy as np
6
+ from dataclasses import dataclass
7
+ from typing import Optional, Tuple, Union
8
+ import gc
9
+ import matplotlib.pyplot as plt
10
+
11
+ import traceback
12
+ import torch
13
+ from torch import nn
14
+ import torch.utils.checkpoint
15
+ import torch.nn.functional as F
16
+ from torch.cuda.amp import autocast
17
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
18
+
19
+ from scipy.stats import spearmanr
20
+ from transformers.models.llama.configuration_llama import LlamaConfig
21
+ from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, LlamaAttention, LlamaRMSNorm, apply_rotary_pos_emb
22
+ from torch.utils.data import Dataset, DataLoader
23
+ import torch
24
+ from typing import Tuple
25
+ import torch
26
+ import numpy as np
27
+ import matplotlib.pyplot as plt
28
+ import seaborn as sns
29
+ import re
30
+ import time
31
+ import matplotlib.cm as cm
32
+ from scipy.spatial.distance import cosine
33
+ from tqdm import tqdm
34
+
35
+ class SlidingWindowCache:
36
+ def __init__(self, max_seq_len, sliding_window, device):
37
+ self.sliding_window = sliding_window
38
+ self.device = device
39
+ if sliding_window is None:
40
+ self.max_seq_len = 0
41
+ self.window = None
42
+ else:
43
+ self.max_seq_len = max_seq_len
44
+ self.window = self._create_window(self.max_seq_len)
45
+
46
+ def _create_window(self, seq_len):
47
+ idx = torch.arange(seq_len, device=self.device)
48
+ query = idx.unsqueeze(1) # [seq_len, 1]
49
+ key = idx.unsqueeze(0) # [1, seq_len]
50
+ win = (key >= (query - self.sliding_window + 1)) & (key <= query)
51
+ return win.unsqueeze(0).unsqueeze(0) # [1,1,seq_len,seq_len]
52
+
53
+ def get_window(self, q_len, key_len):
54
+ if self.sliding_window is None:
55
+ return None
56
+ req = max(q_len, key_len)
57
+ if req > self.max_seq_len:
58
+ self.max_seq_len = req
59
+ self.window = self._create_window(self.max_seq_len)
60
+ return self.window[:, :, :q_len, :key_len]
61
+
62
+ def enforce_sliding_window(mask_tensor, window):
63
+ if window is None:
64
+ return mask_tensor
65
+ return mask_tensor.masked_fill(window, 0.0)
66
+
67
+
68
+ def sanitize_filename(name):
69
+ return re.sub(r'[<>:"/\\|?*\'\[\]]', '_', name)
70
+
71
+ def args_to_name(args, timestamp=True):
72
+ args_dict = vars(args)
73
+ args_dict = args_dict.copy()
74
+ # remove longbench_datasets, task_list from args_dict
75
+ args_dict.pop("longbench_datasets", None)
76
+ args_dict.pop("task_list", None)
77
+ model_descr = list(args_dict.values())
78
+ # Split the model description into two parts
79
+ split_point = len(model_descr) // 2
80
+ folder_part = model_descr[:split_point]
81
+ file_part = model_descr[split_point:]
82
+ # Create a sanitized folder name from the first part
83
+ folder_name = "_".join([str(elem) for elem in folder_part])
84
+ folder_name = sanitize_filename(folder_name)
85
+ # Create a sanitized file name from the second part
86
+ file_name = "_".join([str(elem) for elem in file_part])
87
+ file_name = sanitize_filename(file_name)
88
+ # Add timestamp to file name
89
+ if timestamp:
90
+ timestamp = time.strftime("%Y%m%d-%H%M%S")
91
+ file_name = file_name + "_" + timestamp
92
+ return folder_name, file_name
93
+
94
+
95
+ def snapkv_mask_only(self, query_states, key_states, value_states, attention_mask=None):
96
+ """
97
+ 'Mask-only' version of SnapKV that does not gather/slice the actual key_states:
98
+ - If q_len < max_capacity_prompt, do nothing.
99
+ - Else, we compute the 'top prefix tokens' using the last window_size queries,
100
+ plus the last window_size tokens themselves.
101
+ - Then we create a single-step mask that is -inf for all other tokens.
102
+
103
+ We store that single-step mask in self.snapkv_cache so that
104
+ on the next decode step (q_len=1) we can re-apply it.
105
+ """
106
+ bsz, num_heads, q_len, head_dim = query_states.shape
107
+ # Ensure prefix-phase
108
+ assert key_states.shape[-2] == query_states.shape[-2], "Prefix shape mismatch"
109
+
110
+ # If no compression is needed, just return the normal outputs
111
+ if q_len < self.max_capacity_prompt:
112
+ return None # signals: no special mask built
113
+
114
+ # 1) Compute local attention (like SnapKV: last window_size queries vs entire prefix)
115
+ obs = self.window_size
116
+ if obs > q_len:
117
+ obs = q_len # if the prompt is shorter than window_size
118
+ attn_logits = torch.matmul(query_states[..., -obs:, :],
119
+ key_states.transpose(-2, -1)) / math.sqrt(head_dim)
120
+ # shape [bsz, num_heads, obs, kv_seq_len]
121
+
122
+ # 2) Build a local triangular mask of shape (obs, obs) for the last window_size queries
123
+ mask = torch.full((obs, obs), float('-inf'), device=attn_logits.device, dtype=attn_logits.dtype)
124
+ idxs = torch.arange(obs, device=mask.device)
125
+ mask.masked_fill_(idxs < idxs.unsqueeze(-1), 0) # lower-tri (including diagonal)=0, above diag=-inf
126
+ local_mask = mask.unsqueeze(0).unsqueeze(0) # [1,1,obs,obs]
127
+
128
+ # Apply it to the last obs block in attn_logits
129
+ attn_logits[:, :, -obs:, -obs:] += local_mask
130
+
131
+ # 3) Softmax
132
+ attn_probs = F.softmax(attn_logits, dim=-1, dtype=torch.float32).to(query_states.dtype)
133
+ # shape [bsz, num_heads, obs, kv_seq_len]
134
+
135
+ # 4) Sum across the obs dimension => [bsz, num_heads, kv_seq_len]
136
+ attn_sum = attn_probs.sum(dim=-2)
137
+
138
+ # 5) Optional pooling => must pass [N, C, L] to max_pool1d/avg_pool1d
139
+ # attn_sum is shape [bsz, num_heads, kv_seq_len]. We can flatten bsz*num_heads
140
+ bnh = bsz * num_heads
141
+ L = key_states.shape[-2] # kv_seq_len
142
+ x = attn_sum.view(bnh, 1, L) # [bnh, 1, kv_seq_len]
143
+
144
+ if self.pooling == 'avgpool':
145
+ pooled = F.avg_pool1d(
146
+ x,
147
+ kernel_size=self.kernel_size,
148
+ stride=1,
149
+ padding=self.kernel_size // 2
150
+ )
151
+ elif self.pooling == 'maxpool':
152
+ pooled = F.max_pool1d(
153
+ x,
154
+ kernel_size=self.kernel_size,
155
+ stride=1,
156
+ padding=self.kernel_size // 2
157
+ )
158
+ else:
159
+ raise ValueError("Unsupported pooling method")
160
+
161
+ # Now pooled is shape [bnh, 1, L']
162
+ # Usually, L' = L if stride=1, but let's just keep it so it lines up with kv_seq_len
163
+ pooled = pooled.view(bsz, num_heads, -1) # [bsz, num_heads, kv_seq_len]
164
+
165
+ # 6) topk
166
+ top_prefix_to_keep = self.max_capacity_prompt - obs
167
+ prefix_indices = pooled.topk(top_prefix_to_keep, dim=-1).indices # [bsz, num_heads, top_prefix_to_keep]
168
+
169
+ # 7) Build single-step mask => shape [bsz, num_heads, 1, kv_seq_len]
170
+ single_mask = torch.full(
171
+ (bsz, num_heads, 1, L),
172
+ float('-inf'),
173
+ device=query_states.device,
174
+ dtype=query_states.dtype
175
+ )
176
+
177
+ # unmask the top prefix positions
178
+ row_idx = torch.arange(bsz, device=query_states.device).view(bsz, 1, 1)
179
+ head_idx = torch.arange(num_heads, device=query_states.device).view(1, num_heads, 1)
180
+ row_idx = row_idx.expand(bsz, num_heads, top_prefix_to_keep)
181
+ head_idx = head_idx.expand(bsz, num_heads, top_prefix_to_keep)
182
+ single_mask[row_idx, head_idx, 0, prefix_indices] = 0.0
183
+
184
+ # unmask the last obs tokens
185
+ single_mask[:, :, 0, -obs:] = 0.0
186
+
187
+ return single_mask
188
+
189
+ def calculate_effective_sparsity(final_mask, attention_mask):
190
+ true_mask = final_mask + attention_mask
191
+ num_deact = true_mask.bool().sum(dim=-1) # Number of tokens disabled.
192
+ causally_deact = (attention_mask.bool()).sum(dim=-1).expand_as(num_deact) # Number of tokens disabled causally anyway
193
+ additional_deact = (num_deact - causally_deact)
194
+ num_active = (~attention_mask.bool()).sum(dim=-1).expand_as(num_deact) # Number of tokens active at this position if zero-sparsity
195
+ effective_sparsity = 100 * (additional_deact.float() / num_active.float()).mean().item()
196
+ return effective_sparsity
197
+
198
+ # def sorted_index_to_mask(sorted_indices, attention_mask, min_sparse_index, bsz, q_len, key_len, sparse_aggression):
199
+ # device = sorted_indices.device
200
+ # dtype = sorted_indices.dtype
201
+ # # Step 1: Calculate the number of keys to keep for each query position
202
+ # # K = ceil((query_position + 1) * sparse_aggression)
203
+ # # Shape: [1, 1, Lq, 1]
204
+ # query_positions = torch.arange(q_len, device=device).view(1, 1, q_len, 1).float() + 1.0 # 1-based indexing
205
+ # K = torch.ceil(query_positions * sparse_aggression).to(dtype=torch.long) # [1, 1, Lq, 1]
206
+ # # Ensure K does not exceed key_len
207
+ # K = torch.clamp(K, max=key_len)
208
+ # # Step 2: Expand attention_mask to match sorted_indices shape
209
+ # # attention_mask: [1, 1, Lq, Lk] -> [B, H, Lq, Lk]
210
+ # attention_mask_expanded = attention_mask.expand(bsz, -1, -1, -1) # [B, 1, Lq, Lk]
211
+ # attention_mask_expanded = attention_mask_expanded.expand(-1, sorted_indices.size(1), -1, -1) # [B, H, Lq, Lk]
212
+ # attention_mask_expanded = (~attention_mask_expanded.bool()).int()
213
+ # # Step 3: Gather attention_mask based on sorted_indices
214
+ # # This rearranges the attention_mask to follow the sorted order of keys
215
+ # gathered_mask = torch.gather(attention_mask_expanded, dim=-1, index=sorted_indices) # [B, H, Lq, Lk]
216
+ # # Step 4: Convert gathered_mask to float for cumulative sum
217
+ # gathered_mask_float = gathered_mask.float()
218
+ # # Step 5: Compute cumulative sum along the sorted key dimension
219
+ # cum_sum = torch.cumsum(gathered_mask_float, dim=-1) # [B, H, Lq, Lk]
220
+ # # Step 6: Create a mask where cumulative sum <= K_i
221
+ # # K has shape [1, 1, Lq, 1], expand it to [B, H, Lq, Lk] for comparison
222
+ # K_broadcast = K.view(1, 1, q_len, 1).expand(bsz, sorted_indices.size(1), q_len, key_len) # [B, H, Lq, Lk]
223
+ # selected_mask = cum_sum <= K_broadcast # [B, H, Lq, Lk], boolean
224
+ # # Step 7: Initialize mask_tensor with -inf
225
+ # mask_tensor = torch.full_like(attention_mask_expanded.to(torch.float32), float('-inf')) # [B, H, Lq, Lk]
226
+ # # Step 8: Prepare values to scatter: 0 where selected_mask is True, else -inf
227
+ # # This ensures that only the selected positions are allowed (0), others remain disallowed (-inf)
228
+ # scatter_values = torch.zeros_like(gathered_mask_float)
229
+ # scatter_values = scatter_values.masked_fill(~selected_mask, float('-inf')) # [B, H, Lq, Lk]
230
+ # # Step 9: Scatter the values back to the original key positions
231
+ # # For each sorted key position, place the corresponding scatter_value at the original key index
232
+ # mask_tensor.scatter_(-1, sorted_indices, scatter_values)
233
+ # # Step 10: Ensure mask_tensor has mask_tensor[:, :, :, :min_sparse_index] = 0
234
+ # mask_tensor[:, :, :, :min_sparse_index] = 0.0
235
+ # return mask_tensor
236
+
237
+ def sorted_index_to_mask(
238
+ sorted_indices,
239
+ attention_mask,
240
+ min_sparse_index,
241
+ bsz,
242
+ q_len,
243
+ key_len,
244
+ sparse_aggression,
245
+ sliding_window=None
246
+ ):
247
+ """
248
+ sorted_indices: [B, H, q_len, key_len]
249
+ attention_mask: [1, 1, q_len, key_len] (True = keep, False = mask out, or vice versa)
250
+ min_sparse_index: guaranteed front region to keep
251
+ sliding_window: guaranteed trailing region (for each query) to keep
252
+ sparse_aggression: float in [0,1], fraction of keys to drop or keep
253
+ """
254
+ device = sorted_indices.device
255
+ dtype = sorted_indices.dtype
256
+
257
+ # Step 1: Compute base K
258
+ if q_len == 1:
259
+ query_positions = torch.arange(q_len, device=device).view(1, 1, q_len, 1).float()
260
+ query_positions[0] = key_len + 1
261
+ else:
262
+ query_positions = torch.arange(q_len, device=device).view(1, 1, q_len, 1).float() + 1.0
263
+ K_original = torch.ceil(query_positions * sparse_aggression).long() # [1,1,q_len,1]
264
+ K_original = torch.clamp(K_original, max=key_len)
265
+
266
+ # Step 1b: Incorporate guaranteed region
267
+ guaranteed = min_sparse_index
268
+ if sliding_window is not None:
269
+ guaranteed += sliding_window
270
+ # Subtract guaranteed from the original K
271
+ K_adjusted = K_original - guaranteed
272
+ # Ensure K_adjusted is at least 0
273
+ K_adjusted = torch.clamp(K_adjusted, min=0, max=key_len)
274
+
275
+ # Step 2: Expand attention_mask to [B,H,q_len,key_len]
276
+ attention_mask_expanded = attention_mask.expand(bsz, -1, -1, -1)
277
+ attention_mask_expanded = attention_mask_expanded.expand(-1, sorted_indices.size(1), -1, -1)
278
+ # Convert True -> 1, False -> 0
279
+ attention_mask_expanded = (~attention_mask_expanded.bool()).int()
280
+
281
+ # Step 3: Gather (reorder) mask by sorted_indices
282
+ gathered_mask = torch.gather(attention_mask_expanded, dim=-1, index=sorted_indices)
283
+
284
+ # Step 4: cumsum along sorted dimension
285
+ gathered_mask_float = gathered_mask.float()
286
+ cum_sum = torch.cumsum(gathered_mask_float, dim=-1) # [B,H,q_len,key_len]
287
+
288
+ # Step 5: Compare cumsum <= K_adjusted
289
+ # Expand K_adjusted to [B,H,q_len,key_len] for broadcast
290
+ K_broadcast = K_adjusted.view(1, 1, q_len, 1).expand_as(cum_sum)
291
+ selected_mask = (cum_sum <= K_broadcast)
292
+
293
+ # Step 6: Prepare final mask_tensor with -inf by default
294
+ mask_tensor = torch.full_like(attention_mask_expanded.float(), float('-inf'))
295
+
296
+ # Step 7: Scatter 0 where selected, -inf otherwise
297
+ scatter_values = torch.zeros_like(gathered_mask_float)
298
+ scatter_values = scatter_values.masked_fill(~selected_mask, float('-inf'))
299
+ mask_tensor.scatter_(-1, sorted_indices, scatter_values)
300
+
301
+ # Step 8: Force the guaranteed front region unmasked
302
+ mask_tensor[:, :, :, :min_sparse_index] = 0.0
303
+
304
+ # We do NOT forcibly unmask the trailing `sliding_window` here,
305
+ # because we typically do it with a separate function that
306
+ # ensures the last `sliding_window` positions are unmasked for each query.
307
+ # Replace with self.sliding_window where referenced
308
+ # Where not referenced, reduce budget in calculation.
309
+
310
+ return mask_tensor
311
+
312
+ def threshold_to_mask(unadj_importance_mask, perhead_thresholds, min_sparse_index, bsz, q_len, key_len):
313
+ """
314
+ Create a mask tensor based on per-head thresholds, setting values below the threshold to -inf.
315
+
316
+ Args:
317
+ - unadj_importance_mask: torch.Tensor of shape [B, H, Lq, Lk].
318
+ - perhead_thresholds: torch.Tensor of shape [H], per-head thresholds.
319
+ - min_sparse_index: Minimum index for sparsity; values below this index will not be masked.
320
+ - bsz: Batch size.
321
+ - q_len: Query length (Lq).
322
+ - key_len: Key length (Lk).
323
+
324
+ Returns:
325
+ - mask_tensor: torch.Tensor of shape [B, H, Lq, Lk], with values below threshold as -inf.
326
+ """
327
+ # Ensure perhead_thresholds is in the correct shape for broadcasting
328
+ thresholds_broadcast = perhead_thresholds.view(1, -1, 1, 1) # [1, H, 1, 1]
329
+
330
+ # Compare unadj_importance_mask with thresholds to create a mask
331
+ mask_tensor = torch.where(
332
+ unadj_importance_mask >= thresholds_broadcast,
333
+ torch.zeros_like(unadj_importance_mask),
334
+ torch.full_like(unadj_importance_mask, float('-inf'))
335
+ ) # [B, H, Lq, Lk]
336
+
337
+ # Ensure mask_tensor has mask_tensor[:, :, :, :min_sparse_index] = 0
338
+ mask_tensor[:, :, :, :min_sparse_index] = 0.0
339
+
340
+ return mask_tensor
341
+
342
+ def calculate_hit_metrics(estimated_importance: torch.Tensor,
343
+ true_importance: torch.Tensor,
344
+ top_k_ratio: float = 0.5) -> Tuple[float, float, float]:
345
+ """
346
+ Calculate hit accuracy, mean, and max rank correlation between estimated and true importance tensors.
347
+ We compute metrics along the last dimension of the input tensors.
348
+
349
+ Shapes:
350
+ - 4D token-importance: [B, H, L, L]. We slice the last query (index -1) => [B, H, L].
351
+ - 3D head-importance: [B, L, H]. We use all of it as-is => [B, L, H].
352
+
353
+ Args:
354
+ estimated_importance (torch.Tensor): [B, H, L, L] or [B, L, H]
355
+ true_importance (torch.Tensor): [B, H, L, L] or [B, L, H]
356
+ top_k_ratio (float): Fraction of top-k elements to consider for hit accuracy (default=0.5).
357
+
358
+ Returns:
359
+ (hit_accuracy, mean_corr, max_corr):
360
+ hit_accuracy (float): Intersection ratio of top-k sets (0..1).
361
+ mean_corr (float): Average Spearman rank correlation over all [B, ...].
362
+ max_corr (float): Maximum Spearman rank correlation among all [B, ...].
363
+ """
364
+
365
+ # 1) Standardize shapes so the last dimension is what we rank over.
366
+ if estimated_importance.dim() == 4:
367
+ # Shape is [B, H, L, L] => slice to keep only the last query => [B, H, L]
368
+ estimated_importance = estimated_importance[:, :, -1, :]
369
+ true_importance = true_importance[:, :, -1, :]
370
+ # after slicing: [B, H, L]
371
+ # For intersection denominator => top_k * B * H
372
+ denom_for_hits = estimated_importance.size(0) * estimated_importance.size(1)
373
+ elif estimated_importance.dim() == 3:
374
+ # Shape is [B, L, H], the last dimension is H
375
+ # For intersection denominator => top_k * B * L
376
+ denom_for_hits = estimated_importance.size(0) * estimated_importance.size(1)
377
+ else:
378
+ raise ValueError("Tensors must be either 4D [B,H,L,L] or 3D [B,L,H].")
379
+
380
+ # 2) Compute Spearman rank correlation along the last dimension.
381
+ # Sort indices in descending order => get 'ranks' for correlation.
382
+ _, sorted_esti = torch.sort(estimated_importance, dim=-1, descending=True)
383
+ _, sorted_true = torch.sort(true_importance, dim=-1, descending=True)
384
+
385
+ # Spearman's rho = 1 - 6 * sum(d^2) / [n*(n^2 - 1)]
386
+ n = sorted_esti.shape[-1]
387
+ d = sorted_esti.float() - sorted_true.float()
388
+ d_squared = d ** 2
389
+ sum_d_squared = d_squared.sum(dim=-1)
390
+ rank_corr = 1 - (6 * sum_d_squared) / (n * (n**2 - 1)) # shape: [B,H] or [B,L]
391
+
392
+ mean_corr = rank_corr.mean().item()
393
+ max_corr = rank_corr.max().item()
394
+
395
+ # 3) Compute top-k hit accuracy along the last dimension.
396
+ top_k = max(1, int(n * top_k_ratio))
397
+ _, top_esti_indices = torch.topk(estimated_importance, top_k, dim=-1)
398
+ _, top_true_indices = torch.topk(true_importance, top_k, dim=-1)
399
+
400
+ # top_esti_indices => [B,H,top_k] or [B,L,top_k]
401
+ # top_true_indices => [B,H,top_k] or [B,L,top_k]
402
+ # matches => [B,H,top_k,top_k] or [B,L,top_k,top_k]
403
+ matches = (top_esti_indices.unsqueeze(-1) == top_true_indices.unsqueeze(-2))
404
+ intersection = matches.any(dim=-1).sum(dim=-1) # => [B,H] or [B,L]
405
+
406
+ # Each [B,H] or [B,L] element can have at most 'top_k' matches, so total is top_k * denom_for_hits.
407
+ total_possible = top_k * denom_for_hits
408
+ hit_accuracy = intersection.sum().item() / total_possible # => 0..1
409
+
410
+ return hit_accuracy, mean_corr, max_corr
411
+
412
+ def plot_thresholds(threshold_tensor, true_threshold_tensor, fpath_base, fpath_specific):
413
+ """
414
+ Plots mean and error regions for random layers and heads, showing threshold changes across decode steps.
415
+
416
+ Args:
417
+ - threshold_tensor: torch.Tensor of shape [163, 31, 32, 1024].
418
+ - true_threshold_tensor: torch.Tensor of shape [163, 31, 32, 1024].
419
+ """
420
+ def create_plot(tensor, title, filename):
421
+ """
422
+ Helper function to generate the plot.
423
+ """
424
+ # Choose 3 random layers
425
+ layers = np.random.choice(tensor.shape[1], 3, replace=False)
426
+ # layers = np.array([0, 15, 30])
427
+
428
+ # Create subplots
429
+ fig, axs = plt.subplots(1, 3, figsize=(18, 5), sharey=True)
430
+ x = np.arange(tensor.shape[3]) # Decode steps (1024)
431
+
432
+ for i, layer in enumerate(layers):
433
+ # Choose 5 random heads for this layer
434
+ heads = np.random.choice(tensor.shape[2], 5, replace=False)
435
+
436
+ for head in heads:
437
+ try:
438
+ # Extract data for the selected layer and head
439
+ data = tensor[:, layer, head, :].numpy() # Shape [163, 1024]
440
+
441
+ # Compute mean and standard deviation across samples (dim=0)
442
+ mean = np.mean(data, axis=0)
443
+ std = np.std(data, axis=0)
444
+
445
+ # Plot mean and shaded error region for the head
446
+ axs[i].plot(x, mean, label=f"Head {head}")
447
+ axs[i].fill_between(x, mean - std, mean + std, alpha=0.3)
448
+ except:
449
+ import pdb; pdb.set_trace()
450
+
451
+ # Customize subplot
452
+ axs[i].set_title(f"Layer {layer}")
453
+ axs[i].set_xlabel("Decode Step")
454
+ axs[i].grid(True)
455
+ axs[i].legend(fontsize=8) # Adjust legend size for multiple heads
456
+
457
+ # Common Y-axis label and adjustments
458
+ axs[0].set_ylabel("Threshold")
459
+ fig.suptitle(title, fontsize=16)
460
+ plt.tight_layout(rect=[0, 0.03, 1, 0.95])
461
+
462
+ # Save the plot
463
+ plt.savefig(filename)
464
+ plt.close()
465
+
466
+ def compute_mean_threshold(tensor):
467
+ """
468
+ Computes the mean threshold value for each head and layer, excluding the first 32 tokens.
469
+ """
470
+ # Exclude the first 32 tokens (dimension 1024)
471
+ tensor_excluded = tensor[:, :, :, 32:] # Shape [163, 31, 32, 992]
472
+
473
+ # Compute the mean along the first (samples) and last (remaining tokens) dimensions
474
+ mean_threshold = tensor_excluded.mean(dim=(0, -1)) # Shape [31, 32]
475
+
476
+ return mean_threshold
477
+
478
+ # create folder fpath_base if it does not exist
479
+ if not os.path.exists(f"threshold_plots"):
480
+ os.makedirs(f"threshold_plots")
481
+ if not os.path.exists(f"threshold_plots/{fpath_base}"):
482
+ os.makedirs(f"threshold_plots/{fpath_base}")
483
+ # Plot for threshold_tensor
484
+ create_plot(threshold_tensor, "Post-Attention Thresholds", f"threshold_plots/{fpath_base}/{fpath_specific}_postattn_threshold.pdf")
485
+
486
+ # Plot for true_threshold_tensor
487
+ create_plot(true_threshold_tensor, "Predicted Pre-SM Thresholds", f"threshold_plots/{fpath_base}/{fpath_specific}_pred_presm_threshold.pdf")
488
+
489
+ # Compute mean thresholds
490
+ mean_threshold_postattn = compute_mean_threshold(threshold_tensor)
491
+ mean_threshold_predpresm = compute_mean_threshold(true_threshold_tensor)
492
+
493
+ return mean_threshold_postattn, mean_threshold_predpresm
494
+
495
+
496
+
497
+ # def plot_thresholds(threshold_tensor, true_threshold_tensor):
498
+ # """
499
+ # Plots mean and error regions for random layers and heads, showing threshold changes across decode steps.
500
+
501
+ # Args:
502
+ # - threshold_tensor: torch.Tensor of shape [163, 31, 32, 1024].
503
+ # - true_threshold_tensor: torch.Tensor of shape [163, 31, 32, 1024].
504
+ # """
505
+ # def create_plot(tensor, title, filename):
506
+ # """
507
+ # Helper function to generate the plot.
508
+ # """
509
+ # # Choose 3 random layers and heads
510
+ # layers = np.random.choice(tensor.shape[1], 3, replace=False)
511
+ # heads = np.random.choice(tensor.shape[2], 3, replace=False)
512
+
513
+ # # Create subplots
514
+ # fig, axs = plt.subplots(1, 3, figsize=(15, 5), sharey=True)
515
+ # x = np.arange(tensor.shape[3]) # Decode steps (1024)
516
+
517
+ # for i, (layer, head) in enumerate(zip(layers, heads)):
518
+ # # Extract data for the selected layer and head
519
+ # data = tensor[:, layer, head, :].numpy() # Shape [163, 1024]
520
+
521
+ # # Compute mean and standard deviation across samples (dim=0)
522
+ # mean = np.mean(data, axis=0)
523
+ # std = np.std(data, axis=0)
524
+
525
+ # # Plot with shaded error regions
526
+ # axs[i].plot(x, mean, label=f"Layer {layer}, Head {head}")
527
+ # axs[i].fill_between(x, mean - std, mean + std, alpha=0.3)
528
+ # axs[i].set_title(f"Layer {layer}, Head {head}")
529
+ # axs[i].set_xlabel("Decode Step")
530
+ # axs[i].grid(True)
531
+
532
+ # # Common Y-axis label and adjustments
533
+ # axs[0].set_ylabel("Threshold")
534
+ # fig.suptitle(title, fontsize=16)
535
+ # plt.tight_layout(rect=[0, 0.03, 1, 0.95])
536
+
537
+ # # Save the plot
538
+ # plt.savefig(filename)
539
+ # plt.close()
540
+
541
+ # def compute_mean_threshold(tensor):
542
+ # """
543
+ # Computes the mean threshold value for each head and layer, excluding the first 32 tokens.
544
+ # """
545
+ # # Exclude the first 32 tokens (dimension 1024)
546
+ # tensor_excluded = tensor[:, :, :, 32:] # Shape [163, 31, 32, 992]
547
+
548
+ # # Compute the mean along the first (samples) and last (remaining tokens) dimensions
549
+ # mean_threshold = tensor_excluded.mean(dim=(0, -1)) # Shape [31, 32]
550
+
551
+ # return mean_threshold
552
+
553
+ # # Plot for threshold_tensor
554
+ # create_plot(threshold_tensor, "Post-Attention Thresholds", "postattn_threshold.pdf")
555
+
556
+ # # Plot for true_threshold_tensor
557
+ # create_plot(true_threshold_tensor, "Predicted Pre-SM Thresholds", "pred_presm_threshold.pdf")
558
+
559
+
560
+ # # Compute mean thresholds
561
+ # mean_threshold_postattn = compute_mean_threshold(threshold_tensor)
562
+ # mean_threshold_predpresm = compute_mean_threshold(true_threshold_tensor)
563
+
564
+ # return mean_threshold_postattn, mean_threshold_predpresm
565
+
566
+
567
+ # def calculate_hit_metrics(estimated_importance: torch.Tensor,
568
+ # true_importance: torch.Tensor,
569
+ # top_k_ratio: float = 0.5) -> Tuple[float, float, float]:
570
+ # """
571
+ # Calculate hit accuracy, mean, and max rank correlation between estimated and true importance tensors.
572
+
573
+ # Args:
574
+ # estimated_importance (torch.Tensor): Tensor of estimated importance values [B, H, L, L] OR [B, L, H] (Token / Head Prediction)
575
+ # true_importance (torch.Tensor): Tensor of true importance values [B, H, L, L] OR [B, L, H] (Token / Head Prediction)
576
+ # top_k_ratio (float): Fraction of top-k elements to consider for hit accuracy (default: 0.5).
577
+
578
+ # Returns:
579
+ # Tuple[float, float, float]: Hit accuracy, mean rank correlation, max rank correlation.
580
+ # """
581
+ # if len(estimated_importance.shape) == 4:
582
+ # # Only take the final, fully unmasked next-word prediction problem.
583
+ # import pdb; pdb.set_trace()
584
+ # estimated_importance = estimated_importance[:, :, -1, :].unsqueeze(-2)
585
+ # true_importance = true_importance[:, :, -1, :].unsqueeze(-2)
586
+ # estisize = estimated_importance.size(0) * estimated_importance.size(1) * estimated_importance.size(2)
587
+ # elif len(estimated_importance.shape) == 3:
588
+ # # Take all next-word head magnitudes, but note that its not L2
589
+ # import pdb; pdb.set_trace()
590
+ # estisize = estimated_importance.size(0) * estimated_importance.size(1)
591
+ # else:
592
+ # raise ValueError("Invalid shape for estimated_importance tensor. Must be 3 or 4 dimensional.")
593
+ # # Sort indices to get ranks
594
+ # _, sorted_esti_indices = torch.sort(estimated_importance, dim=-1, descending=True)
595
+ # _, sorted_true_indices = torch.sort(true_importance, dim=-1, descending=True)
596
+
597
+ # # Compute rank correlation
598
+ # n = sorted_esti_indices.shape[-1]
599
+ # d = sorted_esti_indices - sorted_true_indices
600
+ # d_squared = d ** 2
601
+ # sum_d_squared = torch.sum(d_squared, dim=-1)
602
+ # rank_correlation = 1 - (6 * sum_d_squared) / (n * (n**2 - 1))
603
+
604
+ # # Compute top-k hit accuracy
605
+ # top_k = max(1, int(n * top_k_ratio))
606
+ # _, top_esti_indices = torch.topk(estimated_importance, top_k, dim=-1)
607
+ # _, top_true_indices = torch.topk(true_importance, top_k, dim=-1)
608
+ # intersection = (top_esti_indices.unsqueeze(-1) == top_true_indices.unsqueeze(-2)).any(dim=-1).sum(dim=-1)
609
+ # hit_accuracy = intersection.sum().item() / (top_k * estisize)
610
+
611
+ # # Return metrics
612
+ # mean_rank_corr = rank_correlation.mean().item()
613
+ # max_rank_corr = rank_correlation.max().item()
614
+ # return hit_accuracy, mean_rank_corr, max_rank_corr
615
+
616
+
617
+ class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
618
+ """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
619
+
620
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0, config=None):
621
+ self.scaling_factor = scaling_factor
622
+ super().__init__(config)
623
+
624
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
625
+ self.max_seq_len_cached = seq_len
626
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
627
+ t = t / self.scaling_factor
628
+
629
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
630
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
631
+ emb = torch.cat((freqs, freqs), dim=-1)
632
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
633
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
634
+
635
+
636
+ class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
637
+ """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
638
+
639
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0, config=None):
640
+ self.scaling_factor = scaling_factor
641
+ super().__init__(config)
642
+
643
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
644
+ self.max_seq_len_cached = seq_len
645
+
646
+ if seq_len > self.max_position_embeddings:
647
+ base = self.base * (
648
+ (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
649
+ ) ** (self.dim / (self.dim - 2))
650
+ inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
651
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
652
+
653
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
654
+
655
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
656
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
657
+ emb = torch.cat((freqs, freqs), dim=-1)
658
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
659
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
660
+
661
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
662
+ """
663
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
664
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
665
+ """
666
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
667
+ if n_rep == 1:
668
+ return hidden_states
669
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
670
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
671
+
672
+
673
+ class FlattenedDataset(Dataset):
674
+ def __init__(self, dataset, max_seq_len, max_repeat_fraction=0.7):
675
+ self.max_seq_len = max_seq_len
676
+ self.max_repeat_fraction = max_repeat_fraction
677
+
678
+ # Extract and flatten the input_ids column
679
+ all_tokens = torch.cat([torch.tensor(ids) for ids in dataset["input_ids"]], dim=0)
680
+
681
+ # Calculate the number of full chunks
682
+ num_full_chunks = len(all_tokens) // self.max_seq_len
683
+ all_chunks = all_tokens[:num_full_chunks * self.max_seq_len].view(-1, self.max_seq_len)
684
+
685
+ # Filter out chunks with excessive repeated tokens
686
+ self.chunks = []
687
+ for chunk in all_chunks:
688
+ unique_tokens, counts = torch.unique(chunk, return_counts=True)
689
+ max_repeats = counts.max().item()
690
+ if max_repeats <= self.max_repeat_fraction * chunk.numel():
691
+ self.chunks.append(chunk)
692
+
693
+ self.chunks = torch.stack(self.chunks) # Stack the remaining chunks
694
+
695
+ def __len__(self):
696
+ return len(self.chunks)
697
+
698
+ def __getitem__(self, idx):
699
+ return self.chunks[idx]
700
+
701
+
702
+ def compute_js_divergence(p, q, epsilon=1e-12):
703
+ """
704
+ Compute the Jensen-Shannon Divergence between two probability distributions.
705
+
706
+ Args:
707
+ p (torch.Tensor): Shape [..., D]
708
+ q (torch.Tensor): Shape [..., D]
709
+ epsilon (float): Small value to avoid log(0)
710
+
711
+ Returns:
712
+ torch.Tensor: JS Divergence values per pair (Shape: [...])
713
+ """
714
+ # Add epsilon to avoid log(0)
715
+ p = p + epsilon
716
+ q = q + epsilon
717
+
718
+ # Normalize to ensure they are valid probability distributions
719
+ p = p / p.sum(dim=-1, keepdim=True)
720
+ q = q / q.sum(dim=-1, keepdim=True)
721
+
722
+ # Compute the average distribution
723
+ m = 0.5 * (p + q)
724
+
725
+ # Compute KL Divergences
726
+ kl_p = F.kl_div(m.log(), p, reduction='none').sum(dim=-1)
727
+ kl_q = F.kl_div(m.log(), q, reduction='none').sum(dim=-1)
728
+
729
+ # Compute JS Divergence
730
+ js = 0.5 * kl_p + 0.5 * kl_q
731
+
732
+ return js
733
+
734
+ def compute_head_consistency_js(head_data):
735
+ """
736
+ Compute the consistency of a head's probability distributions across examples using JS Divergence.
737
+
738
+ Args:
739
+ head_data (torch.Tensor): Shape [163, 1024], probability distributions for one head.
740
+
741
+ Returns:
742
+ float: Mean pairwise JS Divergence across examples.
743
+ """
744
+ num_examples = head_data.size(0)
745
+
746
+ # Ensure head_data is float for division and normalization
747
+ head_data = head_data.float()
748
+
749
+ # Normalize each distribution to ensure it sums to 1
750
+ head_data = head_data / head_data.sum(dim=-1, keepdim=True)
751
+
752
+ # Initialize variables to accumulate JS divergences
753
+ total_js = 0.0
754
+ count = 0
755
+
756
+ # Iterate over all unique pairs without redundancy
757
+ for i in tqdm(range(num_examples), desc="Computing JS Divergence"):
758
+ # Select the i-th distribution and expand its dimensions
759
+ p = head_data[i].unsqueeze(0) # Shape: [1, D]
760
+
761
+ # Select all distributions after the i-th to avoid duplicate pairs
762
+ q = head_data[i+1:] # Shape: [N-i-1, D]
763
+
764
+ if q.size(0) == 0:
765
+ continue # No more pairs left
766
+
767
+ # Compute JS Divergence between p and q
768
+ js = compute_js_divergence(p.repeat(q.size(0), 1), q) # Shape: [N-i-1]
769
+
770
+ # Accumulate the sum of JS Divergence
771
+ total_js += js.sum().item()
772
+ count += js.size(0)
773
+
774
+ # Compute the mean JS Divergence
775
+ mean_js = total_js / count if count > 0 else 0.0
776
+
777
+ return mean_js
778
+
779
+
780
+ def compute_token_consistency_js(head_data):
781
+ """
782
+ Compute token consistency for all heads in a layer using JS Divergence.
783
+
784
+ Args:
785
+ head_data (torch.Tensor): Shape [163, 24, 1024], layer's head data.
786
+
787
+ Returns:
788
+ np.ndarray: Consistency values for all 24 heads.
789
+ """
790
+ num_heads = head_data.shape[1]
791
+ consistency_metrics = []
792
+
793
+ for head in tqdm(range(num_heads), desc="Processing Heads"):
794
+ head_consistency = compute_head_consistency_js(head_data[:, head, :]) # Consistency for one head
795
+ consistency_metrics.append(head_consistency)
796
+
797
+ return np.array(consistency_metrics)
798
+
799
+
800
+ def graph_headtok_pos_affinity(head_tokpos_affinity, args):
801
+ """
802
+ Generate a violin plot for Token Access Consistency Across Layers using JS Divergence.
803
+
804
+ Args:
805
+ head_tokpos_affinity (dict): Dictionary where keys are layer identifiers and values are
806
+ torch.Tensor of shape [163, 24, 1024].
807
+ args (argparse.Namespace): Arguments containing at least 'model_path'.
808
+ """
809
+ # Process the data into a format suitable for plotting
810
+ layer_ids = []
811
+ consistency_values = []
812
+
813
+ for layer, tensor in tqdm(head_tokpos_affinity.items(), desc="Processing Layers"):
814
+ layer_consistency = compute_token_consistency_js(tensor) # Shape: [24]
815
+ layer_ids.extend([layer] * len(layer_consistency))
816
+ consistency_values.extend(layer_consistency)
817
+
818
+ # Create directory structure: ablation_plots/traces/tok_js_div
819
+ trace_dir = f"ablation_plots/traces/tok_js_div"
820
+ os.makedirs(trace_dir, exist_ok=True)
821
+ mpath = args.model_path.replace("/", "_")
822
+ # Save data to a .npy file (NumPy array format)
823
+ trace_path = os.path.join(trace_dir, f"layer_consistency_{mpath}.npy")
824
+ os.makedirs(os.path.dirname(trace_path), exist_ok=True)
825
+ np.save(trace_path, {"Layer": layer_ids, "JS_Divergence": consistency_values})
826
+ print(f"Consistency data saved to {trace_path}")
827
+
828
+
829
+ # Prepare data for Seaborn violin plot
830
+ data = {"Layer": layer_ids, "JS_Divergence": consistency_values}
831
+
832
+ # Create the violin plot
833
+ plt.figure(figsize=(10, 6))
834
+ sns.violinplot(x=data["Layer"], y=data["JS_Divergence"], scale="width", inner="quartile", palette="viridis")
835
+
836
+ # Formatting the plot
837
+ plt.title(f"Token Access Consistency Across Layers for {args.model_path}", fontsize=16)
838
+ plt.xlabel("Layer", fontsize=14)
839
+ plt.ylabel("Token Consistency Metric (Mean JS Divergence)", fontsize=14)
840
+ plt.xticks(fontsize=12)
841
+ plt.yticks(fontsize=12)
842
+
843
+ # Enhance layout
844
+ plt.tight_layout()
845
+
846
+ # Create ablation_plots directory if it doesn't exist
847
+ os.makedirs("ablation_plots", exist_ok=True)
848
+
849
+ # Construct the full file path
850
+ file_path = f"ablation_plots/{mpath}_headtok_consistency_js_divergence.pdf"
851
+
852
+ os.makedirs(os.path.dirname(file_path), exist_ok=True)
853
+ # Save the plot
854
+ plt.savefig(file_path)
855
+ plt.close()
856
+
857
+
858
+ def compute_head_agreement_js(head_data):
859
+ """
860
+ Compute head agreement for a single example using JS Divergence.
861
+
862
+ Args:
863
+ head_data (torch.Tensor): Shape [num_heads, num_tokens], token distributions for all heads.
864
+
865
+ Returns:
866
+ float: Mean pairwise JS Divergence for the heads.
867
+ """
868
+ num_heads = head_data.size(0)
869
+ js_divergences = []
870
+
871
+ for i in range(num_heads):
872
+ p = head_data[i].unsqueeze(0) # Shape: [1, num_tokens]
873
+ q = head_data[i + 1 :] # Remaining heads, Shape: [num_heads - i - 1, num_tokens]
874
+
875
+ if q.size(0) == 0:
876
+ continue
877
+
878
+ js = compute_js_divergence(p.repeat(q.size(0), 1), q) # Pairwise JS Div
879
+ js_divergences.extend(js.tolist())
880
+
881
+ # Return the mean of the upper triangular JS Divergence matrix
882
+ return np.mean(js_divergences)
883
+
884
+ def compute_head_agreement_all_examples(head_tokpos_affinity):
885
+ """
886
+ Compute head agreement for all examples across all layers.
887
+
888
+ Args:
889
+ head_tokpos_affinity (dict): Dictionary where keys are layer identifiers and values are
890
+ torch.Tensor of shape [num_examples, num_heads, num_tokens].
891
+
892
+ Returns:
893
+ np.ndarray: Head agreement values for all examples.
894
+ """
895
+ agreement_values = []
896
+
897
+ for layer, tensor in tqdm(head_tokpos_affinity.items(), desc="Processing Layers"):
898
+ for example_idx in range(tensor.shape[0]): # Iterate over examples
899
+ head_data = tensor[example_idx] # Shape: [num_heads, num_tokens]
900
+ agreement = compute_head_agreement_js(head_data)
901
+ agreement_values.append(agreement)
902
+
903
+ return np.array(agreement_values)
904
+
905
+ def plot_and_save_head_agreement(agreement_values, args):
906
+ """
907
+ Save head agreement values and plot them as a violin plot.
908
+
909
+ Args:
910
+ agreement_values (np.ndarray): Head agreement values for all examples.
911
+ args (argparse.Namespace): Arguments containing at least 'model_path'.
912
+ """
913
+ # Save agreement values to a .npy file
914
+ trace_dir = "ablation_plots/traces/headagreement_js_div"
915
+ os.makedirs(trace_dir, exist_ok=True)
916
+ mpath = args.model_path.replace("/", "_")
917
+ trace_path = os.path.join(trace_dir, f"head_agreement_{mpath}.npy")
918
+ np.save(trace_path, {"HeadAgreement": agreement_values})
919
+ print(f"Head agreement values saved to {trace_path}")
920
+
921
+ # Plot the violin plot
922
+ plt.figure(figsize=(10, 6))
923
+ sns.violinplot(data=[agreement_values], scale="width", inner="quartile", palette="viridis")
924
+
925
+ # Formatting
926
+ plt.title(f"Head Agreement Across Examples for {mpath}", fontsize=16)
927
+ plt.xlabel("Model", fontsize=14)
928
+ plt.ylabel("Mean JS Divergence Between Heads", fontsize=14)
929
+ plt.xticks([0], [mpath], fontsize=12) # Single violin
930
+ plt.yticks(fontsize=12)
931
+
932
+ # Enhance layout
933
+ plt.tight_layout()
934
+
935
+ # Save the plot
936
+ plot_path = f"ablation_plots/{mpath}_head_agreement_js_divergence.pdf"
937
+ plt.savefig(plot_path)
938
+ print(f"Violin plot saved to {plot_path}")
939
+ plt.close()
940
+
941
+
942
+ def compute_jsd_over_decode_steps(decode_probs):
943
+ """
944
+ Compute average JSD over decode steps for a single head.
945
+
946
+ Args:
947
+ decode_probs (torch.Tensor): Shape [50, 974], softmaxed token importances for 50 decode steps.
948
+
949
+ Returns:
950
+ float: Mean JSD across the upper diagonal of the pairwise JSD matrix.
951
+ """
952
+ # Expand dims to create pairwise matrices
953
+ p = decode_probs.unsqueeze(0) # Shape: [1, 50, 974]
954
+ q = decode_probs.unsqueeze(1) # Shape: [50, 1, 974]
955
+
956
+ # Compute pairwise JSD (broadcasting handles the pairwise combinations)
957
+ jsd_matrix = compute_js_divergence(p, q) # Shape: [50, 50]
958
+
959
+ # Extract upper diagonal without the diagonal itself
960
+ triu_indices = torch.triu_indices(jsd_matrix.size(0), jsd_matrix.size(1), offset=1)
961
+ jsd_upper = jsd_matrix[triu_indices[0], triu_indices[1]] # Shape: [N*(N-1)/2]
962
+
963
+ return jsd_upper.mean().item()
964
+
965
+ def compute_layer_jsd(decode_tokpos_affinity):
966
+ """
967
+ Compute average JSD over decode steps for all heads and layers.
968
+
969
+ Args:
970
+ decode_tokpos_affinity (dict): Dictionary where keys are layer indices and values are
971
+ torch.Tensor of shape [163, 24, 50, 974].
972
+
973
+ Returns:
974
+ dict: Average JSD values for each head in each layer, keyed by layer index.
975
+ """
976
+ layer_jsd = {}
977
+
978
+ for layer, tensor in tqdm(decode_tokpos_affinity.items(), desc="Processing Layers"):
979
+ # tensor shape: [163, 24, 50, 974]
980
+ num_examples, num_heads, num_decode_steps, num_tokens = tensor.shape
981
+
982
+ # Reshape for batch processing
983
+ decode_probs = tensor.view(-1, num_decode_steps, num_tokens) # Shape: [163 * 24, 50, 974]
984
+
985
+ # Compute pairwise JSD for all heads and examples
986
+ jsd_values = []
987
+ for decode_head in decode_probs:
988
+ jsd_values.append(compute_jsd_over_decode_steps(decode_head))
989
+
990
+ # Reshape back to per-head per-layer values
991
+ jsd_values = torch.tensor(jsd_values).view(num_examples, num_heads) # Shape: [163, 24]
992
+ layer_jsd[layer] = jsd_values.mean(dim=0).tolist() # Average across examples per head
993
+
994
+ return layer_jsd
995
+
996
+ def plot_decode_jsd_violin(layer_jsd, args):
997
+ """
998
+ Plot per-layer JSD values as violins.
999
+
1000
+ Args:
1001
+ layer_jsd (dict): Dictionary of average JSD values for each head in each layer.
1002
+ args (argparse.Namespace): Arguments containing at least 'model_path'.
1003
+ """
1004
+ # Prepare data for per-layer violin plot
1005
+ layers = []
1006
+ jsd_values = []
1007
+
1008
+ for layer, values in layer_jsd.items():
1009
+ layers.extend([layer] * len(values))
1010
+ jsd_values.extend(values)
1011
+
1012
+ # Save JSD values to a file
1013
+ trace_dir = "ablation_plots/traces/decode_jsd"
1014
+ os.makedirs(trace_dir, exist_ok=True)
1015
+ mpath = args.model_path.replace("/", "_")
1016
+ trace_path = os.path.join(trace_dir, f"decode_jsd_{mpath}.npy")
1017
+ np.save(trace_path, {"Layer": layers, "JSD": jsd_values})
1018
+ print(f"Decode JSD data saved to {trace_path}")
1019
+
1020
+ # Plot the violin plot
1021
+ plt.figure(figsize=(10, 6))
1022
+ sns.violinplot(x=layers, y=jsd_values, scale="width", inner="quartile", palette="viridis")
1023
+
1024
+ # Formatting
1025
+ plt.title(f"Per-Layer Decode JSD for {args.model_path}", fontsize=16)
1026
+ plt.xlabel("Layer", fontsize=14)
1027
+ plt.ylabel("Average JSD Over Decode Steps", fontsize=14)
1028
+ plt.xticks(fontsize=12)
1029
+ plt.yticks(fontsize=12)
1030
+
1031
+ plt.tight_layout()
1032
+
1033
+ # Save the plot
1034
+ plot_path = f"ablation_plots/{mpath}_decode_jsd_per_layer.pdf"
1035
+ plt.savefig(plot_path)
1036
+ print(f"Violin plot saved to {plot_path}")
1037
+ plt.close()
1038
+
1039
+ def compute_percentage_match_vectorized(decode_probs, top_k=0.1):
1040
+ """
1041
+ Compute the average percentage match of top-k token indices across 50 decode steps for a single head.
1042
+
1043
+ Args:
1044
+ decode_probs (torch.Tensor): Shape [50, 974], softmaxed token importances for 50 decode steps.
1045
+ top_k (float): Percentage of top tokens to consider (e.g., 0.1 for top 10%).
1046
+
1047
+ Returns:
1048
+ float: Average percentage match of token indices across the 50 decode steps.
1049
+ """
1050
+ num_steps, num_tokens = decode_probs.shape
1051
+ k = int(num_tokens * top_k) # Number of top tokens to consider
1052
+
1053
+ # Get top-k indices for all steps
1054
+ top_indices = torch.topk(decode_probs, k, dim=-1).indices # Shape: [50, k]
1055
+
1056
+ # Create a binary mask for top-k tokens
1057
+ binary_mask = torch.zeros(num_steps, num_tokens, device=decode_probs.device)
1058
+ binary_mask.scatter_(1, top_indices, 1) # Shape: [50, 974]
1059
+
1060
+ # Compute overlap between all pairs of steps
1061
+ overlap_matrix = torch.matmul(binary_mask, binary_mask.T) # Shape: [50, 50]
1062
+ overlap_per_pair = overlap_matrix / k # Normalize by k to get match percentages
1063
+
1064
+ # Extract upper triangular without diagonal
1065
+ triu_indices = torch.triu_indices(num_steps, num_steps, offset=1)
1066
+ upper_triangle = overlap_per_pair[triu_indices[0], triu_indices[1]] # Shape: [N*(N-1)/2]
1067
+
1068
+ return torch.tensor([upper_triangle.mean().item()], device=decode_probs.device)
1069
+
1070
+
1071
+
1072
+ def compute_layer_percentage_match_vectorized(decode_tokpos_affinity, top_k=0.1):
1073
+ """
1074
+ Compute average percentage match for top-k token indices across decode steps for all heads and layers.
1075
+
1076
+ Args:
1077
+ decode_tokpos_affinity (dict): Dictionary where keys are layer indices and values are
1078
+ torch.Tensor of shape [163, 24, 50, 974].
1079
+ top_k (float): Percentage of top tokens to consider (e.g., 0.1 for top 10%).
1080
+
1081
+ Returns:
1082
+ dict: Average percentage match values for each head in each layer, keyed by layer index.
1083
+ """
1084
+ layer_match = {}
1085
+
1086
+ for layer, tensor in tqdm(decode_tokpos_affinity.items(), desc="Processing Layers"):
1087
+ # tensor shape: [163, 24, 50, 974]
1088
+ num_examples, num_heads, num_decode_steps, num_tokens = tensor.shape
1089
+
1090
+ # Flatten examples and heads for batch processing
1091
+ decode_probs = tensor.view(-1, num_decode_steps, num_tokens) # Shape: [163 * 24, 50, 974]
1092
+
1093
+ # Vectorized computation for all decode heads
1094
+ match_values = torch.cat([
1095
+ compute_percentage_match_vectorized(decode_head, top_k=top_k) for decode_head in decode_probs
1096
+ ])
1097
+
1098
+ # Reshape back to per-head per-layer values
1099
+ match_values = match_values.view(num_examples, num_heads) # Shape: [163, 24]
1100
+ layer_match[layer] = match_values.mean(dim=0).tolist() # Average across examples per head
1101
+
1102
+ return layer_match
1103
+
1104
+ def plot_decode_percdrift_vectorized(layer_match, args):
1105
+ """
1106
+ Plot per-layer average percentage match of top-k token indices as violins.
1107
+
1108
+ Args:
1109
+ layer_match (dict): Dictionary of average percentage match values for each head in each layer.
1110
+ args (argparse.Namespace): Arguments containing at least 'model_path'.
1111
+ """
1112
+ # Prepare data for per-layer violin plot
1113
+ layers = []
1114
+ match_values = []
1115
+
1116
+ for layer, values in layer_match.items():
1117
+ layers.extend([layer] * len(values))
1118
+ match_values.extend(values)
1119
+
1120
+ # Save match values to a file
1121
+ trace_dir = "ablation_plots/traces/percdrift"
1122
+ os.makedirs(trace_dir, exist_ok=True)
1123
+ mpath = args.model_path.replace("/", "_")
1124
+ trace_path = os.path.join(trace_dir, f"decode_percdrift_{mpath}.npy")
1125
+ np.save(trace_path, {"Layer": layers, "Match": match_values})
1126
+ print(f"Decode percentage drift data saved to {trace_path}")
1127
+
1128
+ # Plot the violin plot
1129
+ plt.figure(figsize=(10, 6))
1130
+ sns.violinplot(x=layers, y=match_values, scale="width", inner="quartile", palette="viridis")
1131
+
1132
+ # Formatting
1133
+ plt.title(f"Per-Layer Percentage Match for {args.model_path}", fontsize=16)
1134
+ plt.xlabel("Layer", fontsize=14)
1135
+ plt.ylabel("Average Percentage Match (Top 10%)", fontsize=14)
1136
+ plt.xticks(fontsize=12)
1137
+ plt.yticks(fontsize=12)
1138
+
1139
+ plt.tight_layout()
1140
+
1141
+ # Save the plot
1142
+ plot_path = f"ablation_plots/{mpath}_decode_percdrift_per_layer.pdf"
1143
+ plt.savefig(plot_path)
1144
+ print(f"Violin plot saved to {plot_path}")
1145
+ plt.close()
1146
+
1147
+
1148
+ def plot_decode_drift_trajectory(decode_tokpos_affinity, top_k=0.1, args=None):
1149
+ """
1150
+ Plot the trajectory of top-k token overlaps for each decode step, compared to the first decode step.
1151
+
1152
+ Args:
1153
+ decode_tokpos_affinity (dict): Dictionary where keys are layer indices and values are
1154
+ torch.Tensor of shape [num_examples, num_heads, num_decode_steps, num_tokens].
1155
+ top_k (float): Percentage of top tokens to consider (e.g., 0.1 for top 10%).
1156
+ args (argparse.Namespace): Arguments containing at least 'model_path'.
1157
+ """
1158
+
1159
+
1160
+ trajectories = [] # To store trajectories for all layers and heads
1161
+ num_decode_steps = None # To infer decode step count from the first layer
1162
+ for layer, tensor in tqdm(decode_tokpos_affinity.items(), desc="Processing Layers"):
1163
+ num_examples, num_heads, num_decode_steps, num_tokens = tensor.shape
1164
+ k = int(num_tokens * top_k) # Top-k token count
1165
+
1166
+ # Flatten examples and heads for batch processing
1167
+ decode_probs = tensor.view(-1, num_decode_steps, num_tokens) # Shape: [163 * 24, 50, 974]
1168
+
1169
+ # Get top-k indices for the first decode step
1170
+ initial_top_k = torch.topk(decode_probs[:, 0, :], k, dim=-1).indices # Shape: [163 * 24, k]
1171
+
1172
+ # Create binary masks for the first decode step
1173
+ initial_masks = torch.zeros(decode_probs.size(0), num_tokens, device=decode_probs.device)
1174
+ initial_masks.scatter_(1, initial_top_k, 1) # Shape: [163 * 24, num_tokens]
1175
+
1176
+ # Get top-k indices for all decode steps
1177
+ top_k_indices = torch.topk(decode_probs, k, dim=-1).indices # Shape: [163 * 24, 50, k]
1178
+
1179
+ # Create binary masks for all decode steps
1180
+ step_masks = torch.zeros(decode_probs.size(0), num_decode_steps, num_tokens, device=decode_probs.device)
1181
+ step_masks.scatter_(2, top_k_indices, 1) # Shape: [163 * 24, 50, num_tokens]
1182
+
1183
+ # Compute overlap with the first decode step for all steps
1184
+ overlaps = (step_masks * initial_masks.unsqueeze(1)).sum(dim=-1) / k # Shape: [163 * 24, 50]
1185
+
1186
+ # Append mean overlap trajectory for this layer
1187
+ # import pdb; pdb.set_trace()
1188
+ # Append all trajectories for this layer
1189
+ trajectories.extend(overlaps.cpu().numpy()) # Shape: [163 * 24, 50]
1190
+
1191
+ # Plot all trajectories with a colormap
1192
+ # Convert trajectories to NumPy for easier processing
1193
+ trajectories = np.array(trajectories) # Shape: [672, 50]
1194
+ plt.figure(figsize=(10, 6))
1195
+ colormap = cm.get_cmap("viridis", trajectories.shape[0]) # Viridis colormap
1196
+ for i, trajectory in enumerate(trajectories):
1197
+ plt.plot(range(num_decode_steps), trajectory, color=colormap(i), alpha=0.5, linewidth=0.8)
1198
+
1199
+ # Add labels and grid
1200
+ plt.title("Decode Drift Trajectories for All Heads", fontsize=16)
1201
+ plt.xlabel("Decode Step", fontsize=14)
1202
+ plt.ylabel("Top-k Overlap with Initial Step", fontsize=14)
1203
+ plt.grid(True)
1204
+
1205
+ # Save the plot
1206
+ mpath = args.model_path.replace("/", "_")
1207
+ output_path = f"ablation_plots/{mpath}_drift_trajectories.png"
1208
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
1209
+ plt.tight_layout()
1210
+ plt.savefig(output_path, dpi=600)
1211
+ plt.close()
1212
+ bins = 10
1213
+ # Create Y-axis edges for histogram bins
1214
+ y_edges = np.linspace(0, 1, bins + 1) # Divide Y-axis into bins (0 to 1 overlap values)
1215
+
1216
+ # Digitize trajectories to find bin indices for all overlap values
1217
+ bin_indices = np.digitize(trajectories, y_edges, right=True) - 1 # Shape: [109536, 50]
1218
+
1219
+ # Clip bin indices to avoid out-of-bound indices
1220
+ bin_indices = np.clip(bin_indices, 0, bins - 1) # Ensure indices are within [0, bins-1]
1221
+
1222
+ # Create the density map using bincount for each step
1223
+ density_map = np.zeros((bins, num_decode_steps), dtype=np.float32)
1224
+ for step in tqdm(range(num_decode_steps)):
1225
+ # Count occurrences in each bin for the current step
1226
+ counts = np.bincount(bin_indices[:, step], minlength=bins)
1227
+ density_map[:, step] = counts
1228
+ # Normalize density map for better visualization
1229
+ density_map /= density_map.max()
1230
+ # # Create a 2D density histogram (heatmap data)
1231
+ # density_map = np.zeros((bins, num_decode_steps))
1232
+ # y_edges = np.linspace(0, 1, bins + 1) # Divide Y-axis into bins (0 to 1 overlap values)
1233
+
1234
+ # for step in range(num_decode_steps):
1235
+ # # Histogram for overlap values at each decode step
1236
+ # hist, _ = np.histogram(trajectories[:, step], bins=y_edges)
1237
+ # density_map[:, step] = hist
1238
+
1239
+ # # Normalize density map for better visualization
1240
+ # density_map /= density_map.max()
1241
+
1242
+ # Plot the density heatmap
1243
+ plt.figure(figsize=(12, 8))
1244
+ sns.heatmap(
1245
+ density_map,
1246
+ cmap="viridis",
1247
+ xticklabels=10, # Optional: Adjust frequency of X-ticks
1248
+ yticklabels=np.round(np.linspace(0, 1, bins), decimals=2), # Show overlap bins
1249
+ cbar_kws={"label": "Density"}
1250
+ )
1251
+ plt.title("Density of Decode Drift Trajectories", fontsize=16)
1252
+ plt.xlabel("Decode Step", fontsize=14)
1253
+ plt.ylabel("Top-k Overlap with Initial Step", fontsize=14)
1254
+ plt.tight_layout()
1255
+
1256
+ # Save the plot
1257
+ mpath = args.model_path.replace("/", "_")
1258
+ output_path = f"ablation_plots/{mpath}_drift_density_heatmap.png"
1259
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
1260
+ plt.savefig(output_path, dpi=600)
1261
+ plt.close()
1262
+
1263
+ print(f"Drift trajectory plot saved to {output_path}")
1264
+ # Convert trajectories to NumPy for easier processing
1265
+ trajectories = np.array(trajectories) # Shape: [672, 50]
1266
+ # average trajectories on dim=0
1267
+ trajectories_to_save = np.mean(trajectories, axis=0) # Shape: [50]
1268
+
1269
+ # Save rank agreement values to a .npy file
1270
+ trace_dir = "ablation_plots/traces/decode_drift_trajectory"
1271
+ os.makedirs(trace_dir, exist_ok=True)
1272
+ mpath = args.model_path.replace("/", "_")
1273
+ trace_path = os.path.join(trace_dir, f"drift_traj_{mpath}.npy")
1274
+ np.save(trace_path, {"Trajectory": trajectories_to_save})
1275
+
1276
+ # Compute mean and standard deviation across trajectories
1277
+ mean_trajectory = np.mean(trajectories, axis=0) # Shape: [50]
1278
+ std_trajectory = np.std(trajectories, axis=0) # Shape: [50]
1279
+
1280
+ # Plot the mean trajectory with shaded error region
1281
+ plt.figure(figsize=(10, 6))
1282
+ plt.plot(range(num_decode_steps), mean_trajectory, label="Mean Drift Trajectory", color="blue")
1283
+ plt.fill_between(
1284
+ range(num_decode_steps),
1285
+ mean_trajectory - std_trajectory,
1286
+ mean_trajectory + std_trajectory,
1287
+ color="blue",
1288
+ alpha=0.2,
1289
+ label="±1 Std Dev"
1290
+ )
1291
+ plt.axhline(y=1.0, color="red", linestyle="--", label="Initial (100% Overlap)")
1292
+ plt.title("Decode Drift Trajectory with Error Region", fontsize=16)
1293
+ plt.xlabel("Decode Step", fontsize=14)
1294
+ plt.ylabel("Top-k Overlap with Initial Step", fontsize=14)
1295
+ plt.legend(fontsize=12)
1296
+ plt.grid(True)
1297
+
1298
+ # Save the plot
1299
+ mpath = args.model_path.replace("/", "_")
1300
+ output_path = f"ablation_plots/{mpath}_drift_trajectory.png"
1301
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
1302
+ plt.tight_layout()
1303
+ plt.savefig(output_path, dpi=600)
1304
+ plt.close()
1305
+
1306
+ print(f"Drift trajectory plot saved to {output_path}")
1307
+
1308
+
1309
+ def compute_rank_agreement_all_examples(head_tokpos_affinity, args):
1310
+ """
1311
+ Compute rank agreement (mean, min, max) for all examples across all heads and layers.
1312
+
1313
+ Args:
1314
+ head_tokpos_affinity (dict): Keys are layers; values are torch.Tensor of shape [num_examples, num_heads, num_tokens].
1315
+
1316
+ Returns:
1317
+ np.ndarray: Shape [num_examples, 3], where 3 corresponds to mean, min, and max rank correlation per example.
1318
+ """
1319
+ num_examples = next(iter(head_tokpos_affinity.values())).shape[0]
1320
+
1321
+ # Flatten heads across layers
1322
+ all_heads = []
1323
+ for layer, tensor in head_tokpos_affinity.items():
1324
+ all_heads.append(tensor) # Shape: [num_examples, num_heads, num_tokens]
1325
+ all_heads = torch.cat(all_heads, dim=1) # Shape: [num_examples, total_heads, num_tokens]
1326
+
1327
+ # Rank tokens per head (Spearman's correlation requires ranks)
1328
+ ranks = torch.argsort(all_heads, dim=-1).float() # Shape: [num_examples, total_heads, num_tokens]
1329
+
1330
+ # Compute rank correlation metrics
1331
+ results = []
1332
+ total_corr_matrix = None
1333
+
1334
+ for example_idx in tqdm(range(num_examples), desc="Computing Rank Correlations"):
1335
+ example_ranks = ranks[example_idx] # Shape: [total_heads, num_tokens]
1336
+
1337
+ # Compute rank correlation matrix (Spearman)
1338
+ corr_matrix = np.corrcoef(example_ranks.numpy()) # Shape: [total_heads, total_heads]
1339
+
1340
+ # Accumulate the correlation matrix
1341
+ if total_corr_matrix is None:
1342
+ total_corr_matrix = corr_matrix
1343
+ else:
1344
+ total_corr_matrix += corr_matrix
1345
+ # Extract upper triangle
1346
+ triu_indices = np.triu_indices(corr_matrix.shape[0], k=1)
1347
+ upper_triangle = corr_matrix[triu_indices]
1348
+
1349
+ # Compute mean, min, and max of rank correlations
1350
+ mean_corr = np.mean(upper_triangle)
1351
+ min_corr = np.min(upper_triangle)
1352
+ max_corr = np.max(upper_triangle)
1353
+
1354
+ results.append([mean_corr, min_corr, max_corr])
1355
+
1356
+ mean_corr_matrix = total_corr_matrix / num_examples
1357
+ # Plot the heatmap
1358
+ plt.figure(figsize=(12, 10)) # Adjust size as needed
1359
+ sns.heatmap(corr_matrix, square=True, cbar=True, xticklabels=False, yticklabels=False, cmap="viridis")
1360
+
1361
+ # Set title
1362
+ plt.title("Mean Rank Correlation Matrix", fontsize=16)
1363
+
1364
+ # Construct the file path
1365
+ mpath = args.model_path.replace("/", "_")
1366
+ heatmap_path = f"ablation_plots/{mpath}_rankcorr_heatmap.png"
1367
+
1368
+ # Save the heatmap
1369
+ plt.tight_layout()
1370
+ plt.savefig(heatmap_path, dpi=600)
1371
+ plt.close()
1372
+
1373
+ print(f"Mean rank correlation heatmap saved to {heatmap_path}")
1374
+ return np.array(results) # Shape: [num_examples, 3]
1375
+
1376
+ def plot_and_save_rank_agreement(rank_agreement, args):
1377
+ """
1378
+ Save rank agreement values and plot their distribution as a violin plot.
1379
+
1380
+ Args:
1381
+ rank_agreement (np.ndarray): Shape [num_examples, 3], where columns represent mean, min, and max rank correlations.
1382
+ args (argparse.Namespace): Arguments containing at least 'model_path'.
1383
+ """
1384
+ # Save rank agreement values to a .npy file
1385
+ trace_dir = "ablation_plots/traces/rankagreement_allheads"
1386
+ os.makedirs(trace_dir, exist_ok=True)
1387
+ mpath = args.model_path.replace("/", "_")
1388
+ trace_path = os.path.join(trace_dir, f"rank_agreement_{mpath}.npy")
1389
+ np.save(trace_path, {"RankAgreement": rank_agreement})
1390
+ print(f"Rank agreement data saved to {trace_path}")
1391
+
1392
+ # Prepare data for violin plot
1393
+ categories = ["Mean", "Min", "Max"]
1394
+ values = [rank_agreement[:, i] for i in range(3)] # Separate columns
1395
+
1396
+ # Create the violin plot
1397
+ plt.figure(figsize=(10, 6))
1398
+ sns.violinplot(data=values, scale="width", inner="quartile", palette="viridis")
1399
+
1400
+ # Formatting
1401
+ plt.title(f"Rank Agreement Distribution for {mpath}", fontsize=16)
1402
+ plt.xlabel("Metric", fontsize=14)
1403
+ plt.ylabel("Rank Correlation", fontsize=14)
1404
+ plt.xticks(range(3), categories, fontsize=12)
1405
+ plt.yticks(fontsize=12)
1406
+
1407
+ # Enhance layout and save the plot
1408
+ plt.tight_layout()
1409
+ # plot_path = os.path.join(trace_dir, f"rank_agreement_violin_{mpath}.pdf")
1410
+
1411
+ plot_path = f"ablation_plots/{mpath}_rank_agreement_violin.pdf"
1412
+ plt.savefig(plot_path)
1413
+ print(f"Violin plot saved to {plot_path}")
1414
+ plt.close()
1415
+ # def compute_head_consistency(head_data):
1416
+ # """
1417
+ # Compute the consistency of a head's probability distributions across examples.
1418
+
1419
+ # Args:
1420
+ # head_data (torch.Tensor): Shape [163, 1024], probability distributions for one head.
1421
+
1422
+ # Returns:
1423
+ # float: Mean pairwise cosine similarity across examples.
1424
+ # """
1425
+ # # Ensure head_data is float for division and normalization
1426
+ # head_data = head_data.float()
1427
+
1428
+ # # Normalize each distribution to ensure it sums to 1
1429
+ # head_data = head_data / head_data.sum(dim=-1, keepdim=True)
1430
+
1431
+ # # Normalize vectors to unit norm to prepare for cosine similarity
1432
+ # head_data = head_data / head_data.norm(dim=-1, keepdim=True)
1433
+
1434
+ # # Compute cosine similarity matrix: [163, 163]
1435
+ # similarity_matrix = torch.matmul(head_data, head_data.T)
1436
+
1437
+ # # Extract the upper triangular part of the similarity matrix, excluding the diagonal
1438
+ # num_examples = head_data.size(0)
1439
+ # triu_indices = torch.triu_indices(num_examples, num_examples, offset=1)
1440
+ # pairwise_similarities = similarity_matrix[triu_indices[0], triu_indices[1]]
1441
+
1442
+ # # Compute and return the mean similarity
1443
+ # return pairwise_similarities.mean().item()
1444
+
1445
+ # def compute_token_consistency(head_data):
1446
+ # """
1447
+ # Compute token consistency for all heads in a layer.
1448
+
1449
+ # Args:
1450
+ # head_data (torch.Tensor): Shape [163, 24, 1024], layer's head data.
1451
+
1452
+ # Returns:
1453
+ # np.ndarray: Consistency values for all 24 heads.
1454
+ # """
1455
+ # # Ensure head_data is float for division and normalization
1456
+ # head_data = head_data.float()
1457
+
1458
+ # # Normalize each distribution to ensure it sums to 1
1459
+ # head_data = head_data / head_data.sum(dim=-1, keepdim=True)
1460
+
1461
+ # # Normalize vectors to unit norm to prepare for cosine similarity
1462
+ # head_data = head_data / head_data.norm(dim=-1, keepdim=True)
1463
+
1464
+ # # Reshape to [24, 163, 1024] for batch processing
1465
+ # head_data = head_data.permute(1, 0, 2) # Now shape is [24, 163, 1024]
1466
+
1467
+ # # Compute cosine similarity matrices for all heads: [24, 163, 163]
1468
+ # similarity_matrices = torch.bmm(head_data, head_data.transpose(1, 2))
1469
+
1470
+ # # Extract upper triangular indices, excluding the diagonal
1471
+ # num_examples = head_data.size(1)
1472
+ # triu_indices = torch.triu_indices(num_examples, num_examples, offset=1)
1473
+
1474
+ # # Gather pairwise similarities for all heads: [24, num_pairs]
1475
+ # pairwise_similarities = similarity_matrices[:, triu_indices[0], triu_indices[1]]
1476
+
1477
+ # # Compute the mean similarity for each head: [24]
1478
+ # consistency_metrics = pairwise_similarities.mean(dim=1).cpu().numpy()
1479
+
1480
+ # return consistency_metrics
1481
+
1482
+
1483
+
1484
+ # def graph_headtok_pos_affinity(head_tokpos_affinity, args):
1485
+ # # Process the data into a format suitable for plotting
1486
+ # layer_ids = []
1487
+ # consistency_values = []
1488
+
1489
+ # for layer, tensor in tqdm(head_tokpos_affinity.items()):
1490
+ # layer_consistency = compute_token_consistency(tensor) # Shape: [24]
1491
+ # layer_ids.extend([layer] * len(layer_consistency))
1492
+ # consistency_values.extend(layer_consistency)
1493
+
1494
+ # # Prepare data for Seaborn violin plot
1495
+ # data = {"Layer": layer_ids, "Consistency": consistency_values}
1496
+
1497
+ # # Create the violin plot
1498
+ # plt.figure(figsize=(10, 6))
1499
+ # sns.violinplot(x=data["Layer"], y=data["Consistency"], scale="width", inner="quartile", palette="viridis")
1500
+
1501
+ # # Formatting the plot
1502
+ # plt.title(f"Token Access Consistency Across Layers For {args.model_path}", fontsize=16)
1503
+ # plt.xlabel("Layer", fontsize=14)
1504
+ # plt.ylabel("Token Consistency Metric (Mean Cosine Similarity)", fontsize=14)
1505
+ # plt.xticks(fontsize=12)
1506
+ # plt.yticks(fontsize=12)
1507
+
1508
+ # # Show the plot
1509
+ # plt.tight_layout()
1510
+
1511
+ # # Create ablation_plots directory if it doesn't exist
1512
+ # os.makedirs("ablation_plots", exist_ok=True)
1513
+
1514
+ # # Construct the full file path
1515
+ # file_path = f"ablation_plots/{args.model_path}_headtok_consistency.pdf"
1516
+
1517
+ # # Create the ablation_plots and intermediate directories if they don't exist
1518
+ # os.makedirs(os.path.dirname(file_path), exist_ok=True)
1519
+ # # Save the plot
1520
+ # plt.savefig(file_path)
1521
+ # plt.close()