Upload folder using huggingface_hub
Browse files- .ruff_cache/.gitignore +2 -0
- .ruff_cache/0.12.8/5591301162804142724 +0 -0
- .ruff_cache/CACHEDIR.TAG +1 -0
- custom_generate/generate.py +11 -2
.ruff_cache/.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
# Automatically created by ruff.
|
2 |
+
*
|
.ruff_cache/0.12.8/5591301162804142724
ADDED
Binary file (151 Bytes). View file
|
|
.ruff_cache/CACHEDIR.TAG
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
Signature: 8a477f597d28d172789f06886806bc55
|
custom_generate/generate.py
CHANGED
@@ -232,7 +232,9 @@ def _contrastive_search(
|
|
232 |
):
|
233 |
# prepare inputs
|
234 |
model_kwargs["use_cache"] = True
|
235 |
-
model_inputs = model.prepare_inputs_for_generation(
|
|
|
|
|
236 |
|
237 |
# encode the given prefix and prepare model inputs; encoder-decoder model process the prefix and save
|
238 |
# the `encoder_outputs`
|
@@ -369,6 +371,11 @@ def _contrastive_search(
|
|
369 |
outputs["past_key_values"] = None
|
370 |
# Remove last token from past K-V since we don't want to append it at this point
|
371 |
model_kwargs["past_key_values"].crop(-1)
|
|
|
|
|
|
|
|
|
|
|
372 |
|
373 |
all_outputs.append(outputs)
|
374 |
outputs = stack_model_outputs(all_outputs, model.config.get_text_config())
|
@@ -605,5 +612,7 @@ def generate(model, *args, **kwargs):
|
|
605 |
penalty_alpha (`float`): The alpha value for the degeneration penalty.
|
606 |
top_k (`int`): The number of candidates to consider at each step.
|
607 |
"""
|
608 |
-
generation_outputs = GenerationMixin.generate(
|
|
|
|
|
609 |
return generation_outputs
|
|
|
232 |
):
|
233 |
# prepare inputs
|
234 |
model_kwargs["use_cache"] = True
|
235 |
+
model_inputs = model.prepare_inputs_for_generation(
|
236 |
+
input_ids, **model_kwargs
|
237 |
+
)
|
238 |
|
239 |
# encode the given prefix and prepare model inputs; encoder-decoder model process the prefix and save
|
240 |
# the `encoder_outputs`
|
|
|
371 |
outputs["past_key_values"] = None
|
372 |
# Remove last token from past K-V since we don't want to append it at this point
|
373 |
model_kwargs["past_key_values"].crop(-1)
|
374 |
+
else:
|
375 |
+
raise ValueError(
|
376 |
+
f"Unsupported cache type: {type(outputs['past_key_values'])}. Contrastive search requires "
|
377 |
+
"dynamic cache, so set `cache_implementation='dynamic'` in the generation config."
|
378 |
+
)
|
379 |
|
380 |
all_outputs.append(outputs)
|
381 |
outputs = stack_model_outputs(all_outputs, model.config.get_text_config())
|
|
|
612 |
penalty_alpha (`float`): The alpha value for the degeneration penalty.
|
613 |
top_k (`int`): The number of candidates to consider at each step.
|
614 |
"""
|
615 |
+
generation_outputs = GenerationMixin.generate(
|
616 |
+
model, *args, custom_generate=_contrastive_search, **kwargs
|
617 |
+
)
|
618 |
return generation_outputs
|