import torch | |
import json | |
import tiktoken | |
from your_model_module import GPT, GPTConfig # Assuming your GPT class is in a file called your_model_module.py | |
# Load configuration | |
with open("config.json", "r") as f: | |
config_dict = json.load(f) | |
config = GPTConfig(**config_dict) | |
# Load model | |
model = GPT(config) | |
model.load_state_dict(torch.load("best_model_params.pt", map_location=torch.device("cpu"))) # Load to CPU | |
model.eval() | |
# Load tokenizer | |
enc = tiktoken.get_encoding("gpt2") | |
def generate_text(prompt, max_new_tokens=200, temperature=1.0, top_k=None): | |
context = torch.tensor(enc.encode_ordinary(prompt)).unsqueeze(dim=0) | |
with torch.no_grad(): | |
generated_tokens = model.generate(context, max_new_tokens, temperature=temperature, top_k=top_k) | |
return enc.decode(generated_tokens.squeeze().tolist()) | |
if __name__ == "__main__": | |
prompt = input("Enter your prompt: ") | |
generated_text = generate_text(prompt) | |
print(generated_text) | |