import torch import json import tiktoken from your_model_module import GPT, GPTConfig # Assuming your GPT class is in a file called your_model_module.py # Load configuration with open("config.json", "r") as f: config_dict = json.load(f) config = GPTConfig(**config_dict) # Load model model = GPT(config) model.load_state_dict(torch.load("best_model_params.pt", map_location=torch.device("cpu"))) # Load to CPU model.eval() # Load tokenizer enc = tiktoken.get_encoding("gpt2") def generate_text(prompt, max_new_tokens=200, temperature=1.0, top_k=None): context = torch.tensor(enc.encode_ordinary(prompt)).unsqueeze(dim=0) with torch.no_grad(): generated_tokens = model.generate(context, max_new_tokens, temperature=temperature, top_k=top_k) return enc.decode(generated_tokens.squeeze().tolist()) if __name__ == "__main__": prompt = input("Enter your prompt: ") generated_text = generate_text(prompt) print(generated_text)