|
import torch |
|
import gradio as gr |
|
from sm_model_train import SmolLMConfig, tokenizer, SmolLM |
|
|
|
|
|
|
|
def load_model(): |
|
config = SmolLMConfig() |
|
model = SmolLM(config) |
|
|
|
|
|
state_dict = torch.load("sm_model.pt", map_location="cpu") |
|
model.load_state_dict(state_dict) |
|
|
|
model.eval() |
|
return model |
|
|
|
|
|
def generate_text(prompt, max_tokens, temperature=0.8, top_k=40): |
|
"""Generate text based on the prompt""" |
|
try: |
|
|
|
prompt_ids = tokenizer.encode(prompt, return_tensors="pt") |
|
|
|
|
|
device = next(model.parameters()).device |
|
prompt_ids = prompt_ids.to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
generated_ids = model.generate( |
|
prompt_ids, |
|
max_new_tokens=max_tokens, |
|
temperature=temperature, |
|
top_k=top_k, |
|
) |
|
|
|
|
|
generated_text = tokenizer.decode(generated_ids[0].tolist()) |
|
|
|
return generated_text |
|
|
|
except Exception as e: |
|
return f"An error occurred: {str(e)}" |
|
|
|
|
|
|
|
model = load_model() |
|
|
|
|
|
demo = gr.Interface( |
|
fn=generate_text, |
|
inputs=[ |
|
gr.Textbox( |
|
label="Enter your prompt", placeholder="Once upon a time...", lines=3 |
|
), |
|
gr.Slider( |
|
minimum=50, |
|
maximum=500, |
|
value=100, |
|
step=10, |
|
label="Maximum number of tokens", |
|
), |
|
], |
|
outputs=gr.Textbox(label="Generated Text", lines=10), |
|
title="SmolLM2-135TextGenerator", |
|
description="Enter Prompt for the model to continue.", |
|
examples=[ |
|
["Once upon a time", 100], |
|
["The future of AI is", 200], |
|
["In a galaxy far far away", 150], |
|
], |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|