llm-ppls / app.py
sthenno's picture
update(core): fix code
e35e3bc
import gradio as gr
from utils import load_model, log_perplexity
class Manager:
"""Class to manage model loading and perplexity calculation state."""
def __init__(self):
self.loaded = None
def load_models(self, checkpoint_input_str: str) -> str:
"""Load models from a comma-separated string of checkpoint names."""
checkpoints = [
ckpt.strip() for ckpt in checkpoint_input_str.split(",") if ckpt.strip()
]
if not checkpoints:
return "Please enter at least one model checkpoint name."
try:
self.loaded = load_model(checkpoints)
return "Models loaded successfully!"
except Exception as e:
return f"Model loading failed: {e}"
def perplexity(
self,
num_samples: int | None = None,
sample_length: int | None = None,
) -> dict | str:
"""Calculate perplexity using the loaded models."""
if self.loaded is None:
return "Please load models first."
if num_samples is None or sample_length is None:
return "Please set the number of samples and sample length."
try:
return log_perplexity(self.loaded, num_samples, sample_length)
except Exception as e:
return f"Perplexity calculation failed: {e}"
def make_interface() -> gr.Blocks:
"""Create and return the Gradio interface."""
manager = Manager()
with gr.Blocks() as demo:
gr.Markdown("# LLM PPLs")
checkpoints = gr.Textbox(
label="Checkpoints", value="HuggingFaceTB/SmolLM2-135M"
)
load_btn = gr.Button("Load Models", variant="primary")
with gr.Row():
num_samples = gr.Number(label="Number of Samples", value=1500)
sample_length = gr.Number(label="Sample Length", value=128)
perplexity_btn = gr.Button("Compute PPLs")
load_output = gr.Textbox(label="Model Loading Status")
perplexity_output = gr.JSON(label="PPL Results")
# Connect event handlers
load_btn.click(
fn=manager.load_models,
inputs=checkpoints,
outputs=load_output,
)
perplexity_btn.click(
fn=manager.perplexity,
inputs=[num_samples, sample_length],
outputs=perplexity_output,
)
return demo
if __name__ == "__main__":
demo = make_interface()
demo.launch()