File size: 2,504 Bytes
f1d3bf6
 
e35e3bc
f1d3bf6
 
 
 
 
e35e3bc
 
f1d3bf6
 
e35e3bc
 
 
 
 
 
 
f1d3bf6
 
e35e3bc
 
f1d3bf6
e35e3bc
 
f1d3bf6
 
 
 
e35e3bc
 
f1d3bf6
 
 
e35e3bc
 
 
 
 
 
 
 
f1d3bf6
e35e3bc
 
 
 
 
 
f1d3bf6
e35e3bc
 
f1d3bf6
e35e3bc
 
 
 
 
f1d3bf6
 
e35e3bc
f1d3bf6
 
e35e3bc
f1d3bf6
e35e3bc
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
from typing import Final

import gradio as gr
import numpy as np
import torch
import ujson as json
from transformers import AutoModelForCausalLM, AutoTokenizer

_dev: Final = "cuda" if torch.cuda.is_available() else "cpu"
_dtype: Final = torch.bfloat16


def _perplexity(model, tokenizer, text) -> float:
    encodings = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
    input_ids = encodings.input_ids.to(_dev)
    with torch.no_grad():
        outputs = model(input_ids, labels=input_ids)
        loss = outputs.loss.item()
    return np.log(torch.exp(torch.tensor(loss)).item())


def load_model(checkpoints: list[str]) -> dict:
    tokenizers = [AutoTokenizer.from_pretrained(c) for c in checkpoints]
    models = [
        AutoModelForCausalLM.from_pretrained(c, device_map="auto", torch_dtype=_dtype)
        for c in checkpoints
    ]

    # Load the models and tokenizers into a dictionary
    return {
        ckpt: {"model": model.to(_dev).eval(), "tokenizer": tokenizer}
        for ckpt, model, tokenizer in zip(checkpoints, models, tokenizers)
    }


def log_perplexity(
    loaded: dict,
    num_samples: int,
    sample_length: int,
    progress=gr.Progress(),
) -> dict:
    # Initialize a dictionary to store perplexity
    ppls: dict[str, list] = {ckpt: [] for ckpt in loaded.keys()}

    # Initialize samples
    texts: Final[list[str]] = [
        text.strip()[:sample_length]
        for text in json.load(open("texts.json", "r"))
        if text.strip()
    ]

    # Start the iteration
    progress(0, desc="Starting")
    for i in range(num_samples):
        progress(i / num_samples, desc="Processing samples")
        for ckpt, info in loaded.items():  # Calculate perplexity for each model
            ppl: float = _perplexity(info["model"], info["tokenizer"], texts[i])
            if 1 < ppl < 1e4:  # Filter out outliers
                ppls[ckpt].append(ppl)

    # Calculate the mean perplexity for each model
    means: dict = {ckpt: np.mean(ppl) for ckpt, ppl in ppls.items()}

    # Calculate the standard deviation of perplexity for each model
    stds: dict = {ckpt: np.std(ppl) for ckpt, ppl in ppls.items()}

    return {"ppls": ppls, "means": means, "stds": stds}


if __name__ == "__main__":
    from pprint import pprint

    # Example usage
    checkpoints = ["HuggingFaceTB/SmolLM2-135M"]
    loaded = load_model(checkpoints)
    num_samples = 500
    sample_length = 128
    pprint(log_perplexity(loaded, num_samples, sample_length))