File size: 2,463 Bytes
a9011a0 e35e3bc a9011a0 f1d3bf6 e35e3bc f1d3bf6 e35e3bc f1d3bf6 e35e3bc f1d3bf6 e35e3bc f1d3bf6 e35e3bc f1d3bf6 e35e3bc f1d3bf6 e35e3bc f1d3bf6 e35e3bc f1d3bf6 e35e3bc f1d3bf6 e35e3bc f1d3bf6 e35e3bc f1d3bf6 e35e3bc f1d3bf6 e35e3bc f1d3bf6 e35e3bc 5f5be0b f1d3bf6 e35e3bc f1d3bf6 e35e3bc f1d3bf6 e35e3bc f1d3bf6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 |
import gradio as gr
from utils import load_model, log_perplexity
class Manager:
"""Class to manage model loading and perplexity calculation state."""
def __init__(self):
self.loaded = None
def load_models(self, checkpoint_input_str: str) -> str:
"""Load models from a comma-separated string of checkpoint names."""
checkpoints = [
ckpt.strip() for ckpt in checkpoint_input_str.split(",") if ckpt.strip()
]
if not checkpoints:
return "Please enter at least one model checkpoint name."
try:
self.loaded = load_model(checkpoints)
return "Models loaded successfully!"
except Exception as e:
return f"Model loading failed: {e}"
def perplexity(
self,
num_samples: int | None = None,
sample_length: int | None = None,
) -> dict | str:
"""Calculate perplexity using the loaded models."""
if self.loaded is None:
return "Please load models first."
if num_samples is None or sample_length is None:
return "Please set the number of samples and sample length."
try:
return log_perplexity(self.loaded, num_samples, sample_length)
except Exception as e:
return f"Perplexity calculation failed: {e}"
def make_interface() -> gr.Blocks:
"""Create and return the Gradio interface."""
manager = Manager()
with gr.Blocks() as demo:
gr.Markdown("# LLM PPLs")
checkpoints = gr.Textbox(
label="Checkpoints", value="HuggingFaceTB/SmolLM2-135M"
)
load_btn = gr.Button("Load Models", variant="primary")
with gr.Row():
num_samples = gr.Number(label="Number of Samples", value=1500)
sample_length = gr.Number(label="Sample Length", value=128)
perplexity_btn = gr.Button("Compute PPLs")
load_output = gr.Textbox(label="Model Loading Status")
perplexity_output = gr.JSON(label="PPL Results")
# Connect event handlers
load_btn.click(
fn=manager.load_models,
inputs=checkpoints,
outputs=load_output,
)
perplexity_btn.click(
fn=manager.perplexity,
inputs=[num_samples, sample_length],
outputs=perplexity_output,
)
return demo
if __name__ == "__main__":
demo = make_interface()
demo.launch()
|