kfoughali commited on
Commit
52797d8
·
verified ·
1 Parent(s): 0129ff5

Update compression.py

Browse files
Files changed (1) hide show
  1. compression.py +248 -96
compression.py CHANGED
@@ -1,6 +1,8 @@
 
1
  """
2
  Enhanced SPG compression algorithms with RocketKV-style 450x compression.
3
  NO ESTIMATIONS - only measured values. FAIL FAST on errors.
 
4
  """
5
 
6
  import torch
@@ -17,10 +19,50 @@ from config import (
17
 
18
  logger = logging.getLogger(__name__)
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  class EnhancedSlidingPrecisionGradient:
21
  """
22
  Research-grade Enhanced SPG with RocketKV-style 450x compression capability.
23
  NO ESTIMATIONS OR HARDCODED VALUES - all parameters from validated config.
 
24
  """
25
 
26
  def __init__(self, config: EnhancedSPGConfig):
@@ -160,34 +202,50 @@ class EnhancedSlidingPrecisionGradient:
160
 
161
  def snapkv_plus_plus(self, keys: torch.Tensor, values: torch.Tensor,
162
  compression_ratio: float) -> Tuple[torch.Tensor, torch.Tensor, List[int]]:
163
- """SnapKV++ with GQA support and adaptive pooling - no hardcoded values."""
164
  batch_size, n_heads, seq_len, head_dim = keys.shape
165
 
 
 
 
 
 
 
 
 
 
 
 
166
  # Adaptive kernel size based on sequence length (from config)
167
  kernel_size = self.config.get_adaptive_kernel_size(seq_len)
168
 
169
  # Compute importance scores with adaptive pooling
170
- key_norms = keys.norm(dim=-1) # [batch, heads, seq]
171
- value_norms = values.norm(dim=-1)
172
- combined_importance = (key_norms + value_norms) / 2.0
173
-
174
- # Multi-head aggregation with adaptive pooling
175
- if kernel_size > 1:
176
- # Apply 1D pooling along sequence dimension
177
- pooled_importance = F.avg_pool1d(
178
- combined_importance.mean(dim=1).unsqueeze(1), # [batch, 1, seq]
179
- kernel_size=kernel_size,
180
- stride=1,
181
- padding=kernel_size // 2
182
- ).squeeze(1) # [batch, seq]
183
- # Ensure pooled output matches original sequence length
184
- if pooled_importance.shape[-1] != seq_len:
185
- pooled_importance = pooled_importance[:, :seq_len]
186
- else:
187
- pooled_importance = combined_importance.mean(dim=1)
188
-
189
- # Aggregate across batch
190
- final_importance = pooled_importance.mean(dim=0) # [seq]
 
 
 
 
 
191
 
192
  # Ensure importance tensor matches sequence length
193
  if final_importance.shape[0] != seq_len:
@@ -195,14 +253,18 @@ class EnhancedSlidingPrecisionGradient:
195
 
196
  # Preserve sink and recent tokens
197
  preserve_mask = torch.zeros(seq_len, dtype=torch.bool, device=keys.device)
198
- preserve_mask[:min(self.config.sink_tokens, seq_len)] = True
199
- preserve_mask[-min(self.config.recent_window, seq_len):] = True
200
 
201
- # Top-k selection for remaining tokens
202
- n_keep = max(self.config.sink_tokens + self.config.recent_window,
203
- int(seq_len / compression_ratio))
204
- n_keep = min(n_keep, seq_len) # Ensure we don't exceed sequence length
205
- remaining_slots = n_keep - preserve_mask.sum().item()
 
 
 
 
 
 
206
 
207
  if remaining_slots > 0:
208
  masked_importance = final_importance.clone()
@@ -212,36 +274,58 @@ class EnhancedSlidingPrecisionGradient:
212
  if len(available_indices) > 0:
213
  k = min(remaining_slots, len(available_indices))
214
  if k > 0:
215
- _, relative_top_indices = torch.topk(masked_importance[available_indices], k)
216
- absolute_top_indices = available_indices[relative_top_indices]
217
- preserve_mask[absolute_top_indices] = True
 
 
 
 
 
 
218
 
219
- # Extract retained tokens with bounds checking
220
- retained_indices = torch.where(preserve_mask)[0]
221
- retained_indices = retained_indices[retained_indices < seq_len] # Safety check
 
 
222
 
223
- keys_compressed = keys[:, :, retained_indices, :]
224
- values_compressed = values[:, :, retained_indices, :]
 
225
 
226
- actual_ratio = seq_len / len(retained_indices) if len(retained_indices) > 0 else float('inf')
227
- logger.debug(f"SnapKV++: {seq_len} → {len(retained_indices)} tokens ({actual_ratio:.1f}x)")
228
 
229
  return keys_compressed, values_compressed, retained_indices.tolist()
230
 
231
  def hybrid_sparse_attention(self, keys: torch.Tensor, values: torch.Tensor,
232
  head_budget: int, seq_budget: int) -> Dict[str, Any]:
233
- """RocketKV-style Hybrid Sparse Attention for Stage 2 - no hardcoded values."""
234
  batch_size, n_heads, seq_len, head_dim = keys.shape
235
 
236
- # 1. Head-wise importance scoring
237
- head_importance = (
238
- keys.float().pow(2).sum(dim=(-1, -2)).sum(dim=0) + # Sum over batch, seq, hidden
239
- values.float().pow(2).sum(dim=(-1, -2)).sum(dim=0)
240
- ) # [n_heads]
241
 
242
- # Select top heads
243
- actual_head_budget = min(head_budget, n_heads)
244
- _, top_head_indices = torch.topk(head_importance, actual_head_budget)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
245
 
246
  compressed_data = {
247
  'keys': {},
@@ -255,33 +339,49 @@ class EnhancedSlidingPrecisionGradient:
255
 
256
  # 2. Sequence-wise top-k selection per selected head
257
  for head_idx in top_head_indices:
258
- head_keys = keys[:, head_idx:head_idx+1, :, :] # Keep head dimension
259
- head_values = values[:, head_idx:head_idx+1, :, :]
 
 
 
260
 
261
  # Compute sequence importance for this head
262
- seq_importance = (
263
- head_keys.norm(dim=-1).squeeze(1).mean(dim=0) + # [seq]
264
- head_values.norm(dim=-1).squeeze(1).mean(dim=0)
265
- ) / 2.0
 
 
 
 
266
 
267
  # Apply position-based boost (from research constants)
268
  position_boost = torch.ones_like(seq_importance)
269
- position_boost[:self.config.sink_tokens] *= self.constants.POSITION_BOOST_SINK
270
- position_boost[-self.config.recent_window:] *= self.constants.POSITION_BOOST_RECENT
 
 
 
 
 
271
  boosted_importance = seq_importance * position_boost
272
 
273
  # Select top tokens for this head
274
- actual_seq_budget = min(seq_budget, seq_len)
275
- _, top_token_indices = torch.topk(boosted_importance, actual_seq_budget)
 
 
 
 
276
 
277
  # Store compressed data
278
- head_key = f'head_{head_idx.item()}'
279
  compressed_data['keys'][head_key] = {
280
- 'data': head_keys[:, :, top_token_indices, :].clone(),
281
  'indices': top_token_indices.tolist()
282
  }
283
  compressed_data['values'][head_key] = {
284
- 'data': head_values[:, :, top_token_indices, :].clone(),
285
  'indices': top_token_indices.tolist()
286
  }
287
 
@@ -315,7 +415,7 @@ class EnhancedSlidingPrecisionGradient:
315
 
316
  # Calculate retention based on compression ratio
317
  retention_ratio = 1.0 / compression_ratio
318
- min_retain = self.config.sink_tokens + self.config.recent_window
319
  n_retain = max(min_retain, int(seq_len * retention_ratio))
320
 
321
  # Apply layer-specific constraints (from research constants)
@@ -325,7 +425,7 @@ class EnhancedSlidingPrecisionGradient:
325
  else: # Late layers
326
  max_retain = int(seq_len * self.constants.LATE_LAYER_MAX_RETENTION)
327
 
328
- n_retain = min(n_retain, max_retain)
329
 
330
  # Compute magnitude-based importance
331
  importance_scores = self.compute_magnitude_importance(keys, values)
@@ -333,13 +433,18 @@ class EnhancedSlidingPrecisionGradient:
333
  # Quality preservation: boost recent tokens (explicit formula from config)
334
  recent_boost = torch.zeros_like(importance_scores)
335
  if self.config.recent_window > 0:
336
- recent_boost[-self.config.recent_window:] = importance_scores.max() * self.config.recent_boost_factor
 
337
  importance_scores = importance_scores + recent_boost
338
 
339
  # Initialize preservation mask
340
  preserve_mask = torch.zeros(seq_len, dtype=torch.bool, device=keys.device)
341
- preserve_mask[:self.config.sink_tokens] = True
342
- preserve_mask[-self.config.recent_window:] = True
 
 
 
 
343
 
344
  # Select additional tokens based on importance
345
  remaining_slots = n_retain - preserve_mask.sum().item()
@@ -359,15 +464,22 @@ class EnhancedSlidingPrecisionGradient:
359
  available = (masked_importance > -float('inf')).sum().item()
360
  k = min(remaining_slots, available)
361
  if k > 0:
362
- _, top_indices = torch.topk(masked_importance, k)
363
- preserve_mask[top_indices] = True
 
364
 
365
  # Extract retained tokens
366
- retained_indices = torch.where(preserve_mask)[0]
367
- keys_stage1 = keys[:, :, retained_indices, :]
368
- values_stage1 = values[:, :, retained_indices, :]
 
 
 
369
 
370
- actual_ratio = seq_len / len(retained_indices) if len(retained_indices) > 0 else float('inf')
 
 
 
371
  logger.debug(f"Stage 1 Layer {layer_idx}: {seq_len} → {len(retained_indices)} tokens ({actual_ratio:.1f}x)")
372
 
373
  return keys_stage1, values_stage1, retained_indices.tolist()
@@ -382,7 +494,10 @@ class EnhancedSlidingPrecisionGradient:
382
 
383
  if self.use_hybrid_sparse_attention:
384
  # RocketKV-style compression with adaptive budgets
385
- sparsity = self.estimate_attention_sparsity(keys, values) # May raise if fails
 
 
 
386
 
387
  if self.use_adaptive_decomposition:
388
  _, stage2_ratio = self.adaptive_stage_split(
@@ -462,7 +577,11 @@ class EnhancedSlidingPrecisionGradient:
462
  values.float().pow(2).sum(dim=(-1, -2)).sum(dim=0)
463
  )
464
 
465
- _, important_head_indices = torch.topk(head_importance, n_important_heads)
 
 
 
 
466
  other_head_indices = torch.tensor(
467
  [h for h in range(n_heads) if h not in important_head_indices.tolist()],
468
  device=keys.device, dtype=torch.long
@@ -470,19 +589,19 @@ class EnhancedSlidingPrecisionGradient:
470
 
471
  # Store important heads at full precision
472
  compressed_data['keys']['heads_fp16'] = {
473
- 'data': keys[:, important_head_indices, :, :].clone(),
474
  'indices': important_head_indices.tolist()
475
  }
476
  compressed_data['values']['heads_fp16'] = {
477
- 'data': values[:, important_head_indices, :, :].clone(),
478
  'indices': important_head_indices.tolist()
479
  }
480
 
481
  if other_head_indices.numel() == 0:
482
  return compressed_data
483
 
484
- seq_keys = keys[:, other_head_indices, :, :]
485
- seq_values = values[:, other_head_indices, :, :]
486
  else:
487
  seq_keys = keys
488
  seq_values = values
@@ -492,10 +611,13 @@ class EnhancedSlidingPrecisionGradient:
492
 
493
  # Explicit top-K selection for FP16
494
  keep_fp16 = max(0, int(seq_len * self.config.sequence_compression_ratio))
495
- top_fp16 = torch.topk(combined_importance, k=keep_fp16).indices if keep_fp16 > 0 else torch.empty(0, dtype=torch.long, device=keys.device)
496
- is_fp16 = torch.zeros(seq_len, dtype=torch.bool, device=keys.device)
497
  if keep_fp16 > 0:
498
- is_fp16[top_fp16] = True
 
 
 
 
 
499
 
500
  # Vectorized token binning
501
  thresh = torch.tensor([pl.threshold for pl in levels], device=keys.device)
@@ -542,8 +664,8 @@ class EnhancedSlidingPrecisionGradient:
542
  continue
543
 
544
  idx_tensor = torch.tensor(indices, device=keys.device, dtype=torch.long)
545
- k_slice = seq_keys.index_select(2, idx_tensor)
546
- v_slice = seq_values.index_select(2, idx_tensor)
547
 
548
  # Store with aggressive precision - only FP16 for ultra-selective tokens
549
  compressed_data['keys'][precision_key]['data'] = k_slice.clone()
@@ -589,7 +711,8 @@ class EnhancedSlidingPrecisionGradient:
589
 
590
  except Exception as e:
591
  logger.error(f"Error in enhanced compression for layer {layer_idx}: {e}")
592
- raise
 
593
 
594
  def _fallback_to_original_spg(self, keys: torch.Tensor, values: torch.Tensor,
595
  layer_idx: int, current_position: Optional[int]) -> Dict[str, Any]:
@@ -668,8 +791,8 @@ class EnhancedSlidingPrecisionGradient:
668
  continue
669
 
670
  level_indices = torch.tensor(indices, device=device, dtype=torch.long)
671
- k_slice = keys.index_select(2, level_indices)
672
- v_slice = values.index_select(2, level_indices)
673
 
674
  # Store with FP16 precision (simplified for original SPG)
675
  compressed_data['keys'][precision_key]['data'] = k_slice.clone()
@@ -750,8 +873,16 @@ class EnhancedSlidingPrecisionGradient:
750
  if 'heads_fp16' in compressed_data['keys']:
751
  head_indices = compressed_data['keys']['heads_fp16']['indices']
752
  head_idx_tensor = torch.tensor(head_indices, device=device, dtype=torch.long)
753
- keys_full[:, head_idx_tensor, :, :] = compressed_data['keys']['heads_fp16']['data']
754
- values_full[:, head_idx_tensor, :, :] = compressed_data['values']['heads_fp16']['data']
 
 
 
 
 
 
 
 
755
 
756
  if self.config.enable_head_compression:
757
  n_heads = original_shape[1]
@@ -768,13 +899,22 @@ class EnhancedSlidingPrecisionGradient:
768
  continue
769
 
770
  indices = compressed_data['keys'][precision_key]['indices']
 
 
 
771
  idx_tensor = torch.tensor(indices, device=device, dtype=torch.long)
772
 
773
  # All data stored as FP16 in this simplified version
774
- keys_full[:, other_head_indices, :, :].index_copy_(2, idx_tensor,
775
- compressed_data['keys'][precision_key]['data'])
776
- values_full[:, other_head_indices, :, :].index_copy_(2, idx_tensor,
777
- compressed_data['values'][precision_key]['data'])
 
 
 
 
 
 
778
 
779
  return keys_full, values_full
780
 
@@ -806,8 +946,11 @@ class EnhancedSlidingPrecisionGradient:
806
  token_indices = head_data_k['indices']
807
 
808
  # Place data in the correct head and token positions
809
- keys_full[:, head_idx:head_idx+1, token_indices, :] = head_data_k['data']
810
- values_full[:, head_idx:head_idx+1, token_indices, :] = head_data_v['data']
 
 
 
811
 
812
  return keys_full, values_full
813
 
@@ -825,11 +968,20 @@ class EnhancedSlidingPrecisionGradient:
825
  data_dict = compressed_data['keys'][precision_key]
826
  if 'data' in data_dict and 'indices' in data_dict:
827
  indices = data_dict['indices']
 
 
 
828
  idx_tensor = torch.tensor(indices, device=device, dtype=torch.long)
829
 
830
  # All data stored as original precision
831
- keys_full.index_copy_(2, idx_tensor, data_dict['data'])
832
- values_full.index_copy_(2, idx_tensor, compressed_data['values'][precision_key]['data'])
 
 
 
 
 
 
833
 
834
  return keys_full, values_full
835
 
 
1
+ # compression.py
2
  """
3
  Enhanced SPG compression algorithms with RocketKV-style 450x compression.
4
  NO ESTIMATIONS - only measured values. FAIL FAST on errors.
5
+ FIXED: CUDA assert errors, safe tensor operations, bounds checking.
6
  """
7
 
8
  import torch
 
19
 
20
  logger = logging.getLogger(__name__)
21
 
22
+
23
+ def safe_topk(tensor, k, dim=-1):
24
+ """Safe version of topk that handles edge cases."""
25
+ if tensor.numel() == 0:
26
+ logger.warning("Empty tensor in topk operation")
27
+ return torch.empty(0, dtype=torch.long, device=tensor.device), torch.empty(0, device=tensor.device)
28
+
29
+ # Ensure k doesn't exceed tensor size
30
+ max_k = tensor.shape[dim]
31
+ actual_k = min(k, max_k)
32
+
33
+ if actual_k <= 0:
34
+ logger.warning(f"Invalid k={k} for tensor with shape {tensor.shape}")
35
+ return torch.empty(0, dtype=torch.long, device=tensor.device), torch.empty(0, device=tensor.device)
36
+
37
+ return torch.topk(tensor, actual_k, dim=dim)
38
+
39
+
40
+ def safe_index_select(tensor, dim, indices):
41
+ """Safe version of index_select that validates indices."""
42
+ if indices.numel() == 0:
43
+ # Return empty tensor with correct shape
44
+ shape = list(tensor.shape)
45
+ shape[dim] = 0
46
+ return torch.empty(shape, dtype=tensor.dtype, device=tensor.device)
47
+
48
+ # Validate indices are within bounds
49
+ max_idx = tensor.shape[dim] - 1
50
+ if indices.max() > max_idx:
51
+ logger.warning(f"Index {indices.max()} exceeds max {max_idx}, clamping")
52
+ indices = indices.clamp(0, max_idx)
53
+
54
+ if indices.min() < 0:
55
+ logger.warning(f"Negative index {indices.min()}, clamping to 0")
56
+ indices = indices.clamp(0, max_idx)
57
+
58
+ return tensor.index_select(dim, indices)
59
+
60
+
61
  class EnhancedSlidingPrecisionGradient:
62
  """
63
  Research-grade Enhanced SPG with RocketKV-style 450x compression capability.
64
  NO ESTIMATIONS OR HARDCODED VALUES - all parameters from validated config.
65
+ FIXED: Safe tensor operations with bounds checking.
66
  """
67
 
68
  def __init__(self, config: EnhancedSPGConfig):
 
202
 
203
  def snapkv_plus_plus(self, keys: torch.Tensor, values: torch.Tensor,
204
  compression_ratio: float) -> Tuple[torch.Tensor, torch.Tensor, List[int]]:
205
+ """SnapKV++ with GQA support and adaptive pooling - FIXED with safe operations."""
206
  batch_size, n_heads, seq_len, head_dim = keys.shape
207
 
208
+ # CRITICAL: Ensure minimum tokens retained
209
+ min_tokens = max(8, self.config.min_tokens_for_stability) # At least 8 tokens
210
+ n_keep = max(min_tokens, int(seq_len / compression_ratio))
211
+ n_keep = min(n_keep, seq_len) # Can't keep more than we have
212
+
213
+ logger.debug(f"SnapKV++: seq_len={seq_len}, compression_ratio={compression_ratio:.1f}, n_keep={n_keep}")
214
+
215
+ if n_keep >= seq_len:
216
+ # No compression needed
217
+ return keys, values, list(range(seq_len))
218
+
219
  # Adaptive kernel size based on sequence length (from config)
220
  kernel_size = self.config.get_adaptive_kernel_size(seq_len)
221
 
222
  # Compute importance scores with adaptive pooling
223
+ try:
224
+ key_norms = keys.norm(dim=-1) # [batch, heads, seq]
225
+ value_norms = values.norm(dim=-1)
226
+ combined_importance = (key_norms + value_norms) / 2.0
227
+
228
+ # Multi-head aggregation with adaptive pooling
229
+ if kernel_size > 1 and seq_len > kernel_size:
230
+ # Apply 1D pooling along sequence dimension
231
+ pooled_importance = F.avg_pool1d(
232
+ combined_importance.mean(dim=1).unsqueeze(1), # [batch, 1, seq]
233
+ kernel_size=kernel_size,
234
+ stride=1,
235
+ padding=kernel_size // 2
236
+ ).squeeze(1) # [batch, seq]
237
+ # Ensure pooled output matches original sequence length
238
+ if pooled_importance.shape[-1] != seq_len:
239
+ pooled_importance = pooled_importance[:, :seq_len]
240
+ else:
241
+ pooled_importance = combined_importance.mean(dim=1)
242
+
243
+ # Aggregate across batch
244
+ final_importance = pooled_importance.mean(dim=0) # [seq]
245
+ except Exception as e:
246
+ logger.error(f"Error computing importance: {e}")
247
+ # Fallback to uniform importance
248
+ final_importance = torch.ones(seq_len, device=keys.device)
249
 
250
  # Ensure importance tensor matches sequence length
251
  if final_importance.shape[0] != seq_len:
 
253
 
254
  # Preserve sink and recent tokens
255
  preserve_mask = torch.zeros(seq_len, dtype=torch.bool, device=keys.device)
 
 
256
 
257
+ # Recent tokens
258
+ recent_window = min(self.config.recent_window, seq_len // 2) # Don't preserve more than half
259
+ preserve_mask[-recent_window:] = True
260
+
261
+ # Sink tokens
262
+ if self.config.sink_tokens > 0:
263
+ sink_count = min(self.config.sink_tokens, seq_len // 4) # Don't preserve more than quarter
264
+ preserve_mask[:sink_count] = True
265
+
266
+ preserved_count = preserve_mask.sum().item()
267
+ remaining_slots = max(0, n_keep - preserved_count)
268
 
269
  if remaining_slots > 0:
270
  masked_importance = final_importance.clone()
 
274
  if len(available_indices) > 0:
275
  k = min(remaining_slots, len(available_indices))
276
  if k > 0:
277
+ available_importance = masked_importance[available_indices]
278
+ _, relative_top_indices = safe_topk(available_importance, k)
279
+
280
+ if relative_top_indices.numel() > 0:
281
+ absolute_indices = available_indices[relative_top_indices]
282
+ preserve_mask[absolute_indices] = True
283
+
284
+ # Get final retained indices
285
+ retained_indices = preserve_mask.nonzero(as_tuple=True)[0]
286
 
287
+ if retained_indices.numel() == 0:
288
+ logger.error("No indices retained! Keeping at least recent tokens")
289
+ # Emergency fallback - keep last few tokens
290
+ retained_indices = torch.arange(max(0, seq_len - min_tokens), seq_len,
291
+ device=keys.device, dtype=torch.long)
292
 
293
+ # Safe indexing
294
+ keys_compressed = safe_index_select(keys, 2, retained_indices)
295
+ values_compressed = safe_index_select(values, 2, retained_indices)
296
 
297
+ actual_ratio = seq_len / len(retained_indices) if len(retained_indices) > 0 else 1.0
298
+ logger.debug(f"SnapKV++ compressed: {seq_len} → {len(retained_indices)} tokens ({actual_ratio:.1f}x)")
299
 
300
  return keys_compressed, values_compressed, retained_indices.tolist()
301
 
302
  def hybrid_sparse_attention(self, keys: torch.Tensor, values: torch.Tensor,
303
  head_budget: int, seq_budget: int) -> Dict[str, Any]:
304
+ """RocketKV-style Hybrid Sparse Attention for Stage 2 - FIXED with safe operations."""
305
  batch_size, n_heads, seq_len, head_dim = keys.shape
306
 
307
+ # Ensure minimum budgets
308
+ head_budget = max(1, min(head_budget, n_heads))
309
+ seq_budget = max(self.config.min_tokens_for_stability, min(seq_budget, seq_len))
 
 
310
 
311
+ logger.debug(f"HSA: n_heads={n_heads}, seq_len={seq_len}, head_budget={head_budget}, seq_budget={seq_budget}")
312
+
313
+ # 1. Head-wise importance scoring with safe computation
314
+ try:
315
+ head_importance = (
316
+ keys.float().pow(2).sum(dim=(-1, -2)).mean(dim=0) + # Average over batch
317
+ values.float().pow(2).sum(dim=(-1, -2)).mean(dim=0)
318
+ ) # [n_heads]
319
+ except Exception as e:
320
+ logger.error(f"Error computing head importance: {e}")
321
+ head_importance = torch.ones(n_heads, device=keys.device)
322
+
323
+ # Select top heads safely
324
+ _, top_head_indices = safe_topk(head_importance, head_budget)
325
+
326
+ if top_head_indices.numel() == 0:
327
+ # Fallback - keep first head
328
+ top_head_indices = torch.tensor([0], device=keys.device, dtype=torch.long)
329
 
330
  compressed_data = {
331
  'keys': {},
 
339
 
340
  # 2. Sequence-wise top-k selection per selected head
341
  for head_idx in top_head_indices:
342
+ head_idx_int = head_idx.item()
343
+
344
+ # Extract head data safely
345
+ head_keys = keys[:, head_idx_int:head_idx_int+1, :, :]
346
+ head_values = values[:, head_idx_int:head_idx_int+1, :, :]
347
 
348
  # Compute sequence importance for this head
349
+ try:
350
+ seq_importance = (
351
+ head_keys.norm(dim=-1).squeeze(1).mean(dim=0) +
352
+ head_values.norm(dim=-1).squeeze(1).mean(dim=0)
353
+ ) / 2.0
354
+ except Exception as e:
355
+ logger.error(f"Error computing seq importance for head {head_idx_int}: {e}")
356
+ seq_importance = torch.ones(seq_len, device=keys.device)
357
 
358
  # Apply position-based boost (from research constants)
359
  position_boost = torch.ones_like(seq_importance)
360
+ if self.config.sink_tokens > 0:
361
+ sink_count = min(self.config.sink_tokens, seq_len // 4)
362
+ position_boost[:sink_count] *= self.constants.POSITION_BOOST_SINK
363
+ if self.config.recent_window > 0:
364
+ recent_count = min(self.config.recent_window, seq_len // 2)
365
+ position_boost[-recent_count:] *= self.constants.POSITION_BOOST_RECENT
366
+
367
  boosted_importance = seq_importance * position_boost
368
 
369
  # Select top tokens for this head
370
+ _, top_token_indices = safe_topk(boosted_importance, seq_budget)
371
+
372
+ if top_token_indices.numel() == 0:
373
+ # Fallback - keep last few tokens
374
+ top_token_indices = torch.arange(max(0, seq_len - seq_budget), seq_len,
375
+ device=keys.device, dtype=torch.long)
376
 
377
  # Store compressed data
378
+ head_key = f'head_{head_idx_int}'
379
  compressed_data['keys'][head_key] = {
380
+ 'data': safe_index_select(head_keys, 2, top_token_indices),
381
  'indices': top_token_indices.tolist()
382
  }
383
  compressed_data['values'][head_key] = {
384
+ 'data': safe_index_select(head_values, 2, top_token_indices),
385
  'indices': top_token_indices.tolist()
386
  }
387
 
 
415
 
416
  # Calculate retention based on compression ratio
417
  retention_ratio = 1.0 / compression_ratio
418
+ min_retain = max(8, self.config.sink_tokens + self.config.recent_window, self.config.min_tokens_for_stability)
419
  n_retain = max(min_retain, int(seq_len * retention_ratio))
420
 
421
  # Apply layer-specific constraints (from research constants)
 
425
  else: # Late layers
426
  max_retain = int(seq_len * self.constants.LATE_LAYER_MAX_RETENTION)
427
 
428
+ n_retain = min(n_retain, max_retain, seq_len)
429
 
430
  # Compute magnitude-based importance
431
  importance_scores = self.compute_magnitude_importance(keys, values)
 
433
  # Quality preservation: boost recent tokens (explicit formula from config)
434
  recent_boost = torch.zeros_like(importance_scores)
435
  if self.config.recent_window > 0:
436
+ recent_window = min(self.config.recent_window, seq_len // 2)
437
+ recent_boost[-recent_window:] = importance_scores.max() * self.config.recent_boost_factor
438
  importance_scores = importance_scores + recent_boost
439
 
440
  # Initialize preservation mask
441
  preserve_mask = torch.zeros(seq_len, dtype=torch.bool, device=keys.device)
442
+ if self.config.sink_tokens > 0:
443
+ sink_count = min(self.config.sink_tokens, seq_len // 4)
444
+ preserve_mask[:sink_count] = True
445
+ if self.config.recent_window > 0:
446
+ recent_count = min(self.config.recent_window, seq_len // 2)
447
+ preserve_mask[-recent_count:] = True
448
 
449
  # Select additional tokens based on importance
450
  remaining_slots = n_retain - preserve_mask.sum().item()
 
464
  available = (masked_importance > -float('inf')).sum().item()
465
  k = min(remaining_slots, available)
466
  if k > 0:
467
+ _, top_indices = safe_topk(masked_importance, k)
468
+ if top_indices.numel() > 0:
469
+ preserve_mask[top_indices] = True
470
 
471
  # Extract retained tokens
472
+ retained_indices = preserve_mask.nonzero(as_tuple=True)[0]
473
+
474
+ if retained_indices.numel() == 0:
475
+ logger.error(f"No tokens retained in stage 1 layer {layer_idx}! Using fallback")
476
+ min_keep = max(8, self.config.min_tokens_for_stability)
477
+ retained_indices = torch.arange(seq_len - min_keep, seq_len, device=keys.device, dtype=torch.long)
478
 
479
+ keys_stage1 = safe_index_select(keys, 2, retained_indices)
480
+ values_stage1 = safe_index_select(values, 2, retained_indices)
481
+
482
+ actual_ratio = seq_len / len(retained_indices) if len(retained_indices) > 0 else 1.0
483
  logger.debug(f"Stage 1 Layer {layer_idx}: {seq_len} → {len(retained_indices)} tokens ({actual_ratio:.1f}x)")
484
 
485
  return keys_stage1, values_stage1, retained_indices.tolist()
 
494
 
495
  if self.use_hybrid_sparse_attention:
496
  # RocketKV-style compression with adaptive budgets
497
+ try:
498
+ sparsity = self.estimate_attention_sparsity(keys, values)
499
+ except:
500
+ sparsity = 0.5 # Default if estimation fails
501
 
502
  if self.use_adaptive_decomposition:
503
  _, stage2_ratio = self.adaptive_stage_split(
 
577
  values.float().pow(2).sum(dim=(-1, -2)).sum(dim=0)
578
  )
579
 
580
+ _, important_head_indices = safe_topk(head_importance, n_important_heads)
581
+
582
+ if important_head_indices.numel() == 0:
583
+ important_head_indices = torch.tensor([0], device=keys.device, dtype=torch.long)
584
+
585
  other_head_indices = torch.tensor(
586
  [h for h in range(n_heads) if h not in important_head_indices.tolist()],
587
  device=keys.device, dtype=torch.long
 
589
 
590
  # Store important heads at full precision
591
  compressed_data['keys']['heads_fp16'] = {
592
+ 'data': safe_index_select(keys, 1, important_head_indices).clone(),
593
  'indices': important_head_indices.tolist()
594
  }
595
  compressed_data['values']['heads_fp16'] = {
596
+ 'data': safe_index_select(values, 1, important_head_indices).clone(),
597
  'indices': important_head_indices.tolist()
598
  }
599
 
600
  if other_head_indices.numel() == 0:
601
  return compressed_data
602
 
603
+ seq_keys = safe_index_select(keys, 1, other_head_indices)
604
+ seq_values = safe_index_select(values, 1, other_head_indices)
605
  else:
606
  seq_keys = keys
607
  seq_values = values
 
611
 
612
  # Explicit top-K selection for FP16
613
  keep_fp16 = max(0, int(seq_len * self.config.sequence_compression_ratio))
 
 
614
  if keep_fp16 > 0:
615
+ top_fp16, _ = safe_topk(combined_importance, k=keep_fp16)
616
+ is_fp16 = torch.zeros(seq_len, dtype=torch.bool, device=keys.device)
617
+ if top_fp16.numel() > 0:
618
+ is_fp16[top_fp16] = True
619
+ else:
620
+ is_fp16 = torch.zeros(seq_len, dtype=torch.bool, device=keys.device)
621
 
622
  # Vectorized token binning
623
  thresh = torch.tensor([pl.threshold for pl in levels], device=keys.device)
 
664
  continue
665
 
666
  idx_tensor = torch.tensor(indices, device=keys.device, dtype=torch.long)
667
+ k_slice = safe_index_select(seq_keys, 2, idx_tensor)
668
+ v_slice = safe_index_select(seq_values, 2, idx_tensor)
669
 
670
  # Store with aggressive precision - only FP16 for ultra-selective tokens
671
  compressed_data['keys'][precision_key]['data'] = k_slice.clone()
 
711
 
712
  except Exception as e:
713
  logger.error(f"Error in enhanced compression for layer {layer_idx}: {e}")
714
+ # Fallback to original SPG on error
715
+ return self._fallback_to_original_spg(keys, values, layer_idx, current_position)
716
 
717
  def _fallback_to_original_spg(self, keys: torch.Tensor, values: torch.Tensor,
718
  layer_idx: int, current_position: Optional[int]) -> Dict[str, Any]:
 
791
  continue
792
 
793
  level_indices = torch.tensor(indices, device=device, dtype=torch.long)
794
+ k_slice = safe_index_select(keys, 2, level_indices)
795
+ v_slice = safe_index_select(values, 2, level_indices)
796
 
797
  # Store with FP16 precision (simplified for original SPG)
798
  compressed_data['keys'][precision_key]['data'] = k_slice.clone()
 
873
  if 'heads_fp16' in compressed_data['keys']:
874
  head_indices = compressed_data['keys']['heads_fp16']['indices']
875
  head_idx_tensor = torch.tensor(head_indices, device=device, dtype=torch.long)
876
+
877
+ # Safe assignment
878
+ head_data_k = compressed_data['keys']['heads_fp16']['data']
879
+ head_data_v = compressed_data['values']['heads_fp16']['data']
880
+
881
+ if head_data_k is not None and head_data_v is not None:
882
+ for i, idx in enumerate(head_indices):
883
+ if idx < keys_full.shape[1]:
884
+ keys_full[:, idx, :, :] = head_data_k[:, i, :, :]
885
+ values_full[:, idx, :, :] = head_data_v[:, i, :, :]
886
 
887
  if self.config.enable_head_compression:
888
  n_heads = original_shape[1]
 
899
  continue
900
 
901
  indices = compressed_data['keys'][precision_key]['indices']
902
+ if not indices:
903
+ continue
904
+
905
  idx_tensor = torch.tensor(indices, device=device, dtype=torch.long)
906
 
907
  # All data stored as FP16 in this simplified version
908
+ k_data = compressed_data['keys'][precision_key]['data']
909
+ v_data = compressed_data['values'][precision_key]['data']
910
+
911
+ if k_data is not None and v_data is not None:
912
+ for head_idx in other_head_indices:
913
+ if head_idx < keys_full.shape[1]:
914
+ for i, seq_idx in enumerate(indices):
915
+ if seq_idx < keys_full.shape[2]:
916
+ keys_full[:, head_idx, seq_idx, :] = k_data[:, :, i, :].squeeze(1)
917
+ values_full[:, head_idx, seq_idx, :] = v_data[:, :, i, :].squeeze(1)
918
 
919
  return keys_full, values_full
920
 
 
946
  token_indices = head_data_k['indices']
947
 
948
  # Place data in the correct head and token positions
949
+ if head_idx < keys_full.shape[1]:
950
+ for i, token_idx in enumerate(token_indices):
951
+ if token_idx < keys_full.shape[2]:
952
+ keys_full[:, head_idx, token_idx, :] = head_data_k['data'][:, 0, i, :]
953
+ values_full[:, head_idx, token_idx, :] = head_data_v['data'][:, 0, i, :]
954
 
955
  return keys_full, values_full
956
 
 
968
  data_dict = compressed_data['keys'][precision_key]
969
  if 'data' in data_dict and 'indices' in data_dict:
970
  indices = data_dict['indices']
971
+ if not indices:
972
+ continue
973
+
974
  idx_tensor = torch.tensor(indices, device=device, dtype=torch.long)
975
 
976
  # All data stored as original precision
977
+ k_data = data_dict['data']
978
+ v_data = compressed_data['values'][precision_key]['data']
979
+
980
+ if k_data is not None and v_data is not None:
981
+ for i, seq_idx in enumerate(indices):
982
+ if seq_idx < keys_full.shape[2]:
983
+ keys_full[:, :, seq_idx, :] = k_data[:, :, i, :]
984
+ values_full[:, :, seq_idx, :] = v_data[:, :, i, :]
985
 
986
  return keys_full, values_full
987