llm-ppls / utils.py
sthenno's picture
update(core): fix code
e35e3bc
from typing import Final
import gradio as gr
import numpy as np
import torch
import ujson as json
from transformers import AutoModelForCausalLM, AutoTokenizer
_dev: Final = "cuda" if torch.cuda.is_available() else "cpu"
_dtype: Final = torch.bfloat16
def _perplexity(model, tokenizer, text) -> float:
encodings = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
input_ids = encodings.input_ids.to(_dev)
with torch.no_grad():
outputs = model(input_ids, labels=input_ids)
loss = outputs.loss.item()
return np.log(torch.exp(torch.tensor(loss)).item())
def load_model(checkpoints: list[str]) -> dict:
tokenizers = [AutoTokenizer.from_pretrained(c) for c in checkpoints]
models = [
AutoModelForCausalLM.from_pretrained(c, device_map="auto", torch_dtype=_dtype)
for c in checkpoints
]
# Load the models and tokenizers into a dictionary
return {
ckpt: {"model": model.to(_dev).eval(), "tokenizer": tokenizer}
for ckpt, model, tokenizer in zip(checkpoints, models, tokenizers)
}
def log_perplexity(
loaded: dict,
num_samples: int,
sample_length: int,
progress=gr.Progress(),
) -> dict:
# Initialize a dictionary to store perplexity
ppls: dict[str, list] = {ckpt: [] for ckpt in loaded.keys()}
# Initialize samples
texts: Final[list[str]] = [
text.strip()[:sample_length]
for text in json.load(open("texts.json", "r"))
if text.strip()
]
# Start the iteration
progress(0, desc="Starting")
for i in range(num_samples):
progress(i / num_samples, desc="Processing samples")
for ckpt, info in loaded.items(): # Calculate perplexity for each model
ppl: float = _perplexity(info["model"], info["tokenizer"], texts[i])
if 1 < ppl < 1e4: # Filter out outliers
ppls[ckpt].append(ppl)
# Calculate the mean perplexity for each model
means: dict = {ckpt: np.mean(ppl) for ckpt, ppl in ppls.items()}
# Calculate the standard deviation of perplexity for each model
stds: dict = {ckpt: np.std(ppl) for ckpt, ppl in ppls.items()}
return {"ppls": ppls, "means": means, "stds": stds}
if __name__ == "__main__":
from pprint import pprint
# Example usage
checkpoints = ["HuggingFaceTB/SmolLM2-135M"]
loaded = load_model(checkpoints)
num_samples = 500
sample_length = 128
pprint(log_perplexity(loaded, num_samples, sample_length))