File size: 2,504 Bytes
f1d3bf6 e35e3bc f1d3bf6 e35e3bc f1d3bf6 e35e3bc f1d3bf6 e35e3bc f1d3bf6 e35e3bc f1d3bf6 e35e3bc f1d3bf6 e35e3bc f1d3bf6 e35e3bc f1d3bf6 e35e3bc f1d3bf6 e35e3bc f1d3bf6 e35e3bc f1d3bf6 e35e3bc f1d3bf6 e35e3bc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 |
from typing import Final
import gradio as gr
import numpy as np
import torch
import ujson as json
from transformers import AutoModelForCausalLM, AutoTokenizer
_dev: Final = "cuda" if torch.cuda.is_available() else "cpu"
_dtype: Final = torch.bfloat16
def _perplexity(model, tokenizer, text) -> float:
encodings = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
input_ids = encodings.input_ids.to(_dev)
with torch.no_grad():
outputs = model(input_ids, labels=input_ids)
loss = outputs.loss.item()
return np.log(torch.exp(torch.tensor(loss)).item())
def load_model(checkpoints: list[str]) -> dict:
tokenizers = [AutoTokenizer.from_pretrained(c) for c in checkpoints]
models = [
AutoModelForCausalLM.from_pretrained(c, device_map="auto", torch_dtype=_dtype)
for c in checkpoints
]
# Load the models and tokenizers into a dictionary
return {
ckpt: {"model": model.to(_dev).eval(), "tokenizer": tokenizer}
for ckpt, model, tokenizer in zip(checkpoints, models, tokenizers)
}
def log_perplexity(
loaded: dict,
num_samples: int,
sample_length: int,
progress=gr.Progress(),
) -> dict:
# Initialize a dictionary to store perplexity
ppls: dict[str, list] = {ckpt: [] for ckpt in loaded.keys()}
# Initialize samples
texts: Final[list[str]] = [
text.strip()[:sample_length]
for text in json.load(open("texts.json", "r"))
if text.strip()
]
# Start the iteration
progress(0, desc="Starting")
for i in range(num_samples):
progress(i / num_samples, desc="Processing samples")
for ckpt, info in loaded.items(): # Calculate perplexity for each model
ppl: float = _perplexity(info["model"], info["tokenizer"], texts[i])
if 1 < ppl < 1e4: # Filter out outliers
ppls[ckpt].append(ppl)
# Calculate the mean perplexity for each model
means: dict = {ckpt: np.mean(ppl) for ckpt, ppl in ppls.items()}
# Calculate the standard deviation of perplexity for each model
stds: dict = {ckpt: np.std(ppl) for ckpt, ppl in ppls.items()}
return {"ppls": ppls, "means": means, "stds": stds}
if __name__ == "__main__":
from pprint import pprint
# Example usage
checkpoints = ["HuggingFaceTB/SmolLM2-135M"]
loaded = load_model(checkpoints)
num_samples = 500
sample_length = 128
pprint(log_perplexity(loaded, num_samples, sample_length))
|