import gradio as gr from utils import load_model, log_perplexity class Manager: """Class to manage model loading and perplexity calculation state.""" def __init__(self): self.loaded = None def load_models(self, checkpoint_input_str: str) -> str: """Load models from a comma-separated string of checkpoint names.""" checkpoints = [ ckpt.strip() for ckpt in checkpoint_input_str.split(",") if ckpt.strip() ] if not checkpoints: return "Please enter at least one model checkpoint name." try: self.loaded = load_model(checkpoints) return "Models loaded successfully!" except Exception as e: return f"Model loading failed: {e}" def perplexity( self, num_samples: int | None = None, sample_length: int | None = None, ) -> dict | str: """Calculate perplexity using the loaded models.""" if self.loaded is None: return "Please load models first." if num_samples is None or sample_length is None: return "Please set the number of samples and sample length." try: return log_perplexity(self.loaded, num_samples, sample_length) except Exception as e: return f"Perplexity calculation failed: {e}" def make_interface() -> gr.Blocks: """Create and return the Gradio interface.""" manager = Manager() with gr.Blocks() as demo: gr.Markdown("# LLM PPLs") checkpoints = gr.Textbox( label="Checkpoints", value="HuggingFaceTB/SmolLM2-135M" ) load_btn = gr.Button("Load Models", variant="primary") with gr.Row(): num_samples = gr.Number(label="Number of Samples", value=1500) sample_length = gr.Number(label="Sample Length", value=128) perplexity_btn = gr.Button("Compute PPLs") load_output = gr.Textbox(label="Model Loading Status") perplexity_output = gr.JSON(label="PPL Results") # Connect event handlers load_btn.click( fn=manager.load_models, inputs=checkpoints, outputs=load_output, ) perplexity_btn.click( fn=manager.perplexity, inputs=[num_samples, sample_length], outputs=perplexity_output, ) return demo if __name__ == "__main__": demo = make_interface() demo.launch()