serpent / app.py
kfoughali's picture
Update app.py
57fa4b9 verified
raw
history blame
20.2 kB
# app.py
"""
KV Cache Compression - SOTA-Oriented App
- 4-bit NF4 weights (bitsandbytes)
- FlashAttention-2 kernels
- HF Cache subclass with KV int8 / packed int4 + optional H2O token dropping
- Clean benchmark: prefill & decode throughput, KV memory footprint, perplexity
Run:
python app.py --model gpt2 --methods none int8 int4 --h2o --ctx 1024 --gen 256
"""
import os
import json
import math
import time
import gc
import random
import logging
from dataclasses import dataclass
from typing import Dict, Tuple, Optional, Any, List
import numpy as np
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache
from transformers.cache_utils import Cache
from datasets import load_dataset
from tqdm import tqdm
import pandas as pd
# ----------------------------- Logging & Repro --------------------------------
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
log = logging.getLogger("kv-sota")
def set_seed(seed: int = 42):
random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# ---------------------------- Model Loader (SOTA) -----------------------------
def load_model_4bit(model_name: str):
device = "cuda" if torch.cuda.is_available() else "cpu"
tok = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
if tok.pad_token is None:
tok.pad_token = tok.eos_token
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="auto" if device == "cuda" else None,
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
attn_implementation="flash_attention_2", # fused kernels
low_cpu_mem_usage=True,
trust_remote_code=True,
# 4-bit weights
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
)
if device == "cpu":
model = model.to(device)
model.eval()
return tok, model
# ----------------------- Quantization helpers (KV) ----------------------------
def quantize_int8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# per-tensor symmetric int8 for speed; good trade-off
# x_fp16 -> q_i8 = round(x/scale).clip(-127,127)
maxv = x.abs().amax()
scale = (maxv / 127.0).clamp(min=1e-8)
q = torch.clamp(torch.round(x / scale), -127, 127).to(torch.int8)
# zero-point kept at 0 for symmetric
zp = torch.tensor(0.0, device=x.device, dtype=torch.float32)
return q, scale.to(torch.float32), zp
def dequantize_int8(q: torch.Tensor, scale: torch.Tensor, zp: torch.Tensor) -> torch.Tensor:
return (q.float()) * scale
def quantize_pack_int4(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# symmetric int4 in [-8,7], then pack 2 nibbles per byte
maxv = x.abs().amax()
scale = (maxv / 7.0).clamp(min=1e-8)
q = torch.clamp(torch.round(x / scale), -8, 7).to(torch.int8) # store as i8 for packing step
# pack: two signed nibbles -> byte (we store sign-magnitude via +8 offset)
q_off = (q + 8).to(torch.uint8) # [0..15]
n = q_off.numel()
if n % 2 != 0:
q_off = torch.cat([q_off, torch.zeros(1, device=q_off.device, dtype=torch.uint8)], dim=0)
n += 1
q0 = q_off.view(-1, 2)
packed = (q0[:, 0] << 4) | (q0[:, 1] & 0x0F)
return packed.contiguous(), scale.to(torch.float32), torch.tensor(0.0, device=x.device, dtype=torch.float32)
def unpack_dequantize_int4(packed: torch.Tensor, out_numel: int, scale: torch.Tensor, zp: torch.Tensor) -> torch.Tensor:
# unpack two nibbles per byte back to [-8,7] then dequantize
bytes_ = packed.view(-1)
hi = (bytes_ >> 4) & 0x0F
lo = bytes_ & 0x0F
q_off = torch.stack([hi, lo], dim=1).view(-1)
q_off = q_off[:out_numel]
q = (q_off.to(torch.int16) - 8).to(torch.float32) # [-8..7]
return q * scale # symmetric, zp=0
# ------------------------ H2O importance (optional) ---------------------------
def select_h2o_indices(attn_scores: Optional[torch.Tensor],
seq_len: int, keep_ratio: float, sink_tokens: int = 4) -> torch.Tensor:
# attn_scores: (seq_len,) importance; if None -> magnitude fallback
k = max(sink_tokens, int(seq_len * keep_ratio))
keep = torch.zeros(seq_len, dtype=torch.bool, device=attn_scores.device if attn_scores is not None else "cuda" if torch.cuda.is_available() else "cpu")
keep[:min(sink_tokens, seq_len)] = True
if seq_len > sink_tokens and k > sink_tokens:
if attn_scores is None:
# if no scores provided, keep a moving window as a safe default
keep[-(k - sink_tokens):] = True
else:
scores = attn_scores.clone()
scores[:sink_tokens] = -1e9
topk = torch.topk(scores, k - sink_tokens).indices
keep[topk] = True
return keep
# --------------------------- KV Cache (HF Cache API) --------------------------
@dataclass
class KVConfig:
method: str = "none" # "none" | "int8" | "int4"
use_h2o: bool = False
h2o_keep: float = 0.5 # fraction to keep
sink_tokens: int = 4
class QuantizedDynamicCache(Cache):
"""
Stores KV compressed (int8 or packed int4), but returns fp16 on update.
This keeps dtype consistent with Q (fp16) for FlashAttention kernels.
"""
def __init__(self, cfg: KVConfig):
super().__init__()
self.cfg = cfg
# per-layer storage
self.K: List[Optional[Any]] = []
self.V: List[Optional[Any]] = []
self.meta: List[Dict[str, Any]] = []
self.seq_lens: List[int] = []
# optional importance stats (very lightweight proxy)
self.importance: List[Optional[torch.Tensor]] = []
def _ensure(self, layer_idx: int):
while len(self.K) <= layer_idx:
self.K.append(None); self.V.append(None)
self.meta.append({})
self.seq_lens.append(0)
self.importance.append(None)
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
if layer_idx < len(self.seq_lens):
return self.seq_lens[layer_idx]
return 0
def get_max_length(self) -> Optional[int]:
return None
def _compress(self, x: torch.Tensor, method: str):
if method == "int8":
return ("int8",) + quantize_int8(x)
elif method == "int4":
packed, scale, zp = quantize_pack_int4(x)
return ("int4", packed, scale, zp, x.numel(), x.shape)
else:
return ("none", x.clone())
def _decompress(self, obj: Tuple):
tag = obj[0]
if tag == "none":
return obj[1]
elif tag == "int8":
_tag, q, scale, zp = obj
return dequantize_int8(q, scale, zp)
elif tag == "int4":
_tag, packed, scale, zp, numel, shape = obj
x = unpack_dequantize_int4(packed, numel, scale, zp)
return x.view(shape).to(torch.float16)
else:
raise ValueError("Unknown compression tag")
def _concat_decompressed(self, prev_obj, new_x):
if prev_obj is None:
return new_x
prev = self._decompress(prev_obj)
return torch.cat([prev, new_x], dim=-2)
def update(self, key_states: torch.Tensor, value_states: torch.Tensor,
layer_idx: int, cache_kwargs: Optional[Dict[str, Any]] = None) -> Tuple[torch.Tensor, torch.Tensor]:
"""
HF calls this every forward. We must:
1) Decompress previous cached K/V (if any)
2) Optionally apply H2O token selection on the concatenated sequence
3) Re-compress & store
4) Return fp16 K/V of the *current full* cache for immediate attention use
"""
self._ensure(layer_idx)
# (B, H, T, D) — HF uses this shape inside attention
assert key_states.dim() == 4 and value_states.dim() == 4, "Unexpected KV dims"
prevK = self.K[layer_idx]
prevV = self.V[layer_idx]
full_k = self._concat_decompressed(prevK, key_states)
full_v = self._concat_decompressed(prevV, value_states)
seq_len = full_k.shape[-2]
# Optional H2O: drop low-importance tokens
if self.cfg.use_h2o and seq_len > self.cfg.sink_tokens:
# simple proxy importance: L2 norm per token (avg over batch, heads, dim)
imp = full_k.float().pow(2).sum(dim=-1).mean(dim=(0,1)) # (T,)
keep_mask = select_h2o_indices(imp, seq_len, self.cfg.h2o_keep, self.cfg.sink_tokens)
full_k = full_k[:, :, keep_mask, :]
full_v = full_v[:, :, keep_mask, :]
self.importance[layer_idx] = imp.detach()
seq_len = full_k.shape[-2]
# Re-compress & store
if self.cfg.method == "none":
self.K[layer_idx] = ("none", full_k.to(torch.float16).contiguous())
self.V[layer_idx] = ("none", full_v.to(torch.float16).contiguous())
elif self.cfg.method == "int8":
self.K[layer_idx] = self._compress(full_k.to(torch.float16), "int8")
self.V[layer_idx] = self._compress(full_v.to(torch.float16), "int8")
elif self.cfg.method == "int4":
self.K[layer_idx] = self._compress(full_k.to(torch.float16), "int4")
self.V[layer_idx] = self._compress(full_v.to(torch.float16), "int4")
else:
raise ValueError(f"Unknown KV method {self.cfg.method}")
self.seq_lens[layer_idx] = seq_len
# Return fp16 tensors for the current step (attention immediately uses these)
# Important: these must match the dtype of Q (fp16) to avoid kernel dtype errors.
return full_k.to(torch.float16), full_v.to(torch.float16)
# Memory accounting for KV (bytes actually held by cache)
def memory_bytes(self) -> int:
total = 0
for obj in (self.K + self.V):
if obj is None:
continue
tag = obj[0]
if tag == "none":
t = obj[1]
total += t.numel() * t.element_size()
elif tag == "int8":
_, q, scale, zp = obj
total += q.numel() * q.element_size() # int8
total += 8 # scale+zp approx
elif tag == "int4":
_, packed, scale, zp, numel, shape = obj
total += packed.numel() * packed.element_size() # uint8 (2 values/byte)
total += 8 # scale+zp approx
return total
# ---------------------------- Data & Utilities --------------------------------
def load_texts(tokenizer, n_samples: int, min_tokens: int) -> List[str]:
texts = []
try:
ds = load_dataset("wikitext", "wikitext-2-raw-v1", split="test", streaming=True)
for ex in ds:
t = ex.get("text", "").strip()
if not t:
continue
toks = tokenizer.encode(t, add_special_tokens=False)
if len(toks) >= min_tokens:
texts.append(t)
if len(texts) >= n_samples:
break
except Exception as e:
log.warning(f"Dataset stream failed: {e}")
if len(texts) < n_samples:
log.info("Using fallback synthetic texts")
base = "The quick brown fox jumps over the lazy dog. " * 200
while len(texts) < n_samples:
texts.append(base)
return texts[:n_samples]
def cuda_sync():
if torch.cuda.is_available():
try: torch.cuda.synchronize()
except Exception: pass
def mem_alloc():
if torch.cuda.is_available():
try: return torch.cuda.memory_allocated()
except Exception: return 0
return 0
# ------------------------------ Benchmark -------------------------------------
@dataclass
class BenchCfg:
eval_samples: int = 20
prefill_len: int = 1024
gen_len: int = 256
warmup: int = 2
def run_benchmark(tokenizer, model, methods: List[str], use_h2o: bool,
h2o_keep: float, sink_tokens: int,
bench: BenchCfg) -> pd.DataFrame:
device = "cuda" if torch.cuda.is_available() else "cpu"
texts = load_texts(tokenizer, bench.eval_samples, bench.prefill_len)
results = []
baseline = None
# measure weights (rough) — GPU params only
weights_bytes = sum(p.numel() * p.element_size() for p in model.parameters() if p.is_cuda)
for method in methods:
log.info("\n" + "="*60 + f"\nTesting: {method.upper()}\n" + "="*60)
kv_cfg = KVConfig(method=method, use_h2o=use_h2o, h2o_keep=h2o_keep, sink_tokens=sink_tokens)
cache: Cache = QuantizedDynamicCache(kv_cfg) if method != "none" else DynamicCache()
# warmup
log.info("Warmup...")
for _ in range(bench.warmup):
dummy = torch.randint(0, tokenizer.vocab_size, (1, 32), device=device)
with torch.no_grad():
_ = model(dummy, use_cache=False)
prefill_times = []
decode_times = []
ppl_list = []
kv_memory_last = 0
for text in tqdm(texts, desc="Benchmark"):
inputs = tokenizer(text, return_tensors="pt", truncation=True,
max_length=bench.prefill_len, padding="max_length")
input_ids = inputs.input_ids.to(device)
attn_mask = inputs.attention_mask.to(device)
# Prefill
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
t0 = time.perf_counter()
with torch.no_grad():
out = model(input_ids, attention_mask=attn_mask, past_key_values=cache, use_cache=True, return_dict=True)
kv = out.past_key_values
cuda_sync()
prefill_times.append(time.perf_counter() - t0)
# Decode loop
next_tok = input_ids[:, -1:]
steps = bench.gen_len
for _ in range(steps):
t1 = time.perf_counter()
with torch.no_grad():
out = model(next_tok, past_key_values=cache, use_cache=True, return_dict=True)
cuda_sync()
decode_times.append(time.perf_counter() - t1)
logits = out.logits
next_tok = torch.argmax(logits[:, -1:], dim=-1)
# Perplexity (quick token-level loss on the prefill)
try:
with torch.no_grad():
out = model(input_ids, attention_mask=attn_mask, labels=input_ids, return_dict=True)
ppl = float(torch.exp(out.loss).item())
ppl_list.append(min(ppl, 1e3))
except Exception:
ppl_list.append(100.0)
# KV memory
if isinstance(cache, QuantizedDynamicCache):
kv_memory_last = cache.memory_bytes()
prefill_s = np.mean(prefill_times)
decode_ms_per_tok = (np.mean(decode_times) * 1000.0) if decode_times else 0.0
tok_per_s = (1.0 / np.mean(decode_times)) if decode_times else 0.0
ppl_mean = float(np.mean(ppl_list)) if ppl_list else float("nan")
kv_mb = (kv_memory_last / (1024**2)) if kv_memory_last else 0.0
weights_mb = weights_bytes / (1024**2)
row = {
"method": method,
"prefill_s": round(prefill_s, 4),
"decode_ms_tok": round(decode_ms_per_tok, 2),
"tok_per_s": round(tok_per_s, 2),
"ppl": round(ppl_mean, 3),
"kv_MB": round(kv_mb, 2),
"weights_MB": round(weights_mb, 2),
}
if method == "none":
baseline = row
row.update({"kv_reduction_x": 1.0, "speedup_x": 1.0})
else:
# We can't read baseline kv_MB (DynamicCache doesn't expose). For fair compare,
# estimate baseline KV as fp16 (2 bytes) vs int8(1) vs int4_packed(0.5) on last seq_len.
# Heuristic: use our compressed cache seq len and head dims from last step
red = 1.0
if isinstance(cache, QuantizedDynamicCache) and cache.K and cache.K[0] is not None:
# reconstruct last K to get shape quickly (without full dequant)
tag = cache.K[0][0]
if tag == "none":
last = cache.K[0][1] # fp16
bytes_per_elem = 2
elif tag == "int8":
# int8 compressed -> baseline fp16 would be 2x
# approximate using stored shape
# We can peek shape from decompress metadata:
k_est = cache._decompress(cache.K[0])
bytes_per_elem = 2
last = k_est
elif tag == "int4":
k_est = cache._decompress(cache.K[0])
bytes_per_elem = 2
last = k_est
else:
last = None
if last is not None:
baseline_bytes = last.numel() * bytes_per_elem
comp_bytes = cache.memory_bytes() / 2 # divide by K+V (we counted both); rough split
# better: compute only K portion, but both sides scale similarly
if comp_bytes > 0:
red = (baseline_bytes) / comp_bytes
row["kv_reduction_x"] = round(float(red), 2)
row["speedup_x"] = round((row["tok_per_s"] / max(baseline["tok_per_s"], 1e-6)), 2) if baseline else 1.0
log.info(
f"\n{method.upper()} Results:\n"
f" Prefill: {row['prefill_s']} s | Decode: {row['decode_ms_tok']} ms/tok "
f"({row['tok_per_s']} tok/s) | PPL: {row['ppl']}\n"
f" Weights: ~{row['weights_MB']} MB | KV: ~{row['kv_MB']} MB | KV reduction: {row['kv_reduction_x']}×\n"
)
results.append(row)
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
df = pd.DataFrame(results)
return df
# --------------------------------- CLI ---------------------------------------
def main():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, default="gpt2")
parser.add_argument("--methods", nargs="+", default=["none", "int8", "int4"],
help="KV cache methods: none int8 int4")
parser.add_argument("--h2o", action="store_true", help="Enable H2O token dropping")
parser.add_argument("--h2o_keep", type=float, default=0.5, help="Fraction to keep (0..1)")
parser.add_argument("--sink", type=int, default=4, help="Sink tokens always kept")
parser.add_argument("--samples", type=int, default=20)
parser.add_argument("--ctx", type=int, default=1024)
parser.add_argument("--gen", type=int, default=256)
parser.add_argument("--warmup", type=int, default=2)
parser.add_argument("--out", type=str, default="./kv_sota_results")
parser.add_argument("--seed", type=int, default=42)
args = parser.parse_args()
set_seed(args.seed)
log.info("Loading model with 4-bit weights + FlashAttention-2...")
tok, model = load_model_4bit(args.model)
bench = BenchCfg(eval_samples=args.samples, prefill_len=args.ctx, gen_len=args.gen, warmup=args.warmup)
df = run_benchmark(
tok, model, args.methods, use_h2o=args.h2o,
h2o_keep=args.h2o_keep, sink_tokens=args.sink, bench=bench
)
os.makedirs(args.out, exist_ok=True)
df.to_csv(os.path.join(args.out, "summary.csv"), index=False)
with open(os.path.join(args.out, "summary.json"), "w") as f:
json.dump(df.to_dict(orient="records"), f, indent=2)
print("\n" + "="*80)
print("BENCHMARK SUMMARY")
print("="*80)
print(df.to_string(index=False))
print(f"\n✅ Results saved in: {args.out}")
if __name__ == "__main__":
main()