joaogante HF staff commited on
Commit
411392e
·
verified ·
1 Parent(s): d3b6bb4

Update generate.py

Browse files
Files changed (1) hide show
  1. generate.py +1 -4
generate.py CHANGED
@@ -2,10 +2,6 @@ import torch
2
 
3
  def generate(model, input_ids, generation_config, left_padding=None, **kwargs):
4
  generation_config = generation_config or model.generation_config # default to the model generation config
5
- pad_token = kwargs.pop("pad_token", None) or generation_config.pad_token_id or model.config.pad_token_id
6
- if len(kwargs) > 0: # Let's catch unexpected kwargs, so that users don't get surprised
7
- raise ValueError(f"Unused kwargs: {list(kwargs.keys())}")
8
-
9
  cur_length = input_ids.shape[1]
10
  max_length = generation_config.max_length or cur_length + generation_config.max_new_tokens
11
 
@@ -14,6 +10,7 @@ def generate(model, input_ids, generation_config, left_padding=None, **kwargs):
14
  if not isinstance(left_padding, int) or left_padding < 0:
15
  raise ValueError(f"left_padding must be an integer larger than 0, but is {left_padding}")
16
 
 
17
  if pad_token is None:
18
  raise ValueError("pad_token is not defined")
19
  batch_size = input_ids.shape[0]
 
2
 
3
  def generate(model, input_ids, generation_config, left_padding=None, **kwargs):
4
  generation_config = generation_config or model.generation_config # default to the model generation config
 
 
 
 
5
  cur_length = input_ids.shape[1]
6
  max_length = generation_config.max_length or cur_length + generation_config.max_new_tokens
7
 
 
10
  if not isinstance(left_padding, int) or left_padding < 0:
11
  raise ValueError(f"left_padding must be an integer larger than 0, but is {left_padding}")
12
 
13
+ pad_token = kwargs.pop("pad_token", None) or generation_config.pad_token_id or model.config.pad_token_id
14
  if pad_token is None:
15
  raise ValueError("pad_token is not defined")
16
  batch_size = input_ids.shape[0]