from typing import Final import gradio as gr import numpy as np import torch import ujson as json from transformers import AutoModelForCausalLM, AutoTokenizer _dev: Final = "cuda" if torch.cuda.is_available() else "cpu" _dtype: Final = torch.bfloat16 def _perplexity(model, tokenizer, text) -> float: encodings = tokenizer(text, return_tensors="pt", truncation=True, max_length=512) input_ids = encodings.input_ids.to(_dev) with torch.no_grad(): outputs = model(input_ids, labels=input_ids) loss = outputs.loss.item() return np.log(torch.exp(torch.tensor(loss)).item()) def load_model(checkpoints: list[str]) -> dict: tokenizers = [AutoTokenizer.from_pretrained(c) for c in checkpoints] models = [ AutoModelForCausalLM.from_pretrained(c, device_map="auto", torch_dtype=_dtype) for c in checkpoints ] # Load the models and tokenizers into a dictionary return { ckpt: {"model": model.to(_dev).eval(), "tokenizer": tokenizer} for ckpt, model, tokenizer in zip(checkpoints, models, tokenizers) } def log_perplexity( loaded: dict, num_samples: int, sample_length: int, progress=gr.Progress(), ) -> dict: # Initialize a dictionary to store perplexity ppls: dict[str, list] = {ckpt: [] for ckpt in loaded.keys()} # Initialize samples texts: Final[list[str]] = [ text.strip()[:sample_length] for text in json.load(open("texts.json", "r")) if text.strip() ] # Start the iteration progress(0, desc="Starting") for i in range(num_samples): progress(i / num_samples, desc="Processing samples") for ckpt, info in loaded.items(): # Calculate perplexity for each model ppl: float = _perplexity(info["model"], info["tokenizer"], texts[i]) if 1 < ppl < 1e4: # Filter out outliers ppls[ckpt].append(ppl) # Calculate the mean perplexity for each model means: dict = {ckpt: np.mean(ppl) for ckpt, ppl in ppls.items()} # Calculate the standard deviation of perplexity for each model stds: dict = {ckpt: np.std(ppl) for ckpt, ppl in ppls.items()} return {"ppls": ppls, "means": means, "stds": stds} if __name__ == "__main__": from pprint import pprint # Example usage checkpoints = ["HuggingFaceTB/SmolLM2-135M"] loaded = load_model(checkpoints) num_samples = 500 sample_length = 128 pprint(log_perplexity(loaded, num_samples, sample_length))