smollm2 / app.py
Vibi007's picture
added git-lfs for lagre files
3a1104b
raw
history blame
6 kB
import os
import gradio as gr
import torch
from model import SmolLMModule
from transformers import AutoTokenizer, AutoModelForCausalLM
import yaml
import glob
# Load config
with open("config_smollm2_135.yaml", "r") as file:
config = yaml.safe_load(file)
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/cosmo2-tokenizer")
tokenizer.pad_token = tokenizer.eos_token
def get_available_checkpoints():
"""Get list of available checkpoints and final model"""
models = []
model_paths = {}
# Get checkpoints
checkpoints = glob.glob("checkpoints/*.ckpt")
for ckpt in checkpoints:
try:
# Extract step number from the filename
filename = os.path.basename(ckpt)
# Handle the format 'model-step=step=X.ckpt'
if "step=step=" in filename:
step = int(filename.split("step=step=")[1].split(".")[0])
display_name = f"Checkpoint Step {step}"
models.append(display_name)
model_paths[display_name] = ckpt
except (ValueError, IndexError) as e:
print(
f"Warning: Could not parse checkpoint filename: {filename}, Error: {e}"
)
continue
# Add final model if it exists
final_model_path = "final_model"
if os.path.exists(final_model_path):
display_name = "Final Model"
models.append(display_name)
model_paths[display_name] = final_model_path
# Sort checkpoints by step number (Final model will be at the end)
def get_step_number(name):
if name == "Final Model":
return float("inf")
try:
return int(name.split("Step ")[-1])
except:
return 0
models.sort(key=get_step_number)
if not models:
print(
"Warning: No checkpoints or final model found in the following locations:"
)
print("- Checkpoints directory:", os.path.abspath("checkpoints"))
print("- Final model directory:", os.path.abspath("final_model"))
else:
print(f"Found {len(models)} models:")
for model in models:
print(f"- {model}: {model_paths[model]}")
return models, model_paths
def load_model_from_checkpoint(model_path):
"""Load model from checkpoint or final model directory"""
if model_path == "final_model":
# Load the final saved model
model = SmolLMModule(config)
model.model = AutoModelForCausalLM.from_pretrained(model_path)
else:
# Load from checkpoint
model = SmolLMModule.load_from_checkpoint(model_path, config=config)
model.eval() # Set to evaluation mode
return model
def generate_text(prompt, model_choice, max_length=100, temperature=0.7, top_p=0.9):
"""Generate text based on prompt using selected model"""
# Check if model is selected
if not model_choice:
return "Please select a model checkpoint!"
if not prompt:
return "Please enter a prompt!"
try:
# Get model path from the mapping
_, model_paths = get_available_checkpoints()
model_path = model_paths.get(model_choice)
if not model_path or not os.path.exists(model_path):
return f"Model {model_choice} not found!"
# Load model
model = load_model_from_checkpoint(model_path)
# Move model to GPU if available
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)
# Tokenize input
inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
# Move inputs to same device as model
inputs = {k: v.to(device) for k, v in inputs.items()}
# Generate
with torch.no_grad():
outputs = model.model.generate(
inputs["input_ids"],
max_length=max_length,
temperature=temperature,
top_p=top_p,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
)
# Decode and return generated text
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
return generated_text
except Exception as e:
return f"Error during generation: {str(e)}"
# Get available models
display_names, _ = get_available_checkpoints()
# Create Gradio interface
with gr.Blocks(title="SmolLM2 Inference") as demo:
gr.Markdown("# SmolLM2 Text Generation")
if not display_names:
gr.Markdown("⚠️ No models found! Please train the model first.")
else:
gr.Markdown(
f"Found {len(display_names)} models/checkpoints. Select one and enter a prompt to generate text."
)
gr.Markdown("Available models: " + ", ".join(display_names))
with gr.Row():
with gr.Column():
model_dropdown = gr.Dropdown(
choices=display_names,
label="Select Model",
value=display_names[-1] if display_names else None,
interactive=True,
)
prompt = gr.Textbox(
lines=3, placeholder="Enter your prompt here...", label="Input Prompt"
)
max_length = gr.Slider(
minimum=10, maximum=500, value=100, step=10, label="Max Length"
)
temperature = gr.Slider(
minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature"
)
top_p = gr.Slider(
minimum=0.1, maximum=1.0, value=0.9, step=0.1, label="Top-p"
)
generate_btn = gr.Button("Generate")
with gr.Column():
output = gr.Textbox(lines=8, label="Generated Text")
generate_btn.click(
fn=generate_text,
inputs=[prompt, model_dropdown, max_length, temperature, top_p],
outputs=output,
)
if __name__ == "__main__":
demo.launch(share=True)