File size: 2,463 Bytes
a9011a0
 
e35e3bc
a9011a0
f1d3bf6
e35e3bc
f1d3bf6
 
 
e35e3bc
f1d3bf6
 
 
e35e3bc
 
f1d3bf6
 
e35e3bc
f1d3bf6
 
 
e35e3bc
f1d3bf6
 
 
 
e35e3bc
 
 
 
 
f1d3bf6
e35e3bc
f1d3bf6
e35e3bc
 
f1d3bf6
e35e3bc
f1d3bf6
 
 
 
e35e3bc
f1d3bf6
e35e3bc
f1d3bf6
 
e35e3bc
f1d3bf6
e35e3bc
 
f1d3bf6
 
 
 
e35e3bc
 
 
 
 
 
 
5f5be0b
f1d3bf6
 
 
e35e3bc
 
 
f1d3bf6
 
e35e3bc
 
 
 
 
f1d3bf6
 
 
 
 
e35e3bc
f1d3bf6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import gradio as gr

from utils import load_model, log_perplexity


class Manager:
    """Class to manage model loading and perplexity calculation state."""

    def __init__(self):
        self.loaded = None

    def load_models(self, checkpoint_input_str: str) -> str:
        """Load models from a comma-separated string of checkpoint names."""
        checkpoints = [
            ckpt.strip() for ckpt in checkpoint_input_str.split(",") if ckpt.strip()
        ]

        if not checkpoints:
            return "Please enter at least one model checkpoint name."

        try:
            self.loaded = load_model(checkpoints)
            return "Models loaded successfully!"
        except Exception as e:
            return f"Model loading failed: {e}"

    def perplexity(
        self,
        num_samples: int | None = None,
        sample_length: int | None = None,
    ) -> dict | str:
        """Calculate perplexity using the loaded models."""
        if self.loaded is None:
            return "Please load models first."
        if num_samples is None or sample_length is None:
            return "Please set the number of samples and sample length."
        try:
            return log_perplexity(self.loaded, num_samples, sample_length)
        except Exception as e:
            return f"Perplexity calculation failed: {e}"


def make_interface() -> gr.Blocks:
    """Create and return the Gradio interface."""
    manager = Manager()

    with gr.Blocks() as demo:
        gr.Markdown("# LLM PPLs")

        checkpoints = gr.Textbox(
            label="Checkpoints", value="HuggingFaceTB/SmolLM2-135M"
        )

        load_btn = gr.Button("Load Models", variant="primary")

        with gr.Row():
            num_samples = gr.Number(label="Number of Samples", value=1500)
            sample_length = gr.Number(label="Sample Length", value=128)

        perplexity_btn = gr.Button("Compute PPLs")

        load_output = gr.Textbox(label="Model Loading Status")
        perplexity_output = gr.JSON(label="PPL Results")

        # Connect event handlers
        load_btn.click(
            fn=manager.load_models,
            inputs=checkpoints,
            outputs=load_output,
        )

        perplexity_btn.click(
            fn=manager.perplexity,
            inputs=[num_samples, sample_length],
            outputs=perplexity_output,
        )

    return demo


if __name__ == "__main__":
    demo = make_interface()
    demo.launch()