joaogante HF Staff commited on
Commit
33ad63f
·
verified ·
1 Parent(s): 87fc79c

Update generate.py

Browse files
Files changed (1) hide show
  1. generate.py +9 -6
generate.py CHANGED
@@ -1,15 +1,19 @@
1
  import torch
2
 
3
- def generate(model, input_ids, generation_config, left_padding=None, **kwargs):
4
  generation_config = generation_config or model.generation_config # default to the model generation config
 
 
 
 
5
  cur_length = input_ids.shape[1]
6
  max_length = generation_config.max_length or cur_length + generation_config.max_new_tokens
7
-
8
- # Example of custom argument: add left padding
9
  if left_padding is not None:
10
  if not isinstance(left_padding, int) or left_padding < 0:
11
  raise ValueError(f"left_padding must be an integer larger than 0, but is {left_padding}")
12
- pad_token = kwargs.get("pad_token") or generation_config.pad_token_id or model.config.pad_token_id
13
  if pad_token is None:
14
  raise ValueError("pad_token is not defined")
15
  batch_size = input_ids.shape[0]
@@ -25,5 +29,4 @@ def generate(model, input_ids, generation_config, left_padding=None, **kwargs):
25
  input_ids = torch.cat((input_ids, next_tokens[:, None]), dim=-1)
26
  cur_length += 1
27
 
28
- return input_ids
29
-
 
1
  import torch
2
 
3
+ def generate(model, input_ids, generation_config, left_padding=None, **kwargs):
4
  generation_config = generation_config or model.generation_config # default to the model generation config
5
+ pad_token = kwargs.pop("pad_token", None) or generation_config.pad_token_id or model.config.pad_token_id
6
+ if len(kwargs) > 0:
7
+ raise ValueError(f"Unused kwargs: {list(kwargs.keys())}")
8
+
9
  cur_length = input_ids.shape[1]
10
  max_length = generation_config.max_length or cur_length + generation_config.max_new_tokens
11
+
12
+ # Example of custom argument: add `left_padding` (integer) pad tokens before the prompt
13
  if left_padding is not None:
14
  if not isinstance(left_padding, int) or left_padding < 0:
15
  raise ValueError(f"left_padding must be an integer larger than 0, but is {left_padding}")
16
+
17
  if pad_token is None:
18
  raise ValueError("pad_token is not defined")
19
  batch_size = input_ids.shape[0]
 
29
  input_ids = torch.cat((input_ids, next_tokens[:, None]), dim=-1)
30
  cur_length += 1
31
 
32
+ return input_ids