Update generate.py
Browse files- generate.py +1 -4
generate.py
CHANGED
@@ -2,10 +2,6 @@ import torch
|
|
2 |
|
3 |
def generate(model, input_ids, generation_config, left_padding=None, **kwargs):
|
4 |
generation_config = generation_config or model.generation_config # default to the model generation config
|
5 |
-
pad_token = kwargs.pop("pad_token", None) or generation_config.pad_token_id or model.config.pad_token_id
|
6 |
-
if len(kwargs) > 0: # Let's catch unexpected kwargs, so that users don't get surprised
|
7 |
-
raise ValueError(f"Unused kwargs: {list(kwargs.keys())}")
|
8 |
-
|
9 |
cur_length = input_ids.shape[1]
|
10 |
max_length = generation_config.max_length or cur_length + generation_config.max_new_tokens
|
11 |
|
@@ -14,6 +10,7 @@ def generate(model, input_ids, generation_config, left_padding=None, **kwargs):
|
|
14 |
if not isinstance(left_padding, int) or left_padding < 0:
|
15 |
raise ValueError(f"left_padding must be an integer larger than 0, but is {left_padding}")
|
16 |
|
|
|
17 |
if pad_token is None:
|
18 |
raise ValueError("pad_token is not defined")
|
19 |
batch_size = input_ids.shape[0]
|
|
|
2 |
|
3 |
def generate(model, input_ids, generation_config, left_padding=None, **kwargs):
|
4 |
generation_config = generation_config or model.generation_config # default to the model generation config
|
|
|
|
|
|
|
|
|
5 |
cur_length = input_ids.shape[1]
|
6 |
max_length = generation_config.max_length or cur_length + generation_config.max_new_tokens
|
7 |
|
|
|
10 |
if not isinstance(left_padding, int) or left_padding < 0:
|
11 |
raise ValueError(f"left_padding must be an integer larger than 0, but is {left_padding}")
|
12 |
|
13 |
+
pad_token = kwargs.pop("pad_token", None) or generation_config.pad_token_id or model.config.pad_token_id
|
14 |
if pad_token is None:
|
15 |
raise ValueError("pad_token is not defined")
|
16 |
batch_size = input_ids.shape[0]
|