kfoughali commited on
Commit
0eecf30
·
verified ·
1 Parent(s): 4a8d21d

Update benchmark.py

Browse files
Files changed (1) hide show
  1. benchmark.py +0 -1195
benchmark.py CHANGED
@@ -1,1195 +0,0 @@
1
- # benchmark.py
2
- """
3
- Benchmarking, metrics, and proof generation for Enhanced SPG.
4
- Supports LongBench, NIAH, RULER, SCBench benchmarks.
5
- MEASURED VALUES ONLY - no estimations. FAIL FAST on errors.
6
- ALL BENCHMARKS USE SAME COMPRESSION PIPELINE AS WIKITEXT.
7
- FIXED: Generation errors, proper fallback handling.
8
- """
9
-
10
- import torch
11
- import torch.nn.functional as F
12
- import numpy as np
13
- from transformers import (
14
- AutoTokenizer, AutoModelForCausalLM,
15
- DynamicCache
16
- )
17
- from datasets import load_dataset
18
- from typing import Tuple, Optional, Dict, Any, List
19
- from dataclasses import dataclass, field
20
- from scipy import stats
21
- import time
22
- import json
23
- import hashlib
24
- import logging
25
- import gc
26
- import os
27
- import sys
28
- import platform
29
- import subprocess
30
- import zipfile
31
- import pathlib
32
- from datetime import datetime
33
- import random
34
-
35
- from config import (
36
- CompressionConfig, CompressionType, ProvingConfig,
37
- ResearchConstants, SUPPORTED_MODELS, BENCHMARK_CONFIGS
38
- )
39
- from compression import QuantizedKVCache, detect_model_layers
40
-
41
- logger = logging.getLogger(__name__)
42
-
43
- def set_seed(seed: int = 42) -> None:
44
- """Set all seeds for reproducibility with explicit validation."""
45
- if not isinstance(seed, int) or seed < 0:
46
- raise ValueError(f"Seed must be non-negative integer, got {seed}")
47
-
48
- random.seed(seed)
49
- np.random.seed(seed)
50
- torch.manual_seed(seed)
51
- if torch.cuda.is_available():
52
- torch.cuda.manual_seed_all(seed)
53
- torch.backends.cudnn.deterministic = True
54
- torch.backends.cudnn.benchmark = False
55
-
56
- logger.info(f"Set all random seeds to {seed}")
57
-
58
- def _peak_mem_bytes_all_gpus() -> int:
59
- """Get peak memory across all GPUs. FAIL FAST if CUDA unavailable when expected."""
60
- if not torch.cuda.is_available():
61
- raise RuntimeError("CUDA memory tracking requested but CUDA is unavailable")
62
-
63
- torch.cuda.synchronize()
64
- total_mem = sum(torch.cuda.max_memory_allocated(d) for d in range(torch.cuda.device_count()))
65
- logger.debug(f"Peak GPU memory: {total_mem / 1024 / 1024:.1f} MB")
66
- return total_mem
67
-
68
- @dataclass
69
- class BenchmarkMetrics:
70
- """Comprehensive metrics with proper statistical handling - NO ESTIMATES."""
71
- # Prefill metrics
72
- prefill_times: List[float] = field(default_factory=list)
73
- prefill_peak_memories: List[float] = field(default_factory=list)
74
- prefill_time_mean: float = 0.0
75
- prefill_time_std: float = 0.0
76
- prefill_time_ci: Tuple[float, float] = (0.0, 0.0)
77
- prefill_peak_memory_mean_mb: float = 0.0
78
- prefill_peak_memory_std_mb: float = 0.0
79
- prefill_peak_memory_ci_mb: Tuple[float, float] = (0.0, 0.0)
80
- prefill_tokens_per_sec: float = 0.0
81
-
82
- # Decode metrics
83
- decode_times: List[float] = field(default_factory=list)
84
- decode_peak_memories: List[float] = field(default_factory=list)
85
- decode_time_per_token_mean_ms: float = 0.0
86
- decode_time_per_token_std_ms: float = 0.0
87
- decode_time_per_token_ci_ms: Tuple[float, float] = (0.0, 0.0)
88
- decode_time_p50_ms: float = 0.0
89
- decode_time_p95_ms: float = 0.0
90
- decode_peak_memory_mean_mb: float = 0.0
91
- decode_tokens_per_sec: float = 0.0
92
-
93
- # Quality metrics
94
- prefill_perplexities: List[float] = field(default_factory=list)
95
- generation_perplexities: List[float] = field(default_factory=list)
96
- prefill_perplexity_mean: float = 0.0
97
- prefill_perplexity_std: float = 0.0
98
- prefill_perplexity_ci: Tuple[float, float] = (0.0, 0.0)
99
- generation_perplexity_mean: float = 0.0
100
- generation_perplexity_std: float = 0.0
101
- generation_perplexity_ci: Tuple[float, float] = (0.0, 0.0)
102
-
103
- # Benchmark-specific metrics
104
- longbench_scores: List[Dict[str, float]] = field(default_factory=list)
105
- niah_retrieval_accuracy: List[float] = field(default_factory=list)
106
- ruler_exact_match: List[float] = field(default_factory=list)
107
- scbench_turn_accuracy: List[float] = field(default_factory=list)
108
-
109
- # Compression metrics (MEASURED ONLY - no estimates)
110
- compression_ratios: List[float] = field(default_factory=list)
111
- compression_ratio_mean: float = 0.0
112
- compression_ratio_std: float = 0.0
113
- kv_cache_memory_mb: float = 0.0
114
- kv_cache_memory_samples_mb: List[float] = field(default_factory=list)
115
-
116
- # Enhanced SPG metrics (MEASURED ONLY)
117
- enhanced_spg_measured_compression: List[float] = field(default_factory=list)
118
- enhanced_spg_measured_auxiliary_overhead_mb: List[float] = field(default_factory=list)
119
- enhanced_spg_progressive_steps: List[int] = field(default_factory=list)
120
-
121
- # Original SPG metrics
122
- spg_precision_distributions: List[Dict[str, float]] = field(default_factory=list)
123
- spg_effective_bits_per_token: List[float] = field(default_factory=list)
124
- spg_decay_rates_per_layer: List[List[float]] = field(default_factory=list)
125
-
126
- # Statistical comparisons
127
- memory_reduction_ratio: float = 1.0
128
- memory_reduction_pvalue: float = 1.0
129
- speedup_ratio: float = 1.0
130
- speedup_pvalue: float = 1.0
131
- prefill_perplexity_delta: float = 0.0
132
- generation_perplexity_delta: float = 0.0
133
- perplexity_pvalue: float = 1.0
134
-
135
- # End-to-end metrics
136
- end_to_end_throughput: float = 0.0
137
- end_to_end_latency_ms: float = 0.0
138
-
139
- def calculate_statistics(self, config: CompressionConfig) -> None:
140
- """Calculate all statistics with proper error handling."""
141
- try:
142
- if self.prefill_times:
143
- self.prefill_time_mean = float(np.mean(self.prefill_times))
144
- self.prefill_time_std = float(np.std(self.prefill_times))
145
- self.prefill_time_ci = self._bootstrap_ci(self.prefill_times, config)
146
- self.prefill_tokens_per_sec = config.prefill_length / self.prefill_time_mean if self.prefill_time_mean > 0 else 0.0
147
-
148
- if self.prefill_peak_memories:
149
- memories_mb = [m / (1024 * 1024) for m in self.prefill_peak_memories]
150
- self.prefill_peak_memory_mean_mb = float(np.mean(memories_mb))
151
- self.prefill_peak_memory_std_mb = float(np.std(memories_mb))
152
- self.prefill_peak_memory_ci_mb = self._bootstrap_ci(memories_mb, config)
153
-
154
- if self.decode_times:
155
- self.decode_time_per_token_mean_ms = float(np.mean(self.decode_times) * 1000)
156
- self.decode_time_per_token_std_ms = float(np.std(self.decode_times) * 1000)
157
- self.decode_time_per_token_ci_ms = tuple(x * 1000 for x in self._bootstrap_ci(self.decode_times, config))
158
- self.decode_tokens_per_sec = 1.0 / np.mean(self.decode_times) if self.decode_times else 0.0
159
- self.decode_time_p50_ms = float(np.percentile(self.decode_times, 50) * 1000)
160
- self.decode_time_p95_ms = float(np.percentile(self.decode_times, 95) * 1000)
161
-
162
- # Calculate end-to-end throughput
163
- if self.prefill_time_mean > 0 and self.decode_time_per_token_mean_ms > 0:
164
- total_tokens = config.prefill_length + config.generation_length
165
- total_time_sec = self.prefill_time_mean + (self.decode_time_per_token_mean_ms * config.generation_length / 1000)
166
- self.end_to_end_throughput = total_tokens / total_time_sec if total_time_sec > 0 else 0.0
167
- self.end_to_end_latency_ms = total_time_sec * 1000
168
-
169
- if self.decode_peak_memories:
170
- self.decode_peak_memory_mean_mb = float(np.mean(self.decode_peak_memories) / (1024 * 1024))
171
-
172
- if self.prefill_perplexities:
173
- self.prefill_perplexity_mean = float(np.mean(self.prefill_perplexities))
174
- self.prefill_perplexity_std = float(np.std(self.prefill_perplexities))
175
- self.prefill_perplexity_ci = self._bootstrap_ci(self.prefill_perplexities, config)
176
-
177
- if self.generation_perplexities:
178
- self.generation_perplexity_mean = float(np.mean(self.generation_perplexities))
179
- self.generation_perplexity_std = float(np.std(self.generation_perplexities))
180
- self.generation_perplexity_ci = self._bootstrap_ci(self.generation_perplexities, config)
181
-
182
- if self.compression_ratios:
183
- self.compression_ratio_mean = float(np.mean(self.compression_ratios))
184
- self.compression_ratio_std = float(np.std(self.compression_ratios))
185
-
186
- if self.kv_cache_memory_samples_mb:
187
- self.kv_cache_memory_mb = float(np.mean(self.kv_cache_memory_samples_mb))
188
-
189
- except Exception as e:
190
- logger.error(f"Error calculating statistics: {e}")
191
- raise
192
-
193
- def _bootstrap_ci(self, data: List[float], config: CompressionConfig) -> Tuple[float, float]:
194
- """Calculate bootstrap confidence interval with reproducible RNG."""
195
- if not data or len(data) < 2:
196
- return (0.0, 0.0)
197
-
198
- try:
199
- rng = np.random.default_rng(config.seed)
200
- bootstrap_means = []
201
- data_array = np.array(data)
202
-
203
- for _ in range(config.n_bootstrap):
204
- sample = rng.choice(data_array, size=len(data_array), replace=True)
205
- bootstrap_means.append(float(sample.mean()))
206
-
207
- if bootstrap_means:
208
- alpha = 1 - config.confidence_level
209
- lower = float(np.percentile(bootstrap_means, alpha/2 * 100))
210
- upper = float(np.percentile(bootstrap_means, (1 - alpha/2) * 100))
211
- return (lower, upper)
212
-
213
- except Exception as e:
214
- logger.error(f"Error in bootstrap CI calculation: {e}")
215
- raise
216
-
217
- return (0.0, 0.0)
218
-
219
-
220
- def safe_tokenize(tokenizer, text, max_length=512):
221
- """Safe tokenization with proper padding and truncation."""
222
- if tokenizer.pad_token is None:
223
- tokenizer.pad_token = tokenizer.eos_token
224
-
225
- inputs = tokenizer(
226
- text,
227
- return_tensors="pt",
228
- truncation=True,
229
- max_length=max_length,
230
- padding="max_length",
231
- return_attention_mask=True,
232
- add_special_tokens=True
233
- )
234
-
235
- if inputs.input_ids.shape[1] == 0:
236
- raise ValueError("Tokenization produced empty sequence")
237
-
238
- if inputs.input_ids.shape[1] > max_length:
239
- inputs.input_ids = inputs.input_ids[:, :max_length]
240
- inputs.attention_mask = inputs.attention_mask[:, :max_length]
241
-
242
- return inputs
243
-
244
-
245
- def validate_model_inputs(model, input_ids, attention_mask):
246
- """Validate inputs are compatible with model."""
247
- if hasattr(model.config, 'max_position_embeddings'):
248
- max_pos = model.config.max_position_embeddings
249
- if input_ids.shape[1] > max_pos:
250
- input_ids = input_ids[:, :max_pos]
251
- attention_mask = attention_mask[:, :max_pos]
252
-
253
- if hasattr(model.config, 'n_positions'):
254
- n_pos = model.config.n_positions
255
- if input_ids.shape[1] > n_pos:
256
- input_ids = input_ids[:, :n_pos]
257
- attention_mask = attention_mask[:, :n_pos]
258
-
259
- vocab_size = model.config.vocab_size
260
- if input_ids.max() >= vocab_size:
261
- input_ids = input_ids.clamp(0, vocab_size - 1)
262
-
263
- if input_ids.min() < 0:
264
- input_ids = input_ids.clamp(0, vocab_size - 1)
265
-
266
- return input_ids, attention_mask
267
-
268
-
269
- def safe_generate(model, tokenizer, input_ids, attention_mask, past_key_values=None, max_new_tokens=20):
270
- """Safe generation with proper error handling - returns generated text."""
271
- try:
272
- input_ids, attention_mask = validate_model_inputs(model, input_ids, attention_mask)
273
-
274
- gen_config = {
275
- "max_new_tokens": max_new_tokens,
276
- "do_sample": False,
277
- "pad_token_id": tokenizer.pad_token_id or tokenizer.eos_token_id,
278
- "eos_token_id": tokenizer.eos_token_id,
279
- "attention_mask": attention_mask,
280
- "use_cache": True
281
- }
282
-
283
- if past_key_values is not None:
284
- gen_config["past_key_values"] = past_key_values
285
-
286
- with torch.no_grad():
287
- output = model.generate(input_ids, **gen_config)
288
-
289
- # Decode only the generated part
290
- generated_ids = output[:, input_ids.shape[1]:]
291
- generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
292
- return generated_text
293
-
294
- except Exception as e:
295
- logger.error(f"Generation failed: {e}")
296
- # Return empty string on failure
297
- return ""
298
-
299
-
300
- def apply_compression_pipeline(model, tokenizer, input_ids, attention_mask,
301
- cache_manager: QuantizedKVCache, config: CompressionConfig,
302
- measure_memory: bool = True) -> Dict[str, Any]:
303
- """
304
- Unified compression pipeline for ALL benchmarks with safety fixes.
305
- Returns compressed cache, metrics, and reconstructed KV pairs.
306
- """
307
- device = input_ids.device
308
-
309
- input_ids, attention_mask = validate_model_inputs(model, input_ids, attention_mask)
310
-
311
- if torch.cuda.is_available() and measure_memory:
312
- torch.cuda.empty_cache()
313
- torch.cuda.reset_peak_memory_stats()
314
- torch.cuda.synchronize()
315
-
316
- if torch.cuda.is_available():
317
- torch.cuda.synchronize()
318
- start_time = time.perf_counter()
319
-
320
- try:
321
- with torch.inference_mode():
322
- outputs = model(
323
- input_ids,
324
- attention_mask=attention_mask,
325
- use_cache=True,
326
- return_dict=True
327
- )
328
- past_key_values = outputs.past_key_values
329
- logits = outputs.logits
330
- except Exception as e:
331
- logger.error(f"Prefill failed: {e}")
332
- return {
333
- 'past_key_values': None,
334
- 'prefill_time': 0,
335
- 'prefill_peak_mem': 0,
336
- 'prefill_loss': None,
337
- 'original_cache_size': 0,
338
- 'compressed_cache_size': 0,
339
- 'compression_ratio': 1.0,
340
- 'logits': None
341
- }
342
-
343
- if torch.cuda.is_available():
344
- torch.cuda.synchronize()
345
-
346
- prefill_time = time.perf_counter() - start_time
347
-
348
- prefill_peak_mem = 0
349
- if torch.cuda.is_available() and measure_memory:
350
- prefill_peak_mem = _peak_mem_bytes_all_gpus()
351
-
352
- prefill_loss = None
353
- if logits is not None and input_ids.shape[1] > 1:
354
- try:
355
- seq_len = min(logits.shape[1], input_ids.shape[1] - 1)
356
- if seq_len > 0:
357
- shift_logits = logits[:, :seq_len, :].contiguous()
358
- shift_labels = input_ids[:, 1:seq_len+1].contiguous()
359
-
360
- loss = F.cross_entropy(
361
- shift_logits.view(-1, shift_logits.size(-1)),
362
- shift_labels.view(-1),
363
- reduction='mean',
364
- ignore_index=tokenizer.pad_token_id or -100
365
- )
366
- prefill_loss = loss.item()
367
- except Exception as e:
368
- logger.warning(f"Could not calculate prefill loss: {e}")
369
-
370
- original_cache_size = 0
371
- compressed_cache_size = 0
372
- compression_ratio = 1.0
373
-
374
- if past_key_values:
375
- try:
376
- if hasattr(past_key_values, 'to_legacy_cache'):
377
- kv_tuple = past_key_values.to_legacy_cache()
378
- else:
379
- kv_tuple = past_key_values
380
-
381
- for layer_idx, (keys, values) in enumerate(kv_tuple):
382
- if keys is not None and values is not None:
383
- original_cache_size += keys.nelement() * keys.element_size()
384
- original_cache_size += values.nelement() * values.element_size()
385
-
386
- if config.compression_type != CompressionType.NONE and cache_manager is not None:
387
- try:
388
- cache_manager.compress_and_store(layer_idx, keys, values)
389
- except Exception as e:
390
- logger.error(f"Compression failed for layer {layer_idx}: {e}")
391
-
392
- if config.compression_type != CompressionType.NONE and cache_manager is not None:
393
- reconstructed_kv = []
394
- for layer_idx in range(len(kv_tuple)):
395
- try:
396
- dec_keys, dec_values = cache_manager.get_decompressed(layer_idx)
397
- if dec_keys is not None and dec_values is not None:
398
- reconstructed_kv.append((dec_keys, dec_values))
399
- else:
400
- reconstructed_kv.append(kv_tuple[layer_idx])
401
- except Exception as e:
402
- logger.error(f"Decompression failed for layer {layer_idx}: {e}")
403
- reconstructed_kv.append(kv_tuple[layer_idx])
404
-
405
- if hasattr(DynamicCache, 'from_legacy_cache'):
406
- past_key_values = DynamicCache.from_legacy_cache(tuple(reconstructed_kv))
407
- else:
408
- past_key_values = tuple(reconstructed_kv)
409
-
410
- try:
411
- compressed_cache_size = cache_manager.get_memory_footprint()
412
- except:
413
- compressed_cache_size = original_cache_size
414
- else:
415
- compressed_cache_size = original_cache_size
416
-
417
- if compressed_cache_size > 0:
418
- compression_ratio = original_cache_size / compressed_cache_size
419
-
420
- except Exception as e:
421
- logger.error(f"Cache processing failed: {e}")
422
- compressed_cache_size = original_cache_size
423
- compression_ratio = 1.0
424
-
425
- return {
426
- 'past_key_values': past_key_values,
427
- 'prefill_time': prefill_time,
428
- 'prefill_peak_mem': prefill_peak_mem,
429
- 'prefill_loss': prefill_loss,
430
- 'original_cache_size': original_cache_size,
431
- 'compressed_cache_size': compressed_cache_size,
432
- 'compression_ratio': compression_ratio,
433
- 'logits': logits
434
- }
435
-
436
-
437
- def create_niah_haystack(context_length: int, needle: str, depth_percent: float) -> str:
438
- """Create Needle-in-a-Haystack test context - NO HARDCODING."""
439
- haystack_template = "The quick brown fox jumps over the lazy dog. " * 20
440
- haystack_chunks = []
441
-
442
- while len(" ".join(haystack_chunks)) < context_length:
443
- haystack_chunks.append(haystack_template)
444
-
445
- haystack = " ".join(haystack_chunks)[:context_length - len(needle) - 10]
446
-
447
- insertion_point = int(len(haystack) * depth_percent / 100)
448
- haystack_with_needle = (
449
- haystack[:insertion_point] +
450
- " " + needle + " " +
451
- haystack[insertion_point:]
452
- )
453
-
454
- return haystack_with_needle
455
-
456
-
457
- def evaluate_niah(model, tokenizer, config: CompressionConfig, cache_manager: Optional[QuantizedKVCache] = None) -> Dict[str, Any]:
458
- """Evaluate NIAH with SAME compression pipeline as WikiText."""
459
- context = create_niah_haystack(
460
- config.prefill_length,
461
- config.niah_needle,
462
- config.niah_depth_percent
463
- )
464
-
465
- prompt = f"{context}\n\nQuestion: What is the secret password?\nAnswer:"
466
-
467
- inputs = safe_tokenize(tokenizer, prompt, max_length=min(config.prefill_length, 1024))
468
- input_ids = inputs.input_ids.to(model.device)
469
- attention_mask = inputs.attention_mask.to(model.device)
470
-
471
- compression_result = apply_compression_pipeline(
472
- model, tokenizer, input_ids, attention_mask, cache_manager, config
473
- )
474
-
475
- gen_start = time.perf_counter()
476
- generated_text = safe_generate(model, tokenizer, input_ids, attention_mask,
477
- compression_result['past_key_values'], max_new_tokens=20)
478
- gen_time = time.perf_counter() - gen_start
479
-
480
- accuracy = 1.0 if config.niah_needle.split()[-1] in generated_text else 0.0
481
-
482
- logger.info(f"NIAH accuracy: {accuracy}, Generated: {generated_text[:50]}")
483
- logger.info(f"NIAH compression ratio: {compression_result['compression_ratio']:.1f}x")
484
-
485
- return {
486
- 'accuracy': accuracy,
487
- 'compression_ratio': compression_result['compression_ratio'],
488
- 'kv_cache_memory_mb': compression_result['compressed_cache_size'] / (1024 * 1024),
489
- 'prefill_time': compression_result['prefill_time'],
490
- 'generation_time': gen_time,
491
- 'prefill_peak_mem': compression_result['prefill_peak_mem']
492
- }
493
-
494
-
495
- def evaluate_ruler(model, tokenizer, config: CompressionConfig, cache_manager: Optional[QuantizedKVCache] = None) -> Dict[str, Any]:
496
- """Evaluate RULER with SAME compression pipeline as WikiText."""
497
- seq_len = min(config.ruler_max_seq_length, config.prefill_length, 1024)
498
-
499
- facts = []
500
- for i in range(10):
501
- facts.append(f"Fact {i}: The capital of Country{i} is City{i}.")
502
-
503
- context = " ".join(facts) * (seq_len // (len(" ".join(facts)) + 1))
504
- context = context[:seq_len - 100]
505
-
506
- query_idx = random.randint(0, 9)
507
- prompt = f"{context}\n\nWhat is the capital of Country{query_idx}?"
508
-
509
- inputs = safe_tokenize(tokenizer, prompt, max_length=seq_len)
510
- input_ids = inputs.input_ids.to(model.device)
511
- attention_mask = inputs.attention_mask.to(model.device)
512
-
513
- compression_result = apply_compression_pipeline(
514
- model, tokenizer, input_ids, attention_mask, cache_manager, config
515
- )
516
-
517
- gen_start = time.perf_counter()
518
- generated = safe_generate(model, tokenizer, input_ids, attention_mask,
519
- compression_result['past_key_values'], max_new_tokens=10)
520
- gen_time = time.perf_counter() - gen_start
521
-
522
- expected = f"City{query_idx}"
523
- exact_match = 1.0 if expected in generated else 0.0
524
-
525
- logger.info(f"RULER exact match: {exact_match}, Generated: {generated[:50]}")
526
- logger.info(f"RULER compression ratio: {compression_result['compression_ratio']:.1f}x")
527
-
528
- return {
529
- 'exact_match': exact_match,
530
- 'compression_ratio': compression_result['compression_ratio'],
531
- 'kv_cache_memory_mb': compression_result['compressed_cache_size'] / (1024 * 1024),
532
- 'prefill_time': compression_result['prefill_time'],
533
- 'generation_time': gen_time,
534
- 'prefill_peak_mem': compression_result['prefill_peak_mem']
535
- }
536
-
537
-
538
- def evaluate_scbench(model, tokenizer, config: CompressionConfig, cache_manager: Optional[QuantizedKVCache] = None) -> Dict[str, Any]:
539
- """Evaluate SCBench with SAME compression pipeline as WikiText."""
540
- conversation = []
541
- facts = {}
542
-
543
- for turn in range(config.scbench_num_turns):
544
- fact_key = f"item_{turn}"
545
- fact_value = f"value_{turn}_{random.randint(1000, 9999)}"
546
- facts[fact_key] = fact_value
547
-
548
- user_msg = f"Remember that {fact_key} is {fact_value}."
549
- assistant_msg = f"I'll remember that {fact_key} is {fact_value}."
550
-
551
- conversation.append(f"User: {user_msg}")
552
- conversation.append(f"Assistant: {assistant_msg}")
553
-
554
- query_key = random.choice(list(facts.keys()))
555
- conversation.append(f"User: What is {query_key}?")
556
-
557
- full_conversation = "\n".join(conversation) + "\nAssistant:"
558
-
559
- inputs = safe_tokenize(tokenizer, full_conversation, max_length=min(config.prefill_length, 1024))
560
- input_ids = inputs.input_ids.to(model.device)
561
- attention_mask = inputs.attention_mask.to(model.device)
562
-
563
- compression_result = apply_compression_pipeline(
564
- model, tokenizer, input_ids, attention_mask, cache_manager, config
565
- )
566
-
567
- gen_start = time.perf_counter()
568
- generated = safe_generate(model, tokenizer, input_ids, attention_mask,
569
- compression_result['past_key_values'], max_new_tokens=20)
570
- gen_time = time.perf_counter() - gen_start
571
-
572
- expected_value = facts[query_key]
573
- accuracy = 1.0 if expected_value in generated else 0.0
574
-
575
- logger.info(f"SCBench accuracy: {accuracy}, Generated: {generated[:50]}")
576
- logger.info(f"SCBench compression ratio: {compression_result['compression_ratio']:.1f}x")
577
-
578
- return {
579
- 'accuracy': accuracy,
580
- 'compression_ratio': compression_result['compression_ratio'],
581
- 'kv_cache_memory_mb': compression_result['compressed_cache_size'] / (1024 * 1024),
582
- 'prefill_time': compression_result['prefill_time'],
583
- 'generation_time': gen_time,
584
- 'prefill_peak_mem': compression_result['prefill_peak_mem']
585
- }
586
-
587
-
588
- def evaluate_longbench_task(model, tokenizer, config: CompressionConfig,
589
- task: str, cache_manager: Optional[QuantizedKVCache] = None) -> Dict[str, Any]:
590
- """Evaluate LongBench with SAME compression pipeline as WikiText."""
591
- try:
592
- dataset = load_dataset("THUDM/LongBench", task, split="test")
593
-
594
- n_samples = min(config.eval_samples, len(dataset))
595
- samples = dataset.select(range(n_samples))
596
-
597
- scores = []
598
- compression_ratios = []
599
- kv_memories = []
600
- prefill_times = []
601
- gen_times = []
602
-
603
- for sample in samples:
604
- context = sample.get("context", "")
605
- question = sample.get("input", sample.get("question", ""))
606
- answer = sample.get("answers", [sample.get("answer", "")])
607
-
608
- if isinstance(answer, list) and answer:
609
- answer = answer[0]
610
-
611
- prompt = f"Context: {context}\n\nQuestion: {question}\n\nAnswer:"
612
-
613
- inputs = safe_tokenize(tokenizer, prompt, max_length=min(config.prefill_length, 1024))
614
- input_ids = inputs.input_ids.to(model.device)
615
- attention_mask = inputs.attention_mask.to(model.device)
616
-
617
- compression_result = apply_compression_pipeline(
618
- model, tokenizer, input_ids, attention_mask, cache_manager, config,
619
- measure_memory=False
620
- )
621
-
622
- gen_start = time.perf_counter()
623
- generated = safe_generate(model, tokenizer, input_ids, attention_mask,
624
- compression_result['past_key_values'], max_new_tokens=50)
625
- gen_time = time.perf_counter() - gen_start
626
-
627
- score = 1.0 if str(answer).lower() in generated.lower() else 0.0
628
- scores.append(score)
629
- compression_ratios.append(compression_result['compression_ratio'])
630
- kv_memories.append(compression_result['compressed_cache_size'] / (1024 * 1024))
631
- prefill_times.append(compression_result['prefill_time'])
632
- gen_times.append(gen_time)
633
-
634
- avg_compression = float(np.mean(compression_ratios)) if compression_ratios else 1.0
635
-
636
- return {
637
- 'accuracy': float(np.mean(scores)),
638
- 'n_samples': n_samples,
639
- 'compression_ratio': avg_compression,
640
- 'kv_cache_memory_mb': float(np.mean(kv_memories)) if kv_memories else 0.0,
641
- 'prefill_time': float(np.mean(prefill_times)) if prefill_times else 0.0,
642
- 'generation_time': float(np.mean(gen_times)) if gen_times else 0.0
643
- }
644
-
645
- except Exception as e:
646
- logger.error(f"Error evaluating LongBench task {task}: {e}")
647
- return {
648
- 'accuracy': 0.0,
649
- 'n_samples': 0,
650
- 'compression_ratio': 1.0,
651
- 'kv_cache_memory_mb': 0.0,
652
- 'prefill_time': 0.0,
653
- 'generation_time': 0.0
654
- }
655
-
656
-
657
- def load_model_and_tokenizer(model_name: str, config: CompressionConfig):
658
- """Load model and tokenizer with proper configuration - NO HARDCODING."""
659
- device = "cuda" if torch.cuda.is_available() else "cpu"
660
- dtype = torch.float16 if device == "cuda" else torch.float32
661
-
662
- if config.fail_on_cpu_fallback and device == "cpu":
663
- raise RuntimeError("CUDA required but unavailable (fail_on_cpu_fallback=True)")
664
-
665
- logger.info(f"Loading model: {model_name}")
666
-
667
- tokenizer = AutoTokenizer.from_pretrained(
668
- model_name,
669
- trust_remote_code=True
670
- )
671
-
672
- if tokenizer.pad_token is None:
673
- tokenizer.pad_token = tokenizer.eos_token
674
-
675
- model_kwargs = {
676
- "torch_dtype": dtype,
677
- "device_map": "auto" if device == "cuda" else None,
678
- "low_cpu_mem_usage": True,
679
- "trust_remote_code": True
680
- }
681
-
682
- if config.use_flash_attention and device == "cuda":
683
- try:
684
- model_kwargs["attn_implementation"] = "flash_attention_2"
685
- model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs)
686
- logger.info("Successfully loaded with Flash Attention 2")
687
- except Exception as e:
688
- logger.warning(f"Flash Attention not available: {e}")
689
- model_kwargs.pop("attn_implementation", None)
690
- model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs)
691
- else:
692
- model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs)
693
-
694
- model.eval()
695
-
696
- return model, tokenizer
697
-
698
-
699
- def load_real_dataset_samples(config: CompressionConfig, tokenizer) -> List[str]:
700
- """Load dataset samples based on benchmark type - NO HARDCODING."""
701
- logger.info(f"Loading samples for benchmark: {config.benchmark_type}")
702
-
703
- if config.benchmark_type == "wikitext":
704
- texts = []
705
- min_tokens = config.prefill_length + config.generation_length
706
-
707
- try:
708
- for split in [config.dataset_split, "train", "validation"]:
709
- if len(texts) >= config.eval_samples:
710
- break
711
-
712
- try:
713
- dataset = load_dataset(
714
- config.dataset_name,
715
- config.dataset_config,
716
- split=split,
717
- streaming=False
718
- )
719
-
720
- logger.info(f"Trying {split} split with {len(dataset)} samples")
721
-
722
- for item in dataset:
723
- text = item.get('text', '').strip()
724
-
725
- if len(text) > 50:
726
- tokens = tokenizer.encode(text, truncation=False, add_special_tokens=False)
727
-
728
- if len(tokens) >= min(min_tokens, 256):
729
- texts.append(text)
730
- if len(texts) >= config.eval_samples * 3:
731
- break
732
-
733
- except Exception as e:
734
- logger.warning(f"Failed to load {split} split: {e}")
735
- continue
736
-
737
- except Exception as e:
738
- logger.error(f"Failed to load dataset: {e}")
739
- raise
740
-
741
- elif config.benchmark_type == "longbench":
742
- texts = []
743
- if config.benchmark_subset:
744
- try:
745
- dataset = load_dataset("THUDM/LongBench", config.benchmark_subset, split="test")
746
- for item in dataset:
747
- if len(texts) >= config.eval_samples:
748
- break
749
- context = item.get("context", "")
750
- if len(context) > 100:
751
- texts.append(context)
752
- except Exception as e:
753
- logger.error(f"Failed to load LongBench subset {config.benchmark_subset}: {e}")
754
- raise
755
-
756
- elif config.benchmark_type in ["niah", "ruler", "scbench"]:
757
- texts = ["Synthetic benchmark data"] * config.eval_samples
758
-
759
- else:
760
- raise ValueError(f"Unsupported benchmark type: {config.benchmark_type}")
761
-
762
- if len(texts) < config.eval_samples:
763
- logger.warning(f"Only loaded {len(texts)} samples, requested {config.eval_samples}")
764
-
765
- logger.info(f"Loaded {len(texts)} text samples")
766
- return texts
767
-
768
-
769
- def run_research_benchmark(model_name: str, config: CompressionConfig, dataset_texts: Optional[List[str]] = None) -> Tuple[BenchmarkMetrics, Dict, List[Dict], List[Dict]]:
770
- """Research-grade benchmark with UNIFIED compression for ALL benchmarks."""
771
- logger.info(f"Starting benchmark: {model_name} with {config.compression_type.value}")
772
- logger.info(f"Benchmark type: {config.benchmark_type}")
773
- logger.info(f"Config hash: {config.get_hash()}")
774
-
775
- if torch.cuda.is_available():
776
- os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
777
-
778
- constants = ResearchConstants()
779
- start_time = datetime.now().isoformat()
780
- per_sample_records = []
781
- per_layer_fingerprints = []
782
-
783
- model, tokenizer = load_model_and_tokenizer(model_name, config)
784
-
785
- try:
786
- n_layers = detect_model_layers(model)
787
- logger.info(f"Model architecture: {n_layers} transformer layers detected")
788
- except ValueError as e:
789
- logger.error(f"Failed to detect model layers: {e}")
790
- raise
791
-
792
- device = model.device
793
- with torch.inference_mode():
794
- dummy = torch.randint(0, tokenizer.vocab_size, (1, min(config.prefill_length, 128)), device=device)
795
- am = torch.ones_like(dummy)
796
- for _ in range(config.warmup_steps):
797
- _ = model(dummy, attention_mask=am, use_cache=True, return_dict=True)
798
-
799
- if torch.cuda.is_available():
800
- torch.cuda.synchronize()
801
- torch.cuda.reset_peak_memory_stats()
802
-
803
- if dataset_texts is None:
804
- dataset_texts = load_real_dataset_samples(config, tokenizer)
805
-
806
- all_metrics = []
807
-
808
- for seed in range(config.n_seeds):
809
- set_seed(config.seed + seed)
810
- logger.info(f"Running evaluation with seed {config.seed + seed}")
811
-
812
- metrics = BenchmarkMetrics()
813
-
814
- if config.benchmark_type == "niah":
815
- for depth in BENCHMARK_CONFIGS["niah"]["depths"]:
816
- config.niah_depth_percent = depth
817
- for idx in range(min(config.eval_samples, 10)):
818
- if config.compression_type != CompressionType.NONE:
819
- cache_manager = QuantizedKVCache(config)
820
- cache_manager.n_layers = n_layers
821
- else:
822
- cache_manager = None
823
-
824
- result = evaluate_niah(model, tokenizer, config, cache_manager)
825
-
826
- metrics.niah_retrieval_accuracy.append(result['accuracy'])
827
- metrics.compression_ratios.append(result['compression_ratio'])
828
- metrics.kv_cache_memory_samples_mb.append(result['kv_cache_memory_mb'])
829
- metrics.prefill_times.append(result['prefill_time'])
830
- metrics.decode_times.append(result['generation_time'] / 20)
831
-
832
- if result['prefill_peak_mem'] > 0:
833
- metrics.prefill_peak_memories.append(result['prefill_peak_mem'])
834
-
835
- per_sample_records.append({
836
- 'benchmark': 'niah',
837
- 'depth_percent': depth,
838
- 'sample_idx': idx,
839
- 'accuracy': result['accuracy'],
840
- 'compression_ratio': result['compression_ratio'],
841
- 'kv_cache_memory_mb': result['kv_cache_memory_mb'],
842
- 'compression_type': config.compression_type.value
843
- })
844
-
845
- elif config.benchmark_type == "ruler":
846
- for idx in range(config.eval_samples):
847
- if config.compression_type != CompressionType.NONE:
848
- cache_manager = QuantizedKVCache(config)
849
- cache_manager.n_layers = n_layers
850
- else:
851
- cache_manager = None
852
-
853
- result = evaluate_ruler(model, tokenizer, config, cache_manager)
854
-
855
- metrics.ruler_exact_match.append(result['exact_match'])
856
- metrics.compression_ratios.append(result['compression_ratio'])
857
- metrics.kv_cache_memory_samples_mb.append(result['kv_cache_memory_mb'])
858
- metrics.prefill_times.append(result['prefill_time'])
859
- metrics.decode_times.append(result['generation_time'] / 10)
860
-
861
- if result['prefill_peak_mem'] > 0:
862
- metrics.prefill_peak_memories.append(result['prefill_peak_mem'])
863
-
864
- per_sample_records.append({
865
- 'benchmark': 'ruler',
866
- 'sample_idx': idx,
867
- 'exact_match': result['exact_match'],
868
- 'compression_ratio': result['compression_ratio'],
869
- 'kv_cache_memory_mb': result['kv_cache_memory_mb'],
870
- 'compression_type': config.compression_type.value
871
- })
872
-
873
- elif config.benchmark_type == "scbench":
874
- for idx in range(config.eval_samples):
875
- if config.compression_type != CompressionType.NONE:
876
- cache_manager = QuantizedKVCache(config)
877
- cache_manager.n_layers = n_layers
878
- else:
879
- cache_manager = None
880
-
881
- result = evaluate_scbench(model, tokenizer, config, cache_manager)
882
-
883
- metrics.scbench_turn_accuracy.append(result['accuracy'])
884
- metrics.compression_ratios.append(result['compression_ratio'])
885
- metrics.kv_cache_memory_samples_mb.append(result['kv_cache_memory_mb'])
886
- metrics.prefill_times.append(result['prefill_time'])
887
- metrics.decode_times.append(result['generation_time'] / 20)
888
-
889
- if result['prefill_peak_mem'] > 0:
890
- metrics.prefill_peak_memories.append(result['prefill_peak_mem'])
891
-
892
- per_sample_records.append({
893
- 'benchmark': 'scbench',
894
- 'sample_idx': idx,
895
- 'accuracy': result['accuracy'],
896
- 'compression_ratio': result['compression_ratio'],
897
- 'kv_cache_memory_mb': result['kv_cache_memory_mb'],
898
- 'compression_type': config.compression_type.value
899
- })
900
-
901
- elif config.benchmark_type == "longbench":
902
- if config.benchmark_subset:
903
- if config.compression_type != CompressionType.NONE:
904
- cache_manager = QuantizedKVCache(config)
905
- cache_manager.n_layers = n_layers
906
- else:
907
- cache_manager = None
908
-
909
- result = evaluate_longbench_task(model, tokenizer, config,
910
- config.benchmark_subset, cache_manager)
911
-
912
- metrics.longbench_scores.append(result)
913
- metrics.compression_ratios.append(result['compression_ratio'])
914
- metrics.kv_cache_memory_samples_mb.append(result['kv_cache_memory_mb'])
915
- metrics.prefill_times.append(result['prefill_time'])
916
-
917
- if result['generation_time'] > 0:
918
- metrics.decode_times.append(result['generation_time'] / 50)
919
-
920
- per_sample_records.append({
921
- 'benchmark': 'longbench',
922
- 'subset': config.benchmark_subset,
923
- 'accuracy': result['accuracy'],
924
- 'compression_ratio': result['compression_ratio'],
925
- 'kv_cache_memory_mb': result['kv_cache_memory_mb'],
926
- 'compression_type': config.compression_type.value
927
- })
928
-
929
- else:
930
- for idx in range(config.eval_samples):
931
- logger.info(f"Sample {idx+1}/{config.eval_samples}")
932
-
933
- text_idx = (idx + seed * config.eval_samples) % len(dataset_texts)
934
- text = dataset_texts[text_idx]
935
-
936
- if config.compression_type != CompressionType.NONE:
937
- cache_manager = QuantizedKVCache(config)
938
- cache_manager.n_layers = n_layers
939
- cache_manager.update_position(config.prefill_length + idx)
940
- else:
941
- cache_manager = None
942
-
943
- inputs = safe_tokenize(tokenizer, text, max_length=min(config.prefill_length, 1024))
944
- input_ids = inputs.input_ids.to(device)
945
- attention_mask = inputs.attention_mask.to(device)
946
-
947
- compression_result = apply_compression_pipeline(
948
- model, tokenizer, input_ids, attention_mask, cache_manager, config
949
- )
950
-
951
- metrics.prefill_times.append(compression_result['prefill_time'])
952
- metrics.compression_ratios.append(compression_result['compression_ratio'])
953
- metrics.kv_cache_memory_samples_mb.append(compression_result['compressed_cache_size'] / (1024 * 1024))
954
-
955
- if compression_result['prefill_peak_mem'] > 0:
956
- metrics.prefill_peak_memories.append(compression_result['prefill_peak_mem'])
957
-
958
- if compression_result['prefill_loss'] is not None:
959
- prefill_perplexity = np.exp(compression_result['prefill_loss'])
960
- metrics.prefill_perplexities.append(min(prefill_perplexity, 1000))
961
-
962
- generated_ids = input_ids.clone()
963
- decode_times = []
964
- generation_losses = []
965
- past_key_values = compression_result['past_key_values']
966
-
967
- for gen_step in range(config.generation_length):
968
- if torch.cuda.is_available():
969
- torch.cuda.synchronize()
970
- step_start = time.perf_counter()
971
-
972
- with torch.inference_mode():
973
- outputs = model(
974
- generated_ids[:, -1:],
975
- past_key_values=past_key_values,
976
- use_cache=True,
977
- return_dict=True
978
- )
979
- next_token_logits = outputs.logits[:, -1, :]
980
- next_token = torch.argmax(next_token_logits, dim=-1)
981
-
982
- loss = F.cross_entropy(next_token_logits, next_token)
983
- generation_losses.append(loss.item())
984
-
985
- generated_ids = torch.cat([generated_ids, next_token.unsqueeze(-1)], dim=-1)
986
- past_key_values = outputs.past_key_values
987
-
988
- if torch.cuda.is_available():
989
- torch.cuda.synchronize()
990
-
991
- decode_time = time.perf_counter() - step_start
992
- decode_times.append(decode_time)
993
-
994
- if decode_times:
995
- metrics.decode_times.extend(decode_times)
996
-
997
- if generation_losses:
998
- generation_perplexity = np.exp(np.mean(generation_losses))
999
- metrics.generation_perplexities.append(min(generation_perplexity, 1000))
1000
-
1001
- per_sample_records.append({
1002
- 'benchmark': 'wikitext',
1003
- 'sample_idx': idx,
1004
- 'prefill_perplexity': metrics.prefill_perplexities[-1] if metrics.prefill_perplexities else None,
1005
- 'generation_perplexity': metrics.generation_perplexities[-1] if metrics.generation_perplexities else None,
1006
- 'compression_ratio': compression_result['compression_ratio'],
1007
- 'kv_cache_memory_mb': compression_result['compressed_cache_size'] / (1024 * 1024),
1008
- 'compression_type': config.compression_type.value
1009
- })
1010
-
1011
- metrics.calculate_statistics(config)
1012
- all_metrics.append(metrics)
1013
-
1014
- final_metrics = BenchmarkMetrics()
1015
- for m in all_metrics:
1016
- final_metrics.prefill_times.extend(m.prefill_times)
1017
- final_metrics.prefill_peak_memories.extend(m.prefill_peak_memories)
1018
- final_metrics.decode_times.extend(m.decode_times)
1019
- final_metrics.decode_peak_memories.extend(m.decode_peak_memories)
1020
- final_metrics.prefill_perplexities.extend(m.prefill_perplexities)
1021
- final_metrics.generation_perplexities.extend(m.generation_perplexities)
1022
- final_metrics.compression_ratios.extend(m.compression_ratios)
1023
- final_metrics.kv_cache_memory_samples_mb.extend(m.kv_cache_memory_samples_mb)
1024
- final_metrics.niah_retrieval_accuracy.extend(m.niah_retrieval_accuracy)
1025
- final_metrics.ruler_exact_match.extend(m.ruler_exact_match)
1026
- final_metrics.scbench_turn_accuracy.extend(m.scbench_turn_accuracy)
1027
- final_metrics.longbench_scores.extend(m.longbench_scores)
1028
-
1029
- final_metrics.calculate_statistics(config)
1030
-
1031
- end_time = datetime.now().isoformat()
1032
- summary = {
1033
- 'compression_type': config.compression_type.value,
1034
- 'model': model_name,
1035
- 'benchmark_type': config.benchmark_type,
1036
- 'n_seeds': config.n_seeds,
1037
- 'total_samples': config.eval_samples * config.n_seeds,
1038
- 'compression_ratio': final_metrics.compression_ratio_mean,
1039
- 'kv_cache_memory_mb': final_metrics.kv_cache_memory_mb,
1040
- 'start_time': start_time,
1041
- 'end_time': end_time
1042
- }
1043
-
1044
- if config.benchmark_type == "niah" and final_metrics.niah_retrieval_accuracy:
1045
- summary['niah_accuracy'] = float(np.mean(final_metrics.niah_retrieval_accuracy))
1046
- elif config.benchmark_type == "ruler" and final_metrics.ruler_exact_match:
1047
- summary['ruler_exact_match'] = float(np.mean(final_metrics.ruler_exact_match))
1048
- elif config.benchmark_type == "scbench" and final_metrics.scbench_turn_accuracy:
1049
- summary['scbench_accuracy'] = float(np.mean(final_metrics.scbench_turn_accuracy))
1050
- elif config.benchmark_type == "longbench" and final_metrics.longbench_scores:
1051
- summary['longbench_accuracy'] = float(np.mean([s['accuracy'] for s in final_metrics.longbench_scores]))
1052
- else:
1053
- summary['prefill_perplexity'] = final_metrics.prefill_perplexity_mean
1054
- summary['generation_perplexity'] = final_metrics.generation_perplexity_mean
1055
-
1056
- summary['prefill_time_ms'] = final_metrics.prefill_time_mean * 1000
1057
- summary['decode_time_ms'] = final_metrics.decode_time_per_token_mean_ms
1058
- summary['throughput_tokens_sec'] = final_metrics.decode_tokens_per_sec
1059
- summary['end_to_end_throughput'] = final_metrics.end_to_end_throughput
1060
- summary['end_to_end_latency_ms'] = final_metrics.end_to_end_latency_ms
1061
- summary['peak_memory_mb'] = final_metrics.prefill_peak_memory_mean_mb
1062
-
1063
- return final_metrics, summary, per_sample_records, per_layer_fingerprints
1064
-
1065
-
1066
- def export_proof_bundle(bundle_dir: str, config: CompressionConfig,
1067
- metrics: BenchmarkMetrics, summary: Dict[str, Any],
1068
- per_sample_records: List[Dict[str, Any]],
1069
- per_layer_fingerprints: List[Dict[str, Any]]) -> str:
1070
- """Export attestable proof bundle with all metrics and fingerprints."""
1071
- p = pathlib.Path(bundle_dir)
1072
- p.mkdir(parents=True, exist_ok=True)
1073
-
1074
- manifest = {
1075
- "config": json.loads(config.to_json()),
1076
- "config_hash": config.get_hash(),
1077
- "model": config.model_name,
1078
- "benchmark_type": config.benchmark_type,
1079
- "python": sys.version,
1080
- "torch": config.torch_version,
1081
- "transformers": config.transformers_version,
1082
- "cuda": config.cuda_version,
1083
- "device_name": config.device_name,
1084
- "start_time": summary.get("start_time"),
1085
- "end_time": summary.get("end_time"),
1086
- "hostname": platform.node()
1087
- }
1088
-
1089
- (p / "manifest.json").write_text(json.dumps(manifest, indent=2))
1090
- (p / "summary.json").write_text(json.dumps(summary, indent=2, default=str))
1091
-
1092
- records_dir = p / "records"
1093
- records_dir.mkdir(exist_ok=True)
1094
-
1095
- with open(records_dir / "metrics.jsonl", "w") as f:
1096
- for r in per_sample_records:
1097
- f.write(json.dumps(r, default=str) + "\n")
1098
-
1099
- with open(records_dir / "kv_fingerprints.jsonl", "w") as f:
1100
- for r in per_layer_fingerprints:
1101
- f.write(json.dumps(r, default=str) + "\n")
1102
-
1103
- try:
1104
- env_text = subprocess.check_output([sys.executable, "-m", "pip", "freeze"], text=True)
1105
- (p / "env.lock").write_text(env_text)
1106
- except Exception as e:
1107
- logger.warning(f"Could not capture environment: {e}")
1108
- (p / "env.lock").write_text(f"# Environment capture failed: {e}\n")
1109
-
1110
- zip_path = str(p.with_suffix(".zip"))
1111
- with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as z:
1112
- for root, _, files in os.walk(p):
1113
- for name in files:
1114
- full = pathlib.Path(root) / name
1115
- z.write(full, arcname=str(full.relative_to(p)))
1116
-
1117
- logger.info(f"Proof bundle exported: {zip_path}")
1118
- return zip_path
1119
-
1120
-
1121
- def verify_proof_bundle(bundle_root: str, config: CompressionConfig, proving: ProvingConfig) -> Dict[str, Any]:
1122
- """Verify proof bundle - recompute metrics and check tolerances."""
1123
- try:
1124
- with open(os.path.join(bundle_root, "summary.json")) as f:
1125
- summary = json.load(f)
1126
-
1127
- records = []
1128
- with open(os.path.join(bundle_root, "records", "metrics.jsonl")) as f:
1129
- for line in f:
1130
- if line.strip():
1131
- records.append(json.loads(line))
1132
- except Exception as e:
1133
- raise RuntimeError(f"Failed to load proof bundle: {e}")
1134
-
1135
- if not records:
1136
- raise ValueError("No per-sample records found in proof bundle")
1137
-
1138
- primary_method = summary.get("compression_type", "enhanced_spg")
1139
- primary_records = [r for r in records if r.get("compression_type") == primary_method]
1140
-
1141
- if not primary_records:
1142
- raise ValueError(f"No records found for method {primary_method}")
1143
-
1144
- logger.info(f"Verifying {len(primary_records)} records for {primary_method}")
1145
-
1146
- def mean_of(key):
1147
- vals = [float(r[key]) for r in primary_records if key in r and r[key] is not None]
1148
- return float(np.mean(vals)) if vals else None
1149
-
1150
- recomputed = {}
1151
- failures = []
1152
-
1153
- if config.benchmark_type == "niah":
1154
- if "niah_accuracy" in summary:
1155
- recomputed["niah_accuracy"] = mean_of("accuracy")
1156
- elif config.benchmark_type == "ruler":
1157
- if "ruler_exact_match" in summary:
1158
- recomputed["ruler_exact_match"] = mean_of("exact_match")
1159
- elif config.benchmark_type == "scbench":
1160
- if "scbench_accuracy" in summary:
1161
- recomputed["scbench_accuracy"] = mean_of("accuracy")
1162
- elif config.benchmark_type == "longbench":
1163
- if "longbench_accuracy" in summary:
1164
- recomputed["longbench_accuracy"] = mean_of("accuracy")
1165
- elif config.benchmark_type == "wikitext":
1166
- if "prefill_perplexity" in summary:
1167
- recomputed["prefill_perplexity"] = mean_of("prefill_perplexity")
1168
- if "generation_perplexity" in summary:
1169
- recomputed["generation_perplexity"] = mean_of("generation_perplexity")
1170
-
1171
- recomputed["compression_ratio"] = mean_of("compression_ratio")
1172
- recomputed["kv_cache_memory_mb"] = mean_of("kv_cache_memory_mb")
1173
-
1174
- for k, v in recomputed.items():
1175
- s = summary.get(k)
1176
- if v is not None and s is not None:
1177
- if abs(v - float(s)) > proving.numeric_tolerance:
1178
- failures.append(f"{k}: recomputed {v:.6f} != summary {s:.6f}")
1179
-
1180
- ok = len(failures) == 0
1181
-
1182
- result = {
1183
- "ok": ok,
1184
- "failures": failures,
1185
- "recomputed": recomputed,
1186
- "summary": summary,
1187
- "n_samples": len(records)
1188
- }
1189
-
1190
- if not ok:
1191
- logger.error(f"Proof verification FAILED: {failures}")
1192
- else:
1193
- logger.info(f"Proof verification PASSED for {len(records)} samples")
1194
-
1195
- return result