|
import gradio as gr |
|
|
|
from utils import load_model, log_perplexity |
|
|
|
|
|
class Manager: |
|
"""Class to manage model loading and perplexity calculation state.""" |
|
|
|
def __init__(self): |
|
self.loaded = None |
|
|
|
def load_models(self, checkpoint_input_str: str) -> str: |
|
"""Load models from a comma-separated string of checkpoint names.""" |
|
checkpoints = [ |
|
ckpt.strip() for ckpt in checkpoint_input_str.split(",") if ckpt.strip() |
|
] |
|
|
|
if not checkpoints: |
|
return "Please enter at least one model checkpoint name." |
|
|
|
try: |
|
self.loaded = load_model(checkpoints) |
|
return "Models loaded successfully!" |
|
except Exception as e: |
|
return f"Model loading failed: {e}" |
|
|
|
def perplexity( |
|
self, |
|
num_samples: int | None = None, |
|
sample_length: int | None = None, |
|
) -> dict | str: |
|
"""Calculate perplexity using the loaded models.""" |
|
if self.loaded is None: |
|
return "Please load models first." |
|
if num_samples is None or sample_length is None: |
|
return "Please set the number of samples and sample length." |
|
try: |
|
return log_perplexity(self.loaded, num_samples, sample_length) |
|
except Exception as e: |
|
return f"Perplexity calculation failed: {e}" |
|
|
|
|
|
def make_interface() -> gr.Blocks: |
|
"""Create and return the Gradio interface.""" |
|
manager = Manager() |
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# LLM PPLs") |
|
|
|
checkpoints = gr.Textbox( |
|
label="Checkpoints", value="HuggingFaceTB/SmolLM2-135M" |
|
) |
|
|
|
load_btn = gr.Button("Load Models", variant="primary") |
|
|
|
with gr.Row(): |
|
num_samples = gr.Number(label="Number of Samples", value=1500) |
|
sample_length = gr.Number(label="Sample Length", value=128) |
|
|
|
perplexity_btn = gr.Button("Compute PPLs") |
|
|
|
load_output = gr.Textbox(label="Model Loading Status") |
|
perplexity_output = gr.JSON(label="PPL Results") |
|
|
|
|
|
load_btn.click( |
|
fn=manager.load_models, |
|
inputs=checkpoints, |
|
outputs=load_output, |
|
) |
|
|
|
perplexity_btn.click( |
|
fn=manager.perplexity, |
|
inputs=[num_samples, sample_length], |
|
outputs=perplexity_output, |
|
) |
|
|
|
return demo |
|
|
|
|
|
if __name__ == "__main__": |
|
demo = make_interface() |
|
demo.launch() |
|
|