|
from typing import Final |
|
|
|
import gradio as gr |
|
import numpy as np |
|
import torch |
|
import ujson as json |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
|
_dev: Final = "cuda" if torch.cuda.is_available() else "cpu" |
|
_dtype: Final = torch.bfloat16 |
|
|
|
|
|
def _perplexity(model, tokenizer, text) -> float: |
|
encodings = tokenizer(text, return_tensors="pt", truncation=True, max_length=512) |
|
input_ids = encodings.input_ids.to(_dev) |
|
with torch.no_grad(): |
|
outputs = model(input_ids, labels=input_ids) |
|
loss = outputs.loss.item() |
|
return np.log(torch.exp(torch.tensor(loss)).item()) |
|
|
|
|
|
def load_model(checkpoints: list[str]) -> dict: |
|
tokenizers = [AutoTokenizer.from_pretrained(c) for c in checkpoints] |
|
models = [ |
|
AutoModelForCausalLM.from_pretrained(c, device_map="auto", torch_dtype=_dtype) |
|
for c in checkpoints |
|
] |
|
|
|
|
|
return { |
|
ckpt: {"model": model.to(_dev).eval(), "tokenizer": tokenizer} |
|
for ckpt, model, tokenizer in zip(checkpoints, models, tokenizers) |
|
} |
|
|
|
|
|
def log_perplexity( |
|
loaded: dict, |
|
num_samples: int, |
|
sample_length: int, |
|
progress=gr.Progress(), |
|
) -> dict: |
|
|
|
ppls: dict[str, list] = {ckpt: [] for ckpt in loaded.keys()} |
|
|
|
|
|
texts: Final[list[str]] = [ |
|
text.strip()[:sample_length] |
|
for text in json.load(open("texts.json", "r")) |
|
if text.strip() |
|
] |
|
|
|
|
|
progress(0, desc="Starting") |
|
for i in range(num_samples): |
|
progress(i / num_samples, desc="Processing samples") |
|
for ckpt, info in loaded.items(): |
|
ppl: float = _perplexity(info["model"], info["tokenizer"], texts[i]) |
|
if 1 < ppl < 1e4: |
|
ppls[ckpt].append(ppl) |
|
|
|
|
|
means: dict = {ckpt: np.mean(ppl) for ckpt, ppl in ppls.items()} |
|
|
|
|
|
stds: dict = {ckpt: np.std(ppl) for ckpt, ppl in ppls.items()} |
|
|
|
return {"ppls": ppls, "means": means, "stds": stds} |
|
|
|
|
|
if __name__ == "__main__": |
|
from pprint import pprint |
|
|
|
|
|
checkpoints = ["HuggingFaceTB/SmolLM2-135M"] |
|
loaded = load_model(checkpoints) |
|
num_samples = 500 |
|
sample_length = 128 |
|
pprint(log_perplexity(loaded, num_samples, sample_length)) |
|
|