jonathantiedchen commited on
Commit
8852436
·
verified ·
1 Parent(s): 52931fc

Create utils.py

Browse files
Files changed (1) hide show
  1. utils.py +14 -0
utils.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import StoppingCriteria
2
+
3
+
4
+ # Define a stopping condition for generation
5
+ class SpecificStringStoppingCriteria(StoppingCriteria):
6
+ def __init__(self, tokenizer, stop_strings, input_len):
7
+ self.tokenizer = tokenizer
8
+ self.stop_strings = stop_strings
9
+ self.input_len = input_len
10
+
11
+ def __call__(self, input_ids, scores, **kwargs):
12
+ current_text = self.tokenizer.decode(input_ids[0], skip_special_tokens=True)[self.input_len:]
13
+
14
+ return any(stop_string in current_text for stop_string in self.stop_strings)