abancp commited on
Commit
7e1aa1c
·
verified ·
1 Parent(s): 17a3eb0

Update inference_fine_tune.py

Browse files
Files changed (1) hide show
  1. inference_fine_tune.py +9 -95
inference_fine_tune.py CHANGED
@@ -7,6 +7,15 @@ from pathlib import Path
7
  from config import get_config, get_weights_file_path
8
  from train import get_model
9
 
 
 
 
 
 
 
 
 
 
10
 
11
  config = get_config("./openweb.config.json")
12
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -22,101 +31,6 @@ model.eval()
22
  state = torch.load(model_path,map_location=torch.device('cpu'))
23
  model.load_state_dict(state['model_state_dict'])
24
 
25
- def generate_text(
26
- model, text, tokenizer, max_len, device,
27
- temperature=0.7, top_k=50
28
- ):
29
- eos_idx = tokenizer.token_to_id('</s>')
30
- pad_idx = tokenizer.token_to_id('<pad>')
31
-
32
- # Start with the input text as initial decoder input
33
- decoder_input = text.to(device)
34
- if decoder_input.dim() == 1:
35
- decoder_input = decoder_input.unsqueeze(0)
36
-
37
-
38
- # Print the initial prompt
39
-
40
- while decoder_input.shape[1] < 2000 :
41
- # Apply causal mask based on current decoder_input length
42
- # decoder_mask = (decoder_input != pad_idx).unsqueeze(0).int() & causal_mask(decoder_input.size(1)).type_as(mask).to(device)
43
-
44
- # Get model output
45
- out = model.decode(decoder_input)
46
- logits = model.project(out[:, -1]) # Get logits for last token
47
-
48
- # Sampling: temperature + top-k
49
- logits = logits / temperature
50
- top_k_logits, top_k_indices = torch.topk(logits, top_k)
51
- probs = torch.softmax(top_k_logits, dim=-1)
52
- next_token = torch.multinomial(probs, num_samples=1)
53
- next_token = top_k_indices.gather(-1, next_token)
54
-
55
- # Decode and print token
56
- word = tokenizer.decode([next_token.item()])
57
- print(word, end="", flush=True)
58
-
59
- # Append next token
60
-
61
- decoder_input = torch.cat([decoder_input, next_token], dim=1)
62
- if decoder_input.shape[1] > max_len:
63
- decoder_input = decoder_input[:,-max_len:]
64
-
65
-
66
- if next_token.item() == eos_idx:
67
- break
68
-
69
- print()
70
- return decoder_input.squeeze(0)
71
-
72
-
73
-
74
- def get_tokenizer(config)->Tokenizer:
75
- tokenizers_path = Path(config['tokenizer_file'])
76
- if Path.exists(tokenizers_path):
77
- print("Loading tokenizer from ", tokenizers_path)
78
- tokenizer = Tokenizer.from_file(str(tokenizers_path))
79
- return tokenizer
80
- else:
81
- raise FileNotFoundError("Cant find tokenizer file : ",tokenizers_path)
82
-
83
- def run_model(config):
84
- device = "cuda" if torch.cuda.is_available() else "cpu"
85
- print(f"Using device : {device}")
86
- tokenizer = get_tokenizer(config)
87
- model = get_model(config, tokenizer.get_vocab_size()).to(device)
88
- model_path = get_weights_file_path(config,config['preload'])
89
- model.eval()
90
-
91
- if Path.exists(Path(model_path)):
92
- print("Loading Model from : ", model_path)
93
- state = torch.load(model_path)
94
- model.load_state_dict(state['model_state_dict'])
95
- print("You : ",end="")
96
- input_text = input()
97
- pad_token_id = tokenizer.token_to_id("<pad>")
98
- user_token_id = tokenizer.token_to_id("<user>")
99
- ai_token_id = tokenizer.token_to_id("<ai>")
100
- while input_text != "exit":
101
- input_tokens = tokenizer.encode(input_text).ids[:-1]
102
- input_tokens.extend([user_token_id] + input_tokens + [ai_token_id] )
103
-
104
- if len(input_tokens) > config['seq_len']:
105
- print(f"exceeding max length of input : {config['seq_len']}")
106
- continue
107
- # if len(input_tokens) < config['seq_len']:
108
- # input_tokens += [pad_token_id] * (config['seq_len'] - len(input_tokens))
109
- input_tokens = torch.tensor(input_tokens)
110
- output_tokens = generate_text(model, input_tokens, tokenizer, config['seq_len'], device )
111
- print("MODEL : ",output_tokens)
112
- output_text = tokenizer.decode(output_tokens.detach().cpu().numpy())
113
- # print("Model : "+output_text)
114
- print("You : ",end="")
115
- input_text = input()
116
-
117
- else:
118
- raise FileNotFoundError("Model File not found : "+ model_path)
119
-
120
  def generate_response(prompt:str):
121
  print("Prompt : ",prompt)
122
 
 
7
  from config import get_config, get_weights_file_path
8
  from train import get_model
9
 
10
+ def get_tokenizer(config)->Tokenizer:
11
+ tokenizers_path = Path(config['tokenizer_file'])
12
+ if Path.exists(tokenizers_path):
13
+ print("Loading tokenizer from ", tokenizers_path)
14
+ tokenizer = Tokenizer.from_file(str(tokenizers_path))
15
+ return tokenizer
16
+ else:
17
+ raise FileNotFoundError("Cant find tokenizer file : ",tokenizers_path)
18
+
19
 
20
  config = get_config("./openweb.config.json")
21
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
31
  state = torch.load(model_path,map_location=torch.device('cpu'))
32
  model.load_state_dict(state['model_state_dict'])
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  def generate_response(prompt:str):
35
  print("Prompt : ",prompt)
36