Update generate.py
Browse files- generate.py +9 -6
generate.py
CHANGED
@@ -1,15 +1,19 @@
|
|
1 |
import torch
|
2 |
|
3 |
-
def generate(model, input_ids, generation_config, left_padding=None, **kwargs):
|
4 |
generation_config = generation_config or model.generation_config # default to the model generation config
|
|
|
|
|
|
|
|
|
5 |
cur_length = input_ids.shape[1]
|
6 |
max_length = generation_config.max_length or cur_length + generation_config.max_new_tokens
|
7 |
-
|
8 |
-
# Example of custom argument: add
|
9 |
if left_padding is not None:
|
10 |
if not isinstance(left_padding, int) or left_padding < 0:
|
11 |
raise ValueError(f"left_padding must be an integer larger than 0, but is {left_padding}")
|
12 |
-
|
13 |
if pad_token is None:
|
14 |
raise ValueError("pad_token is not defined")
|
15 |
batch_size = input_ids.shape[0]
|
@@ -25,5 +29,4 @@ def generate(model, input_ids, generation_config, left_padding=None, **kwargs):
|
|
25 |
input_ids = torch.cat((input_ids, next_tokens[:, None]), dim=-1)
|
26 |
cur_length += 1
|
27 |
|
28 |
-
return input_ids
|
29 |
-
|
|
|
1 |
import torch
|
2 |
|
3 |
+
def generate(model, input_ids, generation_config, left_padding=None, **kwargs):
|
4 |
generation_config = generation_config or model.generation_config # default to the model generation config
|
5 |
+
pad_token = kwargs.pop("pad_token", None) or generation_config.pad_token_id or model.config.pad_token_id
|
6 |
+
if len(kwargs) > 0:
|
7 |
+
raise ValueError(f"Unused kwargs: {list(kwargs.keys())}")
|
8 |
+
|
9 |
cur_length = input_ids.shape[1]
|
10 |
max_length = generation_config.max_length or cur_length + generation_config.max_new_tokens
|
11 |
+
|
12 |
+
# Example of custom argument: add `left_padding` (integer) pad tokens before the prompt
|
13 |
if left_padding is not None:
|
14 |
if not isinstance(left_padding, int) or left_padding < 0:
|
15 |
raise ValueError(f"left_padding must be an integer larger than 0, but is {left_padding}")
|
16 |
+
|
17 |
if pad_token is None:
|
18 |
raise ValueError("pad_token is not defined")
|
19 |
batch_size = input_ids.shape[0]
|
|
|
29 |
input_ids = torch.cat((input_ids, next_tokens[:, None]), dim=-1)
|
30 |
cur_length += 1
|
31 |
|
32 |
+
return input_ids
|
|