kfoughali commited on
Commit
9c6e956
·
verified ·
1 Parent(s): b3bb89e

Update benchmark.py

Browse files
Files changed (1) hide show
  1. benchmark.py +424 -191
benchmark.py CHANGED
@@ -2,6 +2,7 @@
2
  Benchmarking, metrics, and proof generation for Enhanced SPG.
3
  Supports LongBench, NIAH, RULER, SCBench benchmarks.
4
  MEASURED VALUES ONLY - no estimations. FAIL FAST on errors.
 
5
  """
6
 
7
  import torch
@@ -234,6 +235,113 @@ class BenchmarkMetrics:
234
 
235
  return (0.0, 0.0)
236
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
237
  def create_niah_haystack(context_length: int, needle: str, depth_percent: float) -> str:
238
  """Create Needle-in-a-Haystack test context - NO HARDCODING."""
239
  # Generate haystack text
@@ -255,8 +363,9 @@ def create_niah_haystack(context_length: int, needle: str, depth_percent: float)
255
 
256
  return haystack_with_needle
257
 
258
- def evaluate_niah(model, tokenizer, config: CompressionConfig, cache_manager: Optional[QuantizedKVCache] = None) -> float:
259
- """Evaluate Needle-in-a-Haystack performance - MEASURED ONLY."""
 
260
  context = create_niah_haystack(
261
  config.prefill_length,
262
  config.niah_needle,
@@ -267,46 +376,32 @@ def evaluate_niah(model, tokenizer, config: CompressionConfig, cache_manager: Op
267
 
268
  inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=config.prefill_length)
269
  input_ids = inputs.input_ids.to(model.device)
 
270
 
 
 
 
 
 
 
271
  with torch.inference_mode():
272
- if cache_manager:
273
- # Compress KV cache
274
- outputs = model(input_ids, use_cache=True, return_dict=True)
275
- past_key_values = outputs.past_key_values
276
-
277
- # Store compressed
278
- kv_tuple = past_key_values.to_legacy_cache() if hasattr(past_key_values, 'to_legacy_cache') else past_key_values
279
- for layer_idx, (keys, values) in enumerate(kv_tuple):
280
- cache_manager.compress_and_store(layer_idx, keys, values)
281
-
282
- # Reconstruct for generation
283
- reconstructed_kv = []
284
- for layer_idx in range(len(kv_tuple)):
285
- dec_keys, dec_values = cache_manager.get_decompressed(layer_idx)
286
- if dec_keys is not None and dec_values is not None:
287
- reconstructed_kv.append((dec_keys, dec_values))
288
-
289
- if hasattr(DynamicCache, 'from_legacy_cache'):
290
- past_key_values = DynamicCache.from_legacy_cache(tuple(reconstructed_kv))
291
- else:
292
- past_key_values = tuple(reconstructed_kv)
293
-
294
- # Generate with compressed cache
295
- output = model.generate(
296
- input_ids,
297
- past_key_values=past_key_values,
298
- max_new_tokens=20,
299
- temperature=0.0,
300
- do_sample=False
301
- )
302
- else:
303
- # Generate without compression
304
- output = model.generate(
305
- input_ids,
306
- max_new_tokens=20,
307
- temperature=0.0,
308
- do_sample=False
309
- )
310
 
311
  generated_text = tokenizer.decode(output[0][input_ids.shape[1]:], skip_special_tokens=True)
312
 
@@ -314,59 +409,20 @@ def evaluate_niah(model, tokenizer, config: CompressionConfig, cache_manager: Op
314
  accuracy = 1.0 if config.niah_needle.split()[-1] in generated_text else 0.0
315
 
316
  logger.info(f"NIAH accuracy: {accuracy}, Generated: {generated_text[:50]}")
317
- return accuracy
 
 
 
 
 
 
 
 
 
318
 
319
- def evaluate_longbench_task(model, tokenizer, config: CompressionConfig,
320
- task: str, cache_manager: Optional[QuantizedKVCache] = None) -> Dict[str, float]:
321
- """Evaluate LongBench task - MEASURED METRICS ONLY."""
322
- try:
323
- dataset = load_dataset("THUDM/LongBench", task, split="test")
324
-
325
- # Sample evaluation examples
326
- n_samples = min(config.eval_samples, len(dataset))
327
- samples = dataset.select(range(n_samples))
328
-
329
- scores = []
330
- for sample in samples:
331
- context = sample.get("context", "")
332
- question = sample.get("input", sample.get("question", ""))
333
- answer = sample.get("answers", [sample.get("answer", "")])
334
-
335
- if isinstance(answer, list) and answer:
336
- answer = answer[0]
337
-
338
- prompt = f"Context: {context}\n\nQuestion: {question}\n\nAnswer:"
339
-
340
- inputs = tokenizer(prompt, return_tensors="pt", truncation=True,
341
- max_length=config.prefill_length)
342
- input_ids = inputs.input_ids.to(model.device)
343
-
344
- with torch.inference_mode():
345
- output = model.generate(
346
- input_ids,
347
- max_new_tokens=50,
348
- temperature=0.0,
349
- do_sample=False
350
- )
351
-
352
- generated = tokenizer.decode(output[0][input_ids.shape[1]:], skip_special_tokens=True)
353
-
354
- # Simple accuracy metric - check if answer appears in generation
355
- score = 1.0 if str(answer).lower() in generated.lower() else 0.0
356
- scores.append(score)
357
-
358
- return {
359
- "accuracy": float(np.mean(scores)),
360
- "n_samples": n_samples
361
- }
362
-
363
- except Exception as e:
364
- logger.error(f"Error evaluating LongBench task {task}: {e}")
365
- return {"accuracy": 0.0, "n_samples": 0}
366
 
367
- def evaluate_ruler(model, tokenizer, config: CompressionConfig,
368
- cache_manager: Optional[QuantizedKVCache] = None) -> float:
369
- """Evaluate RULER benchmark - MEASURED ONLY."""
370
  # Create synthetic RULER-like task
371
  seq_len = min(config.ruler_max_seq_length, config.prefill_length)
372
 
@@ -383,14 +439,31 @@ def evaluate_ruler(model, tokenizer, config: CompressionConfig,
383
 
384
  inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=seq_len)
385
  input_ids = inputs.input_ids.to(model.device)
 
386
 
 
 
 
 
 
 
387
  with torch.inference_mode():
 
 
 
 
388
  output = model.generate(
389
  input_ids,
 
390
  max_new_tokens=10,
391
  temperature=0.0,
392
- do_sample=False
 
393
  )
 
 
 
 
394
 
395
  generated = tokenizer.decode(output[0][input_ids.shape[1]:], skip_special_tokens=True)
396
 
@@ -399,11 +472,20 @@ def evaluate_ruler(model, tokenizer, config: CompressionConfig,
399
  exact_match = 1.0 if expected in generated else 0.0
400
 
401
  logger.info(f"RULER exact match: {exact_match}, Generated: {generated[:50]}")
402
- return exact_match
 
 
 
 
 
 
 
 
 
 
403
 
404
- def evaluate_scbench(model, tokenizer, config: CompressionConfig,
405
- cache_manager: Optional[QuantizedKVCache] = None) -> float:
406
- """Evaluate SCBench multi-turn conversation - MEASURED ONLY."""
407
  # Create multi-turn conversation
408
  conversation = []
409
  facts = {}
@@ -428,14 +510,31 @@ def evaluate_scbench(model, tokenizer, config: CompressionConfig,
428
  inputs = tokenizer(full_conversation, return_tensors="pt", truncation=True,
429
  max_length=config.prefill_length)
430
  input_ids = inputs.input_ids.to(model.device)
 
431
 
 
 
 
 
 
 
432
  with torch.inference_mode():
 
 
 
 
433
  output = model.generate(
434
  input_ids,
 
435
  max_new_tokens=20,
436
  temperature=0.0,
437
- do_sample=False
 
438
  )
 
 
 
 
439
 
440
  generated = tokenizer.decode(output[0][input_ids.shape[1]:], skip_special_tokens=True)
441
 
@@ -444,7 +543,107 @@ def evaluate_scbench(model, tokenizer, config: CompressionConfig,
444
  accuracy = 1.0 if expected_value in generated else 0.0
445
 
446
  logger.info(f"SCBench accuracy: {accuracy}, Generated: {generated[:50]}")
447
- return accuracy
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
448
 
449
  def load_model_and_tokenizer(model_name: str, config: CompressionConfig):
450
  """Load model and tokenizer with proper configuration - NO HARDCODING."""
@@ -496,11 +695,12 @@ def load_model_and_tokenizer(model_name: str, config: CompressionConfig):
496
 
497
  return model, tokenizer
498
 
 
499
  def load_real_dataset_samples(config: CompressionConfig, tokenizer) -> List[str]:
500
  """Load dataset samples based on benchmark type - NO HARDCODING."""
501
  logger.info(f"Loading samples for benchmark: {config.benchmark_type}")
502
 
503
- if config.benchmark_type == "perplexity":
504
  # Original WikiText loading
505
  texts = []
506
  min_tokens = config.prefill_length + config.generation_length
@@ -568,8 +768,9 @@ def load_real_dataset_samples(config: CompressionConfig, tokenizer) -> List[str]
568
  logger.info(f"Loaded {len(texts)} text samples")
569
  return texts
570
 
 
571
  def run_research_benchmark(model_name: str, config: CompressionConfig, dataset_texts: Optional[List[str]] = None) -> Tuple[BenchmarkMetrics, Dict, List[Dict], List[Dict]]:
572
- """Research-grade benchmark with support for multiple benchmarks."""
573
  logger.info(f"Starting benchmark: {model_name} with {config.compression_type.value}")
574
  logger.info(f"Benchmark type: {config.benchmark_type}")
575
  logger.info(f"Config hash: {config.get_hash()}")
@@ -611,57 +812,117 @@ def run_research_benchmark(model_name: str, config: CompressionConfig, dataset_t
611
 
612
  metrics = BenchmarkMetrics()
613
 
614
- # Run benchmark-specific evaluation
615
  if config.benchmark_type == "niah":
616
- # NIAH evaluation
617
  for depth in BENCHMARK_CONFIGS["niah"]["depths"]:
618
  config.niah_depth_percent = depth
619
  for idx in range(min(config.eval_samples, 10)):
620
  cache_manager = QuantizedKVCache(config)
621
  cache_manager.n_layers = n_layers
622
 
623
- accuracy = evaluate_niah(model, tokenizer, config, cache_manager)
624
- metrics.niah_retrieval_accuracy.append(accuracy)
 
 
 
 
 
625
 
626
- compressed_size = cache_manager.get_memory_footprint()
627
- metrics.kv_cache_memory_samples_mb.append(compressed_size / (1024 * 1024))
 
 
 
 
 
 
 
 
 
 
 
628
 
629
  elif config.benchmark_type == "ruler":
630
- # RULER evaluation
631
  for idx in range(config.eval_samples):
632
  cache_manager = QuantizedKVCache(config)
633
  cache_manager.n_layers = n_layers
634
 
635
- exact_match = evaluate_ruler(model, tokenizer, config, cache_manager)
636
- metrics.ruler_exact_match.append(exact_match)
 
 
 
 
 
 
 
 
637
 
638
- compressed_size = cache_manager.get_memory_footprint()
639
- metrics.kv_cache_memory_samples_mb.append(compressed_size / (1024 * 1024))
 
 
 
 
 
 
640
 
641
  elif config.benchmark_type == "scbench":
642
- # SCBench evaluation
643
  for idx in range(config.eval_samples):
644
  cache_manager = QuantizedKVCache(config)
645
  cache_manager.n_layers = n_layers
646
 
647
- accuracy = evaluate_scbench(model, tokenizer, config, cache_manager)
648
- metrics.scbench_turn_accuracy.append(accuracy)
649
 
650
- compressed_size = cache_manager.get_memory_footprint()
651
- metrics.kv_cache_memory_samples_mb.append(compressed_size / (1024 * 1024))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
652
 
653
  elif config.benchmark_type == "longbench":
654
- # LongBench evaluation
655
  if config.benchmark_subset:
656
  cache_manager = QuantizedKVCache(config)
657
  cache_manager.n_layers = n_layers
658
 
659
- scores = evaluate_longbench_task(model, tokenizer, config,
660
  config.benchmark_subset, cache_manager)
661
- metrics.longbench_scores.append(scores)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
662
 
663
  else:
664
- # Standard perplexity evaluation
665
  for idx in range(config.eval_samples):
666
  logger.info(f"Sample {idx+1}/{config.eval_samples}")
667
 
@@ -682,68 +943,27 @@ def run_research_benchmark(model_name: str, config: CompressionConfig, dataset_t
682
  input_ids = inputs.input_ids.to(device)
683
  attention_mask = inputs.attention_mask.to(device)
684
 
685
- if torch.cuda.is_available():
686
- torch.cuda.empty_cache()
687
- torch.cuda.reset_peak_memory_stats()
688
- torch.cuda.synchronize()
689
-
690
- # Prefill
691
- if torch.cuda.is_available():
692
- torch.cuda.synchronize()
693
- start_time_sample = time.perf_counter()
694
-
695
- with torch.inference_mode():
696
- outputs = model(
697
- input_ids,
698
- attention_mask=attention_mask,
699
- use_cache=True,
700
- return_dict=True
701
- )
702
- past_key_values = outputs.past_key_values
703
-
704
- if torch.cuda.is_available():
705
- torch.cuda.synchronize()
706
-
707
- prefill_time = time.perf_counter() - start_time_sample
708
 
709
- if torch.cuda.is_available():
710
- prefill_peak_mem = _peak_mem_bytes_all_gpus()
711
- metrics.prefill_peak_memories.append(prefill_peak_mem)
712
-
713
- metrics.prefill_times.append(prefill_time)
714
-
715
- # Compression
716
- original_cache_size = 0
717
- if past_key_values:
718
- kv_tuple = past_key_values.to_legacy_cache() if hasattr(past_key_values, 'to_legacy_cache') else past_key_values
719
- for layer_idx, (keys, values) in enumerate(kv_tuple):
720
- original_cache_size += keys.nelement() * keys.element_size()
721
- original_cache_size += values.nelement() * values.element_size()
722
- if config.compression_type != CompressionType.NONE:
723
- cache_manager.compress_and_store(layer_idx, keys, values)
724
-
725
- if config.compression_type != CompressionType.NONE:
726
- reconstructed_kv = []
727
- for layer_idx in range(len(kv_tuple)):
728
- dec_keys, dec_values = cache_manager.get_decompressed(layer_idx)
729
- if dec_keys is not None and dec_values is not None:
730
- reconstructed_kv.append((dec_keys, dec_values))
731
-
732
- if hasattr(DynamicCache, 'from_legacy_cache'):
733
- past_key_values = DynamicCache.from_legacy_cache(tuple(reconstructed_kv))
734
- else:
735
- past_key_values = tuple(reconstructed_kv)
736
 
737
- compressed_size = original_cache_size if config.compression_type == CompressionType.NONE else cache_manager.get_memory_footprint()
738
- comp_ratio = original_cache_size / compressed_size if compressed_size > 0 else 1.0
739
 
740
- metrics.compression_ratios.append(comp_ratio)
741
- metrics.kv_cache_memory_samples_mb.append(compressed_size / (1024 * 1024))
 
742
 
743
- # Generation
744
  generated_ids = input_ids.clone()
745
  decode_times = []
746
  generation_losses = []
 
747
 
748
  for gen_step in range(config.generation_length):
749
  if torch.cuda.is_available():
@@ -778,11 +998,21 @@ def run_research_benchmark(model_name: str, config: CompressionConfig, dataset_t
778
  if generation_losses:
779
  generation_perplexity = np.exp(np.mean(generation_losses))
780
  metrics.generation_perplexities.append(min(generation_perplexity, 1000))
 
 
 
 
 
 
 
 
 
 
781
 
782
  metrics.calculate_statistics(config)
783
  all_metrics.append(metrics)
784
 
785
- # Aggregate results
786
  final_metrics = BenchmarkMetrics()
787
  for m in all_metrics:
788
  final_metrics.prefill_times.extend(m.prefill_times)
@@ -826,15 +1056,18 @@ def run_research_benchmark(model_name: str, config: CompressionConfig, dataset_t
826
  else:
827
  summary['prefill_perplexity'] = final_metrics.prefill_perplexity_mean
828
  summary['generation_perplexity'] = final_metrics.generation_perplexity_mean
829
- summary['prefill_time_ms'] = final_metrics.prefill_time_mean * 1000
830
- summary['decode_time_ms'] = final_metrics.decode_time_per_token_mean_ms
831
- summary['throughput_tokens_sec'] = final_metrics.decode_tokens_per_sec
832
- summary['end_to_end_throughput'] = final_metrics.end_to_end_throughput
833
- summary['end_to_end_latency_ms'] = final_metrics.end_to_end_latency_ms
834
- summary['peak_memory_mb'] = final_metrics.prefill_peak_memory_mean_mb
 
 
835
 
836
  return final_metrics, summary, per_sample_records, per_layer_fingerprints
837
 
 
838
  def export_proof_bundle(bundle_dir: str, config: CompressionConfig,
839
  metrics: BenchmarkMetrics, summary: Dict[str, Any],
840
  per_sample_records: List[Dict[str, Any]],
@@ -889,6 +1122,7 @@ def export_proof_bundle(bundle_dir: str, config: CompressionConfig,
889
  logger.info(f"Proof bundle exported: {zip_path}")
890
  return zip_path
891
 
 
892
  def verify_proof_bundle(bundle_root: str, config: CompressionConfig, proving: ProvingConfig) -> Dict[str, Any]:
893
  """Verify proof bundle - recompute metrics and check tolerances."""
894
  try:
@@ -924,27 +1158,26 @@ def verify_proof_bundle(bundle_root: str, config: CompressionConfig, proving: Pr
924
  # Verify based on benchmark type
925
  if config.benchmark_type == "niah":
926
  if "niah_accuracy" in summary:
927
- recomputed["niah_accuracy"] = mean_of("niah_accuracy")
928
  elif config.benchmark_type == "ruler":
929
  if "ruler_exact_match" in summary:
930
- recomputed["ruler_exact_match"] = mean_of("ruler_exact_match")
931
  elif config.benchmark_type == "scbench":
932
  if "scbench_accuracy" in summary:
933
- recomputed["scbench_accuracy"] = mean_of("scbench_accuracy")
934
  elif config.benchmark_type == "longbench":
935
  if "longbench_accuracy" in summary:
936
- recomputed["longbench_accuracy"] = mean_of("longbench_accuracy")
937
  elif config.benchmark_type == "wikitext":
938
  # WikiText benchmark metrics
939
- recomputed["compression_ratio"] = mean_of("compression_ratio")
940
- recomputed["kv_cache_memory_mb"] = mean_of("kv_cache_memory_mb")
941
  if "prefill_perplexity" in summary:
942
  recomputed["prefill_perplexity"] = mean_of("prefill_perplexity")
943
  if "generation_perplexity" in summary:
944
  recomputed["generation_perplexity"] = mean_of("generation_perplexity")
945
- else:
946
- recomputed["compression_ratio"] = mean_of("compression_ratio")
947
- recomputed["kv_cache_memory_mb"] = mean_of("kv_cache_memory_mb")
 
948
 
949
  for k, v in recomputed.items():
950
  s = summary.get(k)
 
2
  Benchmarking, metrics, and proof generation for Enhanced SPG.
3
  Supports LongBench, NIAH, RULER, SCBench benchmarks.
4
  MEASURED VALUES ONLY - no estimations. FAIL FAST on errors.
5
+ ALL BENCHMARKS USE SAME COMPRESSION PIPELINE AS WIKITEXT.
6
  """
7
 
8
  import torch
 
235
 
236
  return (0.0, 0.0)
237
 
238
+
239
+ def apply_compression_pipeline(model, tokenizer, input_ids, attention_mask,
240
+ cache_manager: QuantizedKVCache, config: CompressionConfig,
241
+ measure_memory: bool = True) -> Dict[str, Any]:
242
+ """
243
+ Unified compression pipeline for ALL benchmarks.
244
+ Returns compressed cache, metrics, and reconstructed KV pairs.
245
+ """
246
+ device = input_ids.device
247
+
248
+ # Clear GPU cache if requested
249
+ if torch.cuda.is_available() and measure_memory:
250
+ torch.cuda.empty_cache()
251
+ torch.cuda.reset_peak_memory_stats()
252
+ torch.cuda.synchronize()
253
+
254
+ # Measure prefill time
255
+ if torch.cuda.is_available():
256
+ torch.cuda.synchronize()
257
+ start_time = time.perf_counter()
258
+
259
+ # Prefill phase
260
+ with torch.inference_mode():
261
+ outputs = model(
262
+ input_ids,
263
+ attention_mask=attention_mask,
264
+ use_cache=True,
265
+ return_dict=True
266
+ )
267
+ past_key_values = outputs.past_key_values
268
+ logits = outputs.logits
269
+
270
+ if torch.cuda.is_available():
271
+ torch.cuda.synchronize()
272
+
273
+ prefill_time = time.perf_counter() - start_time
274
+
275
+ # Measure peak memory
276
+ prefill_peak_mem = 0
277
+ if torch.cuda.is_available() and measure_memory:
278
+ prefill_peak_mem = _peak_mem_bytes_all_gpus()
279
+
280
+ # Calculate prefill perplexity if we have logits
281
+ prefill_loss = None
282
+ if logits is not None and input_ids.shape[1] > 1:
283
+ shift_logits = logits[..., :-1, :].contiguous()
284
+ shift_labels = input_ids[..., 1:].contiguous()
285
+ loss = F.cross_entropy(
286
+ shift_logits.view(-1, shift_logits.size(-1)),
287
+ shift_labels.view(-1),
288
+ reduction='mean',
289
+ ignore_index=tokenizer.pad_token_id if tokenizer.pad_token_id is not None else -100
290
+ )
291
+ prefill_loss = loss.item()
292
+
293
+ # Compression phase - same as WikiText
294
+ original_cache_size = 0
295
+ compressed_cache_size = 0
296
+ compression_ratio = 1.0
297
+
298
+ if past_key_values:
299
+ # Convert to legacy format for processing
300
+ kv_tuple = past_key_values.to_legacy_cache() if hasattr(past_key_values, 'to_legacy_cache') else past_key_values
301
+
302
+ # Calculate original size
303
+ for layer_idx, (keys, values) in enumerate(kv_tuple):
304
+ original_cache_size += keys.nelement() * keys.element_size()
305
+ original_cache_size += values.nelement() * values.element_size()
306
+
307
+ # Apply compression if enabled
308
+ if config.compression_type != CompressionType.NONE:
309
+ cache_manager.compress_and_store(layer_idx, keys, values)
310
+
311
+ # Reconstruct compressed cache
312
+ if config.compression_type != CompressionType.NONE:
313
+ reconstructed_kv = []
314
+ for layer_idx in range(len(kv_tuple)):
315
+ dec_keys, dec_values = cache_manager.get_decompressed(layer_idx)
316
+ if dec_keys is not None and dec_values is not None:
317
+ reconstructed_kv.append((dec_keys, dec_values))
318
+
319
+ # Convert back to DynamicCache format
320
+ if hasattr(DynamicCache, 'from_legacy_cache'):
321
+ past_key_values = DynamicCache.from_legacy_cache(tuple(reconstructed_kv))
322
+ else:
323
+ past_key_values = tuple(reconstructed_kv)
324
+
325
+ # Measure compressed size
326
+ compressed_cache_size = cache_manager.get_memory_footprint()
327
+ else:
328
+ compressed_cache_size = original_cache_size
329
+
330
+ # Calculate compression ratio
331
+ compression_ratio = original_cache_size / compressed_cache_size if compressed_cache_size > 0 else 1.0
332
+
333
+ return {
334
+ 'past_key_values': past_key_values,
335
+ 'prefill_time': prefill_time,
336
+ 'prefill_peak_mem': prefill_peak_mem,
337
+ 'prefill_loss': prefill_loss,
338
+ 'original_cache_size': original_cache_size,
339
+ 'compressed_cache_size': compressed_cache_size,
340
+ 'compression_ratio': compression_ratio,
341
+ 'logits': logits
342
+ }
343
+
344
+
345
  def create_niah_haystack(context_length: int, needle: str, depth_percent: float) -> str:
346
  """Create Needle-in-a-Haystack test context - NO HARDCODING."""
347
  # Generate haystack text
 
363
 
364
  return haystack_with_needle
365
 
366
+
367
+ def evaluate_niah(model, tokenizer, config: CompressionConfig, cache_manager: Optional[QuantizedKVCache] = None) -> Dict[str, Any]:
368
+ """Evaluate NIAH with SAME compression pipeline as WikiText."""
369
  context = create_niah_haystack(
370
  config.prefill_length,
371
  config.niah_needle,
 
376
 
377
  inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=config.prefill_length)
378
  input_ids = inputs.input_ids.to(model.device)
379
+ attention_mask = inputs.attention_mask.to(model.device)
380
 
381
+ # Apply SAME compression pipeline as WikiText
382
+ compression_result = apply_compression_pipeline(
383
+ model, tokenizer, input_ids, attention_mask, cache_manager, config
384
+ )
385
+
386
+ # Generate with compressed cache
387
  with torch.inference_mode():
388
+ # Measure generation time
389
+ if torch.cuda.is_available():
390
+ torch.cuda.synchronize()
391
+ gen_start = time.perf_counter()
392
+
393
+ output = model.generate(
394
+ input_ids,
395
+ past_key_values=compression_result['past_key_values'],
396
+ max_new_tokens=20,
397
+ temperature=0.0,
398
+ do_sample=False,
399
+ attention_mask=attention_mask
400
+ )
401
+
402
+ if torch.cuda.is_available():
403
+ torch.cuda.synchronize()
404
+ gen_time = time.perf_counter() - gen_start
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
405
 
406
  generated_text = tokenizer.decode(output[0][input_ids.shape[1]:], skip_special_tokens=True)
407
 
 
409
  accuracy = 1.0 if config.niah_needle.split()[-1] in generated_text else 0.0
410
 
411
  logger.info(f"NIAH accuracy: {accuracy}, Generated: {generated_text[:50]}")
412
+ logger.info(f"NIAH compression ratio: {compression_result['compression_ratio']:.1f}x")
413
+
414
+ return {
415
+ 'accuracy': accuracy,
416
+ 'compression_ratio': compression_result['compression_ratio'],
417
+ 'kv_cache_memory_mb': compression_result['compressed_cache_size'] / (1024 * 1024),
418
+ 'prefill_time': compression_result['prefill_time'],
419
+ 'generation_time': gen_time,
420
+ 'prefill_peak_mem': compression_result['prefill_peak_mem']
421
+ }
422
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
423
 
424
+ def evaluate_ruler(model, tokenizer, config: CompressionConfig, cache_manager: Optional[QuantizedKVCache] = None) -> Dict[str, Any]:
425
+ """Evaluate RULER with SAME compression pipeline as WikiText."""
 
426
  # Create synthetic RULER-like task
427
  seq_len = min(config.ruler_max_seq_length, config.prefill_length)
428
 
 
439
 
440
  inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=seq_len)
441
  input_ids = inputs.input_ids.to(model.device)
442
+ attention_mask = inputs.attention_mask.to(model.device)
443
 
444
+ # Apply SAME compression pipeline as WikiText
445
+ compression_result = apply_compression_pipeline(
446
+ model, tokenizer, input_ids, attention_mask, cache_manager, config
447
+ )
448
+
449
+ # Generate with compressed cache
450
  with torch.inference_mode():
451
+ if torch.cuda.is_available():
452
+ torch.cuda.synchronize()
453
+ gen_start = time.perf_counter()
454
+
455
  output = model.generate(
456
  input_ids,
457
+ past_key_values=compression_result['past_key_values'],
458
  max_new_tokens=10,
459
  temperature=0.0,
460
+ do_sample=False,
461
+ attention_mask=attention_mask
462
  )
463
+
464
+ if torch.cuda.is_available():
465
+ torch.cuda.synchronize()
466
+ gen_time = time.perf_counter() - gen_start
467
 
468
  generated = tokenizer.decode(output[0][input_ids.shape[1]:], skip_special_tokens=True)
469
 
 
472
  exact_match = 1.0 if expected in generated else 0.0
473
 
474
  logger.info(f"RULER exact match: {exact_match}, Generated: {generated[:50]}")
475
+ logger.info(f"RULER compression ratio: {compression_result['compression_ratio']:.1f}x")
476
+
477
+ return {
478
+ 'exact_match': exact_match,
479
+ 'compression_ratio': compression_result['compression_ratio'],
480
+ 'kv_cache_memory_mb': compression_result['compressed_cache_size'] / (1024 * 1024),
481
+ 'prefill_time': compression_result['prefill_time'],
482
+ 'generation_time': gen_time,
483
+ 'prefill_peak_mem': compression_result['prefill_peak_mem']
484
+ }
485
+
486
 
487
+ def evaluate_scbench(model, tokenizer, config: CompressionConfig, cache_manager: Optional[QuantizedKVCache] = None) -> Dict[str, Any]:
488
+ """Evaluate SCBench with SAME compression pipeline as WikiText."""
 
489
  # Create multi-turn conversation
490
  conversation = []
491
  facts = {}
 
510
  inputs = tokenizer(full_conversation, return_tensors="pt", truncation=True,
511
  max_length=config.prefill_length)
512
  input_ids = inputs.input_ids.to(model.device)
513
+ attention_mask = inputs.attention_mask.to(model.device)
514
 
515
+ # Apply SAME compression pipeline as WikiText
516
+ compression_result = apply_compression_pipeline(
517
+ model, tokenizer, input_ids, attention_mask, cache_manager, config
518
+ )
519
+
520
+ # Generate with compressed cache
521
  with torch.inference_mode():
522
+ if torch.cuda.is_available():
523
+ torch.cuda.synchronize()
524
+ gen_start = time.perf_counter()
525
+
526
  output = model.generate(
527
  input_ids,
528
+ past_key_values=compression_result['past_key_values'],
529
  max_new_tokens=20,
530
  temperature=0.0,
531
+ do_sample=False,
532
+ attention_mask=attention_mask
533
  )
534
+
535
+ if torch.cuda.is_available():
536
+ torch.cuda.synchronize()
537
+ gen_time = time.perf_counter() - gen_start
538
 
539
  generated = tokenizer.decode(output[0][input_ids.shape[1]:], skip_special_tokens=True)
540
 
 
543
  accuracy = 1.0 if expected_value in generated else 0.0
544
 
545
  logger.info(f"SCBench accuracy: {accuracy}, Generated: {generated[:50]}")
546
+ logger.info(f"SCBench compression ratio: {compression_result['compression_ratio']:.1f}x")
547
+
548
+ return {
549
+ 'accuracy': accuracy,
550
+ 'compression_ratio': compression_result['compression_ratio'],
551
+ 'kv_cache_memory_mb': compression_result['compressed_cache_size'] / (1024 * 1024),
552
+ 'prefill_time': compression_result['prefill_time'],
553
+ 'generation_time': gen_time,
554
+ 'prefill_peak_mem': compression_result['prefill_peak_mem']
555
+ }
556
+
557
+
558
+ def evaluate_longbench_task(model, tokenizer, config: CompressionConfig,
559
+ task: str, cache_manager: Optional[QuantizedKVCache] = None) -> Dict[str, Any]:
560
+ """Evaluate LongBench with SAME compression pipeline as WikiText."""
561
+ try:
562
+ dataset = load_dataset("THUDM/LongBench", task, split="test")
563
+
564
+ # Sample evaluation examples
565
+ n_samples = min(config.eval_samples, len(dataset))
566
+ samples = dataset.select(range(n_samples))
567
+
568
+ scores = []
569
+ compression_ratios = []
570
+ kv_memories = []
571
+ prefill_times = []
572
+ gen_times = []
573
+
574
+ for sample in samples:
575
+ context = sample.get("context", "")
576
+ question = sample.get("input", sample.get("question", ""))
577
+ answer = sample.get("answers", [sample.get("answer", "")])
578
+
579
+ if isinstance(answer, list) and answer:
580
+ answer = answer[0]
581
+
582
+ prompt = f"Context: {context}\n\nQuestion: {question}\n\nAnswer:"
583
+
584
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True,
585
+ max_length=config.prefill_length)
586
+ input_ids = inputs.input_ids.to(model.device)
587
+ attention_mask = inputs.attention_mask.to(model.device)
588
+
589
+ # Apply SAME compression pipeline as WikiText
590
+ compression_result = apply_compression_pipeline(
591
+ model, tokenizer, input_ids, attention_mask, cache_manager, config,
592
+ measure_memory=False # Don't measure memory for each sample
593
+ )
594
+
595
+ # Generate with compressed cache
596
+ with torch.inference_mode():
597
+ if torch.cuda.is_available():
598
+ torch.cuda.synchronize()
599
+ gen_start = time.perf_counter()
600
+
601
+ output = model.generate(
602
+ input_ids,
603
+ past_key_values=compression_result['past_key_values'],
604
+ max_new_tokens=50,
605
+ temperature=0.0,
606
+ do_sample=False,
607
+ attention_mask=attention_mask
608
+ )
609
+
610
+ if torch.cuda.is_available():
611
+ torch.cuda.synchronize()
612
+ gen_time = time.perf_counter() - gen_start
613
+
614
+ generated = tokenizer.decode(output[0][input_ids.shape[1]:], skip_special_tokens=True)
615
+
616
+ # Simple accuracy metric
617
+ score = 1.0 if str(answer).lower() in generated.lower() else 0.0
618
+ scores.append(score)
619
+ compression_ratios.append(compression_result['compression_ratio'])
620
+ kv_memories.append(compression_result['compressed_cache_size'] / (1024 * 1024))
621
+ prefill_times.append(compression_result['prefill_time'])
622
+ gen_times.append(gen_time)
623
+
624
+ avg_compression = float(np.mean(compression_ratios)) if compression_ratios else 1.0
625
+ logger.info(f"LongBench {task} avg compression: {avg_compression:.1f}x")
626
+
627
+ return {
628
+ 'accuracy': float(np.mean(scores)),
629
+ 'n_samples': n_samples,
630
+ 'compression_ratio': avg_compression,
631
+ 'kv_cache_memory_mb': float(np.mean(kv_memories)) if kv_memories else 0.0,
632
+ 'prefill_time': float(np.mean(prefill_times)) if prefill_times else 0.0,
633
+ 'generation_time': float(np.mean(gen_times)) if gen_times else 0.0
634
+ }
635
+
636
+ except Exception as e:
637
+ logger.error(f"Error evaluating LongBench task {task}: {e}")
638
+ return {
639
+ 'accuracy': 0.0,
640
+ 'n_samples': 0,
641
+ 'compression_ratio': 1.0,
642
+ 'kv_cache_memory_mb': 0.0,
643
+ 'prefill_time': 0.0,
644
+ 'generation_time': 0.0
645
+ }
646
+
647
 
648
  def load_model_and_tokenizer(model_name: str, config: CompressionConfig):
649
  """Load model and tokenizer with proper configuration - NO HARDCODING."""
 
695
 
696
  return model, tokenizer
697
 
698
+
699
  def load_real_dataset_samples(config: CompressionConfig, tokenizer) -> List[str]:
700
  """Load dataset samples based on benchmark type - NO HARDCODING."""
701
  logger.info(f"Loading samples for benchmark: {config.benchmark_type}")
702
 
703
+ if config.benchmark_type == "wikitext":
704
  # Original WikiText loading
705
  texts = []
706
  min_tokens = config.prefill_length + config.generation_length
 
768
  logger.info(f"Loaded {len(texts)} text samples")
769
  return texts
770
 
771
+
772
  def run_research_benchmark(model_name: str, config: CompressionConfig, dataset_texts: Optional[List[str]] = None) -> Tuple[BenchmarkMetrics, Dict, List[Dict], List[Dict]]:
773
+ """Research-grade benchmark with UNIFIED compression for ALL benchmarks."""
774
  logger.info(f"Starting benchmark: {model_name} with {config.compression_type.value}")
775
  logger.info(f"Benchmark type: {config.benchmark_type}")
776
  logger.info(f"Config hash: {config.get_hash()}")
 
812
 
813
  metrics = BenchmarkMetrics()
814
 
815
+ # Run benchmark-specific evaluation with UNIFIED compression
816
  if config.benchmark_type == "niah":
817
+ # NIAH evaluation with unified compression
818
  for depth in BENCHMARK_CONFIGS["niah"]["depths"]:
819
  config.niah_depth_percent = depth
820
  for idx in range(min(config.eval_samples, 10)):
821
  cache_manager = QuantizedKVCache(config)
822
  cache_manager.n_layers = n_layers
823
 
824
+ result = evaluate_niah(model, tokenizer, config, cache_manager)
825
+
826
+ metrics.niah_retrieval_accuracy.append(result['accuracy'])
827
+ metrics.compression_ratios.append(result['compression_ratio'])
828
+ metrics.kv_cache_memory_samples_mb.append(result['kv_cache_memory_mb'])
829
+ metrics.prefill_times.append(result['prefill_time'])
830
+ metrics.decode_times.append(result['generation_time'] / 20) # Per token
831
 
832
+ if result['prefill_peak_mem'] > 0:
833
+ metrics.prefill_peak_memories.append(result['prefill_peak_mem'])
834
+
835
+ # Record per-sample data
836
+ per_sample_records.append({
837
+ 'benchmark': 'niah',
838
+ 'depth_percent': depth,
839
+ 'sample_idx': idx,
840
+ 'accuracy': result['accuracy'],
841
+ 'compression_ratio': result['compression_ratio'],
842
+ 'kv_cache_memory_mb': result['kv_cache_memory_mb'],
843
+ 'compression_type': config.compression_type.value
844
+ })
845
 
846
  elif config.benchmark_type == "ruler":
847
+ # RULER evaluation with unified compression
848
  for idx in range(config.eval_samples):
849
  cache_manager = QuantizedKVCache(config)
850
  cache_manager.n_layers = n_layers
851
 
852
+ result = evaluate_ruler(model, tokenizer, config, cache_manager)
853
+
854
+ metrics.ruler_exact_match.append(result['exact_match'])
855
+ metrics.compression_ratios.append(result['compression_ratio'])
856
+ metrics.kv_cache_memory_samples_mb.append(result['kv_cache_memory_mb'])
857
+ metrics.prefill_times.append(result['prefill_time'])
858
+ metrics.decode_times.append(result['generation_time'] / 10) # Per token
859
+
860
+ if result['prefill_peak_mem'] > 0:
861
+ metrics.prefill_peak_memories.append(result['prefill_peak_mem'])
862
 
863
+ per_sample_records.append({
864
+ 'benchmark': 'ruler',
865
+ 'sample_idx': idx,
866
+ 'exact_match': result['exact_match'],
867
+ 'compression_ratio': result['compression_ratio'],
868
+ 'kv_cache_memory_mb': result['kv_cache_memory_mb'],
869
+ 'compression_type': config.compression_type.value
870
+ })
871
 
872
  elif config.benchmark_type == "scbench":
873
+ # SCBench evaluation with unified compression
874
  for idx in range(config.eval_samples):
875
  cache_manager = QuantizedKVCache(config)
876
  cache_manager.n_layers = n_layers
877
 
878
+ result = evaluate_scbench(model, tokenizer, config, cache_manager)
 
879
 
880
+ metrics.scbench_turn_accuracy.append(result['accuracy'])
881
+ metrics.compression_ratios.append(result['compression_ratio'])
882
+ metrics.kv_cache_memory_samples_mb.append(result['kv_cache_memory_mb'])
883
+ metrics.prefill_times.append(result['prefill_time'])
884
+ metrics.decode_times.append(result['generation_time'] / 20) # Per token
885
+
886
+ if result['prefill_peak_mem'] > 0:
887
+ metrics.prefill_peak_memories.append(result['prefill_peak_mem'])
888
+
889
+ per_sample_records.append({
890
+ 'benchmark': 'scbench',
891
+ 'sample_idx': idx,
892
+ 'accuracy': result['accuracy'],
893
+ 'compression_ratio': result['compression_ratio'],
894
+ 'kv_cache_memory_mb': result['kv_cache_memory_mb'],
895
+ 'compression_type': config.compression_type.value
896
+ })
897
 
898
  elif config.benchmark_type == "longbench":
899
+ # LongBench evaluation with unified compression
900
  if config.benchmark_subset:
901
  cache_manager = QuantizedKVCache(config)
902
  cache_manager.n_layers = n_layers
903
 
904
+ result = evaluate_longbench_task(model, tokenizer, config,
905
  config.benchmark_subset, cache_manager)
906
+
907
+ metrics.longbench_scores.append(result)
908
+ metrics.compression_ratios.append(result['compression_ratio'])
909
+ metrics.kv_cache_memory_samples_mb.append(result['kv_cache_memory_mb'])
910
+ metrics.prefill_times.append(result['prefill_time'])
911
+
912
+ if result['generation_time'] > 0:
913
+ metrics.decode_times.append(result['generation_time'] / 50) # Per token
914
+
915
+ per_sample_records.append({
916
+ 'benchmark': 'longbench',
917
+ 'subset': config.benchmark_subset,
918
+ 'accuracy': result['accuracy'],
919
+ 'compression_ratio': result['compression_ratio'],
920
+ 'kv_cache_memory_mb': result['kv_cache_memory_mb'],
921
+ 'compression_type': config.compression_type.value
922
+ })
923
 
924
  else:
925
+ # Standard WikiText perplexity evaluation with existing compression
926
  for idx in range(config.eval_samples):
927
  logger.info(f"Sample {idx+1}/{config.eval_samples}")
928
 
 
943
  input_ids = inputs.input_ids.to(device)
944
  attention_mask = inputs.attention_mask.to(device)
945
 
946
+ # Apply unified compression pipeline
947
+ compression_result = apply_compression_pipeline(
948
+ model, tokenizer, input_ids, attention_mask, cache_manager, config
949
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
950
 
951
+ metrics.prefill_times.append(compression_result['prefill_time'])
952
+ metrics.compression_ratios.append(compression_result['compression_ratio'])
953
+ metrics.kv_cache_memory_samples_mb.append(compression_result['compressed_cache_size'] / (1024 * 1024))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
954
 
955
+ if compression_result['prefill_peak_mem'] > 0:
956
+ metrics.prefill_peak_memories.append(compression_result['prefill_peak_mem'])
957
 
958
+ if compression_result['prefill_loss'] is not None:
959
+ prefill_perplexity = np.exp(compression_result['prefill_loss'])
960
+ metrics.prefill_perplexities.append(min(prefill_perplexity, 1000))
961
 
962
+ # Generation phase with timing
963
  generated_ids = input_ids.clone()
964
  decode_times = []
965
  generation_losses = []
966
+ past_key_values = compression_result['past_key_values']
967
 
968
  for gen_step in range(config.generation_length):
969
  if torch.cuda.is_available():
 
998
  if generation_losses:
999
  generation_perplexity = np.exp(np.mean(generation_losses))
1000
  metrics.generation_perplexities.append(min(generation_perplexity, 1000))
1001
+
1002
+ per_sample_records.append({
1003
+ 'benchmark': 'wikitext',
1004
+ 'sample_idx': idx,
1005
+ 'prefill_perplexity': metrics.prefill_perplexities[-1] if metrics.prefill_perplexities else None,
1006
+ 'generation_perplexity': metrics.generation_perplexities[-1] if metrics.generation_perplexities else None,
1007
+ 'compression_ratio': compression_result['compression_ratio'],
1008
+ 'kv_cache_memory_mb': compression_result['compressed_cache_size'] / (1024 * 1024),
1009
+ 'compression_type': config.compression_type.value
1010
+ })
1011
 
1012
  metrics.calculate_statistics(config)
1013
  all_metrics.append(metrics)
1014
 
1015
+ # Aggregate results across seeds
1016
  final_metrics = BenchmarkMetrics()
1017
  for m in all_metrics:
1018
  final_metrics.prefill_times.extend(m.prefill_times)
 
1056
  else:
1057
  summary['prefill_perplexity'] = final_metrics.prefill_perplexity_mean
1058
  summary['generation_perplexity'] = final_metrics.generation_perplexity_mean
1059
+
1060
+ # Always add timing and memory metrics
1061
+ summary['prefill_time_ms'] = final_metrics.prefill_time_mean * 1000
1062
+ summary['decode_time_ms'] = final_metrics.decode_time_per_token_mean_ms
1063
+ summary['throughput_tokens_sec'] = final_metrics.decode_tokens_per_sec
1064
+ summary['end_to_end_throughput'] = final_metrics.end_to_end_throughput
1065
+ summary['end_to_end_latency_ms'] = final_metrics.end_to_end_latency_ms
1066
+ summary['peak_memory_mb'] = final_metrics.prefill_peak_memory_mean_mb
1067
 
1068
  return final_metrics, summary, per_sample_records, per_layer_fingerprints
1069
 
1070
+
1071
  def export_proof_bundle(bundle_dir: str, config: CompressionConfig,
1072
  metrics: BenchmarkMetrics, summary: Dict[str, Any],
1073
  per_sample_records: List[Dict[str, Any]],
 
1122
  logger.info(f"Proof bundle exported: {zip_path}")
1123
  return zip_path
1124
 
1125
+
1126
  def verify_proof_bundle(bundle_root: str, config: CompressionConfig, proving: ProvingConfig) -> Dict[str, Any]:
1127
  """Verify proof bundle - recompute metrics and check tolerances."""
1128
  try:
 
1158
  # Verify based on benchmark type
1159
  if config.benchmark_type == "niah":
1160
  if "niah_accuracy" in summary:
1161
+ recomputed["niah_accuracy"] = mean_of("accuracy")
1162
  elif config.benchmark_type == "ruler":
1163
  if "ruler_exact_match" in summary:
1164
+ recomputed["ruler_exact_match"] = mean_of("exact_match")
1165
  elif config.benchmark_type == "scbench":
1166
  if "scbench_accuracy" in summary:
1167
+ recomputed["scbench_accuracy"] = mean_of("accuracy")
1168
  elif config.benchmark_type == "longbench":
1169
  if "longbench_accuracy" in summary:
1170
+ recomputed["longbench_accuracy"] = mean_of("accuracy")
1171
  elif config.benchmark_type == "wikitext":
1172
  # WikiText benchmark metrics
 
 
1173
  if "prefill_perplexity" in summary:
1174
  recomputed["prefill_perplexity"] = mean_of("prefill_perplexity")
1175
  if "generation_perplexity" in summary:
1176
  recomputed["generation_perplexity"] = mean_of("generation_perplexity")
1177
+
1178
+ # Always verify compression metrics
1179
+ recomputed["compression_ratio"] = mean_of("compression_ratio")
1180
+ recomputed["kv_cache_memory_mb"] = mean_of("kv_cache_memory_mb")
1181
 
1182
  for k, v in recomputed.items():
1183
  s = summary.get(k)