math_test / utils.py
jonathantiedchen's picture
Create utils.py
8852436 verified
raw
history blame contribute delete
565 Bytes
from transformers import StoppingCriteria
# Define a stopping condition for generation
class SpecificStringStoppingCriteria(StoppingCriteria):
def __init__(self, tokenizer, stop_strings, input_len):
self.tokenizer = tokenizer
self.stop_strings = stop_strings
self.input_len = input_len
def __call__(self, input_ids, scores, **kwargs):
current_text = self.tokenizer.decode(input_ids[0], skip_special_tokens=True)[self.input_len:]
return any(stop_string in current_text for stop_string in self.stop_strings)