manueldeprada HF Staff commited on
Commit
6891872
·
verified ·
1 Parent(s): 9da4cff

Upload folder using huggingface_hub

Browse files
.ruff_cache/.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # Automatically created by ruff.
2
+ *
.ruff_cache/0.12.8/5591301162804142724 ADDED
Binary file (151 Bytes). View file
 
.ruff_cache/CACHEDIR.TAG ADDED
@@ -0,0 +1 @@
 
 
1
+ Signature: 8a477f597d28d172789f06886806bc55
custom_generate/generate.py CHANGED
@@ -232,7 +232,9 @@ def _contrastive_search(
232
  ):
233
  # prepare inputs
234
  model_kwargs["use_cache"] = True
235
- model_inputs = model.prepare_inputs_for_generation(input_ids, **model_kwargs)
 
 
236
 
237
  # encode the given prefix and prepare model inputs; encoder-decoder model process the prefix and save
238
  # the `encoder_outputs`
@@ -369,6 +371,11 @@ def _contrastive_search(
369
  outputs["past_key_values"] = None
370
  # Remove last token from past K-V since we don't want to append it at this point
371
  model_kwargs["past_key_values"].crop(-1)
 
 
 
 
 
372
 
373
  all_outputs.append(outputs)
374
  outputs = stack_model_outputs(all_outputs, model.config.get_text_config())
@@ -605,5 +612,7 @@ def generate(model, *args, **kwargs):
605
  penalty_alpha (`float`): The alpha value for the degeneration penalty.
606
  top_k (`int`): The number of candidates to consider at each step.
607
  """
608
- generation_outputs = GenerationMixin.generate(model, *args, custom_generate=_contrastive_search, **kwargs)
 
 
609
  return generation_outputs
 
232
  ):
233
  # prepare inputs
234
  model_kwargs["use_cache"] = True
235
+ model_inputs = model.prepare_inputs_for_generation(
236
+ input_ids, **model_kwargs
237
+ )
238
 
239
  # encode the given prefix and prepare model inputs; encoder-decoder model process the prefix and save
240
  # the `encoder_outputs`
 
371
  outputs["past_key_values"] = None
372
  # Remove last token from past K-V since we don't want to append it at this point
373
  model_kwargs["past_key_values"].crop(-1)
374
+ else:
375
+ raise ValueError(
376
+ f"Unsupported cache type: {type(outputs['past_key_values'])}. Contrastive search requires "
377
+ "dynamic cache, so set `cache_implementation='dynamic'` in the generation config."
378
+ )
379
 
380
  all_outputs.append(outputs)
381
  outputs = stack_model_outputs(all_outputs, model.config.get_text_config())
 
612
  penalty_alpha (`float`): The alpha value for the degeneration penalty.
613
  top_k (`int`): The number of candidates to consider at each step.
614
  """
615
+ generation_outputs = GenerationMixin.generate(
616
+ model, *args, custom_generate=_contrastive_search, **kwargs
617
+ )
618
  return generation_outputs