KingNish commited on
Commit
5f4bbff
·
verified ·
1 Parent(s): 28affd1

Upload run_inference.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. run_inference.py +29 -0
run_inference.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import json
4
+ import tiktoken
5
+ from your_model_module import GPT, GPTConfig # Assuming your GPT class is in a file called your_model_module.py
6
+
7
+ # Load configuration
8
+ with open("config.json", "r") as f:
9
+ config_dict = json.load(f)
10
+ config = GPTConfig(**config_dict)
11
+
12
+ # Load model
13
+ model = GPT(config)
14
+ model.load_state_dict(torch.load("best_model_params.pt", map_location=torch.device("cpu"))) # Load to CPU
15
+ model.eval()
16
+
17
+ # Load tokenizer
18
+ enc = tiktoken.get_encoding("gpt2")
19
+
20
+ def generate_text(prompt, max_new_tokens=200, temperature=1.0, top_k=None):
21
+ context = torch.tensor(enc.encode_ordinary(prompt)).unsqueeze(dim=0)
22
+ with torch.no_grad():
23
+ generated_tokens = model.generate(context, max_new_tokens, temperature=temperature, top_k=top_k)
24
+ return enc.decode(generated_tokens.squeeze().tolist())
25
+
26
+ if __name__ == "__main__":
27
+ prompt = input("Enter your prompt: ")
28
+ generated_text = generate_text(prompt)
29
+ print(generated_text)