|
|
|
|
|
""" |
|
|
KV Cache Compression - SOTA-Oriented App |
|
|
- 4-bit NF4 weights (bitsandbytes) |
|
|
- FlashAttention-2 kernels |
|
|
- HF Cache subclass with KV int8 / packed int4 + optional H2O token dropping |
|
|
- Clean benchmark: prefill & decode throughput, KV memory footprint, perplexity |
|
|
|
|
|
Run: |
|
|
python app.py --model gpt2 --methods none int8 int4 --h2o --ctx 1024 --gen 256 |
|
|
""" |
|
|
|
|
|
import os |
|
|
import json |
|
|
import math |
|
|
import time |
|
|
import gc |
|
|
import random |
|
|
import logging |
|
|
from dataclasses import dataclass |
|
|
from typing import Dict, Tuple, Optional, Any, List |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache |
|
|
from transformers.cache_utils import Cache |
|
|
from datasets import load_dataset |
|
|
from tqdm import tqdm |
|
|
import pandas as pd |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") |
|
|
log = logging.getLogger("kv-sota") |
|
|
|
|
|
def set_seed(seed: int = 42): |
|
|
random.seed(seed); np.random.seed(seed); torch.manual_seed(seed) |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.manual_seed_all(seed) |
|
|
torch.backends.cudnn.deterministic = True |
|
|
torch.backends.cudnn.benchmark = False |
|
|
|
|
|
|
|
|
def load_model_4bit(model_name: str): |
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
tok = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) |
|
|
if tok.pad_token is None: |
|
|
tok.pad_token = tok.eos_token |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
model_name, |
|
|
device_map="auto" if device == "cuda" else None, |
|
|
torch_dtype=torch.float16 if device == "cuda" else torch.float32, |
|
|
attn_implementation="flash_attention_2", |
|
|
low_cpu_mem_usage=True, |
|
|
trust_remote_code=True, |
|
|
|
|
|
load_in_4bit=True, |
|
|
bnb_4bit_quant_type="nf4", |
|
|
bnb_4bit_compute_dtype=torch.float16, |
|
|
bnb_4bit_use_double_quant=True, |
|
|
) |
|
|
if device == "cpu": |
|
|
model = model.to(device) |
|
|
model.eval() |
|
|
return tok, model |
|
|
|
|
|
|
|
|
def quantize_int8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
|
|
|
|
|
|
|
|
maxv = x.abs().amax() |
|
|
scale = (maxv / 127.0).clamp(min=1e-8) |
|
|
q = torch.clamp(torch.round(x / scale), -127, 127).to(torch.int8) |
|
|
|
|
|
zp = torch.tensor(0.0, device=x.device, dtype=torch.float32) |
|
|
return q, scale.to(torch.float32), zp |
|
|
|
|
|
def dequantize_int8(q: torch.Tensor, scale: torch.Tensor, zp: torch.Tensor) -> torch.Tensor: |
|
|
return (q.float()) * scale |
|
|
|
|
|
def quantize_pack_int4(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
|
|
|
|
|
maxv = x.abs().amax() |
|
|
scale = (maxv / 7.0).clamp(min=1e-8) |
|
|
q = torch.clamp(torch.round(x / scale), -8, 7).to(torch.int8) |
|
|
|
|
|
|
|
|
q_off = (q + 8).to(torch.uint8) |
|
|
n = q_off.numel() |
|
|
if n % 2 != 0: |
|
|
q_off = torch.cat([q_off, torch.zeros(1, device=q_off.device, dtype=torch.uint8)], dim=0) |
|
|
n += 1 |
|
|
q0 = q_off.view(-1, 2) |
|
|
packed = (q0[:, 0] << 4) | (q0[:, 1] & 0x0F) |
|
|
return packed.contiguous(), scale.to(torch.float32), torch.tensor(0.0, device=x.device, dtype=torch.float32) |
|
|
|
|
|
def unpack_dequantize_int4(packed: torch.Tensor, out_numel: int, scale: torch.Tensor, zp: torch.Tensor) -> torch.Tensor: |
|
|
|
|
|
bytes_ = packed.view(-1) |
|
|
hi = (bytes_ >> 4) & 0x0F |
|
|
lo = bytes_ & 0x0F |
|
|
q_off = torch.stack([hi, lo], dim=1).view(-1) |
|
|
q_off = q_off[:out_numel] |
|
|
q = (q_off.to(torch.int16) - 8).to(torch.float32) |
|
|
return q * scale |
|
|
|
|
|
|
|
|
def select_h2o_indices(attn_scores: Optional[torch.Tensor], |
|
|
seq_len: int, keep_ratio: float, sink_tokens: int = 4) -> torch.Tensor: |
|
|
|
|
|
k = max(sink_tokens, int(seq_len * keep_ratio)) |
|
|
keep = torch.zeros(seq_len, dtype=torch.bool, device=attn_scores.device if attn_scores is not None else "cuda" if torch.cuda.is_available() else "cpu") |
|
|
keep[:min(sink_tokens, seq_len)] = True |
|
|
if seq_len > sink_tokens and k > sink_tokens: |
|
|
if attn_scores is None: |
|
|
|
|
|
keep[-(k - sink_tokens):] = True |
|
|
else: |
|
|
scores = attn_scores.clone() |
|
|
scores[:sink_tokens] = -1e9 |
|
|
topk = torch.topk(scores, k - sink_tokens).indices |
|
|
keep[topk] = True |
|
|
return keep |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class KVConfig: |
|
|
method: str = "none" |
|
|
use_h2o: bool = False |
|
|
h2o_keep: float = 0.5 |
|
|
sink_tokens: int = 4 |
|
|
|
|
|
class QuantizedDynamicCache(Cache): |
|
|
""" |
|
|
Stores KV compressed (int8 or packed int4), but returns fp16 on update. |
|
|
This keeps dtype consistent with Q (fp16) for FlashAttention kernels. |
|
|
""" |
|
|
def __init__(self, cfg: KVConfig): |
|
|
super().__init__() |
|
|
self.cfg = cfg |
|
|
|
|
|
self.K: List[Optional[Any]] = [] |
|
|
self.V: List[Optional[Any]] = [] |
|
|
self.meta: List[Dict[str, Any]] = [] |
|
|
self.seq_lens: List[int] = [] |
|
|
|
|
|
self.importance: List[Optional[torch.Tensor]] = [] |
|
|
|
|
|
def _ensure(self, layer_idx: int): |
|
|
while len(self.K) <= layer_idx: |
|
|
self.K.append(None); self.V.append(None) |
|
|
self.meta.append({}) |
|
|
self.seq_lens.append(0) |
|
|
self.importance.append(None) |
|
|
|
|
|
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: |
|
|
if layer_idx < len(self.seq_lens): |
|
|
return self.seq_lens[layer_idx] |
|
|
return 0 |
|
|
|
|
|
def get_max_length(self) -> Optional[int]: |
|
|
return None |
|
|
|
|
|
def _compress(self, x: torch.Tensor, method: str): |
|
|
if method == "int8": |
|
|
return ("int8",) + quantize_int8(x) |
|
|
elif method == "int4": |
|
|
packed, scale, zp = quantize_pack_int4(x) |
|
|
return ("int4", packed, scale, zp, x.numel(), x.shape) |
|
|
else: |
|
|
return ("none", x.clone()) |
|
|
|
|
|
def _decompress(self, obj: Tuple): |
|
|
tag = obj[0] |
|
|
if tag == "none": |
|
|
return obj[1] |
|
|
elif tag == "int8": |
|
|
_tag, q, scale, zp = obj |
|
|
return dequantize_int8(q, scale, zp) |
|
|
elif tag == "int4": |
|
|
_tag, packed, scale, zp, numel, shape = obj |
|
|
x = unpack_dequantize_int4(packed, numel, scale, zp) |
|
|
return x.view(shape).to(torch.float16) |
|
|
else: |
|
|
raise ValueError("Unknown compression tag") |
|
|
|
|
|
def _concat_decompressed(self, prev_obj, new_x): |
|
|
if prev_obj is None: |
|
|
return new_x |
|
|
prev = self._decompress(prev_obj) |
|
|
return torch.cat([prev, new_x], dim=-2) |
|
|
|
|
|
def update(self, key_states: torch.Tensor, value_states: torch.Tensor, |
|
|
layer_idx: int, cache_kwargs: Optional[Dict[str, Any]] = None) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
""" |
|
|
HF calls this every forward. We must: |
|
|
1) Decompress previous cached K/V (if any) |
|
|
2) Optionally apply H2O token selection on the concatenated sequence |
|
|
3) Re-compress & store |
|
|
4) Return fp16 K/V of the *current full* cache for immediate attention use |
|
|
""" |
|
|
self._ensure(layer_idx) |
|
|
|
|
|
|
|
|
assert key_states.dim() == 4 and value_states.dim() == 4, "Unexpected KV dims" |
|
|
|
|
|
prevK = self.K[layer_idx] |
|
|
prevV = self.V[layer_idx] |
|
|
|
|
|
full_k = self._concat_decompressed(prevK, key_states) |
|
|
full_v = self._concat_decompressed(prevV, value_states) |
|
|
seq_len = full_k.shape[-2] |
|
|
|
|
|
|
|
|
if self.cfg.use_h2o and seq_len > self.cfg.sink_tokens: |
|
|
|
|
|
imp = full_k.float().pow(2).sum(dim=-1).mean(dim=(0,1)) |
|
|
keep_mask = select_h2o_indices(imp, seq_len, self.cfg.h2o_keep, self.cfg.sink_tokens) |
|
|
full_k = full_k[:, :, keep_mask, :] |
|
|
full_v = full_v[:, :, keep_mask, :] |
|
|
self.importance[layer_idx] = imp.detach() |
|
|
seq_len = full_k.shape[-2] |
|
|
|
|
|
|
|
|
if self.cfg.method == "none": |
|
|
self.K[layer_idx] = ("none", full_k.to(torch.float16).contiguous()) |
|
|
self.V[layer_idx] = ("none", full_v.to(torch.float16).contiguous()) |
|
|
elif self.cfg.method == "int8": |
|
|
self.K[layer_idx] = self._compress(full_k.to(torch.float16), "int8") |
|
|
self.V[layer_idx] = self._compress(full_v.to(torch.float16), "int8") |
|
|
elif self.cfg.method == "int4": |
|
|
self.K[layer_idx] = self._compress(full_k.to(torch.float16), "int4") |
|
|
self.V[layer_idx] = self._compress(full_v.to(torch.float16), "int4") |
|
|
else: |
|
|
raise ValueError(f"Unknown KV method {self.cfg.method}") |
|
|
|
|
|
self.seq_lens[layer_idx] = seq_len |
|
|
|
|
|
|
|
|
|
|
|
return full_k.to(torch.float16), full_v.to(torch.float16) |
|
|
|
|
|
|
|
|
def memory_bytes(self) -> int: |
|
|
total = 0 |
|
|
for obj in (self.K + self.V): |
|
|
if obj is None: |
|
|
continue |
|
|
tag = obj[0] |
|
|
if tag == "none": |
|
|
t = obj[1] |
|
|
total += t.numel() * t.element_size() |
|
|
elif tag == "int8": |
|
|
_, q, scale, zp = obj |
|
|
total += q.numel() * q.element_size() |
|
|
total += 8 |
|
|
elif tag == "int4": |
|
|
_, packed, scale, zp, numel, shape = obj |
|
|
total += packed.numel() * packed.element_size() |
|
|
total += 8 |
|
|
return total |
|
|
|
|
|
|
|
|
def load_texts(tokenizer, n_samples: int, min_tokens: int) -> List[str]: |
|
|
texts = [] |
|
|
try: |
|
|
ds = load_dataset("wikitext", "wikitext-2-raw-v1", split="test", streaming=True) |
|
|
for ex in ds: |
|
|
t = ex.get("text", "").strip() |
|
|
if not t: |
|
|
continue |
|
|
toks = tokenizer.encode(t, add_special_tokens=False) |
|
|
if len(toks) >= min_tokens: |
|
|
texts.append(t) |
|
|
if len(texts) >= n_samples: |
|
|
break |
|
|
except Exception as e: |
|
|
log.warning(f"Dataset stream failed: {e}") |
|
|
|
|
|
if len(texts) < n_samples: |
|
|
log.info("Using fallback synthetic texts") |
|
|
base = "The quick brown fox jumps over the lazy dog. " * 200 |
|
|
while len(texts) < n_samples: |
|
|
texts.append(base) |
|
|
return texts[:n_samples] |
|
|
|
|
|
def cuda_sync(): |
|
|
if torch.cuda.is_available(): |
|
|
try: torch.cuda.synchronize() |
|
|
except Exception: pass |
|
|
|
|
|
def mem_alloc(): |
|
|
if torch.cuda.is_available(): |
|
|
try: return torch.cuda.memory_allocated() |
|
|
except Exception: return 0 |
|
|
return 0 |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class BenchCfg: |
|
|
eval_samples: int = 20 |
|
|
prefill_len: int = 1024 |
|
|
gen_len: int = 256 |
|
|
warmup: int = 2 |
|
|
|
|
|
def run_benchmark(tokenizer, model, methods: List[str], use_h2o: bool, |
|
|
h2o_keep: float, sink_tokens: int, |
|
|
bench: BenchCfg) -> pd.DataFrame: |
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
texts = load_texts(tokenizer, bench.eval_samples, bench.prefill_len) |
|
|
results = [] |
|
|
baseline = None |
|
|
|
|
|
|
|
|
weights_bytes = sum(p.numel() * p.element_size() for p in model.parameters() if p.is_cuda) |
|
|
|
|
|
for method in methods: |
|
|
log.info("\n" + "="*60 + f"\nTesting: {method.upper()}\n" + "="*60) |
|
|
kv_cfg = KVConfig(method=method, use_h2o=use_h2o, h2o_keep=h2o_keep, sink_tokens=sink_tokens) |
|
|
cache: Cache = QuantizedDynamicCache(kv_cfg) if method != "none" else DynamicCache() |
|
|
|
|
|
|
|
|
log.info("Warmup...") |
|
|
for _ in range(bench.warmup): |
|
|
dummy = torch.randint(0, tokenizer.vocab_size, (1, 32), device=device) |
|
|
with torch.no_grad(): |
|
|
_ = model(dummy, use_cache=False) |
|
|
|
|
|
prefill_times = [] |
|
|
decode_times = [] |
|
|
ppl_list = [] |
|
|
kv_memory_last = 0 |
|
|
|
|
|
for text in tqdm(texts, desc="Benchmark"): |
|
|
inputs = tokenizer(text, return_tensors="pt", truncation=True, |
|
|
max_length=bench.prefill_len, padding="max_length") |
|
|
input_ids = inputs.input_ids.to(device) |
|
|
attn_mask = inputs.attention_mask.to(device) |
|
|
|
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.empty_cache() |
|
|
torch.cuda.reset_peak_memory_stats() |
|
|
t0 = time.perf_counter() |
|
|
with torch.no_grad(): |
|
|
out = model(input_ids, attention_mask=attn_mask, past_key_values=cache, use_cache=True, return_dict=True) |
|
|
kv = out.past_key_values |
|
|
cuda_sync() |
|
|
prefill_times.append(time.perf_counter() - t0) |
|
|
|
|
|
|
|
|
next_tok = input_ids[:, -1:] |
|
|
steps = bench.gen_len |
|
|
for _ in range(steps): |
|
|
t1 = time.perf_counter() |
|
|
with torch.no_grad(): |
|
|
out = model(next_tok, past_key_values=cache, use_cache=True, return_dict=True) |
|
|
cuda_sync() |
|
|
decode_times.append(time.perf_counter() - t1) |
|
|
logits = out.logits |
|
|
next_tok = torch.argmax(logits[:, -1:], dim=-1) |
|
|
|
|
|
|
|
|
try: |
|
|
with torch.no_grad(): |
|
|
out = model(input_ids, attention_mask=attn_mask, labels=input_ids, return_dict=True) |
|
|
ppl = float(torch.exp(out.loss).item()) |
|
|
ppl_list.append(min(ppl, 1e3)) |
|
|
except Exception: |
|
|
ppl_list.append(100.0) |
|
|
|
|
|
|
|
|
if isinstance(cache, QuantizedDynamicCache): |
|
|
kv_memory_last = cache.memory_bytes() |
|
|
|
|
|
prefill_s = np.mean(prefill_times) |
|
|
decode_ms_per_tok = (np.mean(decode_times) * 1000.0) if decode_times else 0.0 |
|
|
tok_per_s = (1.0 / np.mean(decode_times)) if decode_times else 0.0 |
|
|
ppl_mean = float(np.mean(ppl_list)) if ppl_list else float("nan") |
|
|
kv_mb = (kv_memory_last / (1024**2)) if kv_memory_last else 0.0 |
|
|
weights_mb = weights_bytes / (1024**2) |
|
|
|
|
|
row = { |
|
|
"method": method, |
|
|
"prefill_s": round(prefill_s, 4), |
|
|
"decode_ms_tok": round(decode_ms_per_tok, 2), |
|
|
"tok_per_s": round(tok_per_s, 2), |
|
|
"ppl": round(ppl_mean, 3), |
|
|
"kv_MB": round(kv_mb, 2), |
|
|
"weights_MB": round(weights_mb, 2), |
|
|
} |
|
|
|
|
|
if method == "none": |
|
|
baseline = row |
|
|
row.update({"kv_reduction_x": 1.0, "speedup_x": 1.0}) |
|
|
else: |
|
|
|
|
|
|
|
|
|
|
|
red = 1.0 |
|
|
if isinstance(cache, QuantizedDynamicCache) and cache.K and cache.K[0] is not None: |
|
|
|
|
|
tag = cache.K[0][0] |
|
|
if tag == "none": |
|
|
last = cache.K[0][1] |
|
|
bytes_per_elem = 2 |
|
|
elif tag == "int8": |
|
|
|
|
|
|
|
|
|
|
|
k_est = cache._decompress(cache.K[0]) |
|
|
bytes_per_elem = 2 |
|
|
last = k_est |
|
|
elif tag == "int4": |
|
|
k_est = cache._decompress(cache.K[0]) |
|
|
bytes_per_elem = 2 |
|
|
last = k_est |
|
|
else: |
|
|
last = None |
|
|
|
|
|
if last is not None: |
|
|
baseline_bytes = last.numel() * bytes_per_elem |
|
|
comp_bytes = cache.memory_bytes() / 2 |
|
|
|
|
|
if comp_bytes > 0: |
|
|
red = (baseline_bytes) / comp_bytes |
|
|
row["kv_reduction_x"] = round(float(red), 2) |
|
|
row["speedup_x"] = round((row["tok_per_s"] / max(baseline["tok_per_s"], 1e-6)), 2) if baseline else 1.0 |
|
|
|
|
|
log.info( |
|
|
f"\n{method.upper()} Results:\n" |
|
|
f" Prefill: {row['prefill_s']} s | Decode: {row['decode_ms_tok']} ms/tok " |
|
|
f"({row['tok_per_s']} tok/s) | PPL: {row['ppl']}\n" |
|
|
f" Weights: ~{row['weights_MB']} MB | KV: ~{row['kv_MB']} MB | KV reduction: {row['kv_reduction_x']}×\n" |
|
|
) |
|
|
|
|
|
results.append(row) |
|
|
gc.collect() |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
df = pd.DataFrame(results) |
|
|
return df |
|
|
|
|
|
|
|
|
def main(): |
|
|
import argparse |
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument("--model", type=str, default="gpt2") |
|
|
parser.add_argument("--methods", nargs="+", default=["none", "int8", "int4"], |
|
|
help="KV cache methods: none int8 int4") |
|
|
parser.add_argument("--h2o", action="store_true", help="Enable H2O token dropping") |
|
|
parser.add_argument("--h2o_keep", type=float, default=0.5, help="Fraction to keep (0..1)") |
|
|
parser.add_argument("--sink", type=int, default=4, help="Sink tokens always kept") |
|
|
parser.add_argument("--samples", type=int, default=20) |
|
|
parser.add_argument("--ctx", type=int, default=1024) |
|
|
parser.add_argument("--gen", type=int, default=256) |
|
|
parser.add_argument("--warmup", type=int, default=2) |
|
|
parser.add_argument("--out", type=str, default="./kv_sota_results") |
|
|
parser.add_argument("--seed", type=int, default=42) |
|
|
args = parser.parse_args() |
|
|
|
|
|
set_seed(args.seed) |
|
|
|
|
|
log.info("Loading model with 4-bit weights + FlashAttention-2...") |
|
|
tok, model = load_model_4bit(args.model) |
|
|
|
|
|
bench = BenchCfg(eval_samples=args.samples, prefill_len=args.ctx, gen_len=args.gen, warmup=args.warmup) |
|
|
df = run_benchmark( |
|
|
tok, model, args.methods, use_h2o=args.h2o, |
|
|
h2o_keep=args.h2o_keep, sink_tokens=args.sink, bench=bench |
|
|
) |
|
|
|
|
|
os.makedirs(args.out, exist_ok=True) |
|
|
df.to_csv(os.path.join(args.out, "summary.csv"), index=False) |
|
|
with open(os.path.join(args.out, "summary.json"), "w") as f: |
|
|
json.dump(df.to_dict(orient="records"), f, indent=2) |
|
|
|
|
|
print("\n" + "="*80) |
|
|
print("BENCHMARK SUMMARY") |
|
|
print("="*80) |
|
|
print(df.to_string(index=False)) |
|
|
print(f"\n✅ Results saved in: {args.out}") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|