Xibanya commited on
Commit
c3241ad
·
1 Parent(s): 369bd20

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -3
app.py CHANGED
@@ -9,6 +9,20 @@ with open('config.yaml', "r") as f:
9
  tokenizer = AutoTokenizer.from_pretrained("Xibanya/DS9Bot")
10
  model = AutoModelForCausalLM.from_pretrained("Xibanya/DS9Bot")
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  def generate(prompt: str = ' ', temp: float = 0.7, length: float = 250):
13
  torch.Generator().manual_seed(cfg['seed'] if cfg['seed'] is not None else torch.seed())
14
  prompt = '[Ops] SISKO: ' if prompt is None or '' else prompt
@@ -27,9 +41,7 @@ def generate(prompt: str = ' ', temp: float = 0.7, length: float = 250):
27
  )
28
  text = tokenizer.batch_decode(
29
  output, skip_special_tokens=True, clean_up_tokenization_spaces=True)[0]
30
- for string in cfg['split']:
31
- text = text.replace(string, '\n' + string.strip())
32
- return text.strip()
33
 
34
  title = 'Deep Space Nine Script Generator'
35
 
 
9
  tokenizer = AutoTokenizer.from_pretrained("Xibanya/DS9Bot")
10
  model = AutoModelForCausalLM.from_pretrained("Xibanya/DS9Bot")
11
 
12
+ def split_dialogue(line):
13
+ tokens = line.split()
14
+ concat = ''
15
+ total = len(tokens)
16
+ for t in range(total):
17
+ token = tokens[t]
18
+ if token != '[OC]:' and '[on ' not in token and \
19
+ ('[' in token or
20
+ (token.isupper() and (':' in token or (t < total - 1 and tokens[t + 1] == '[OC]:')))):
21
+ token = '\n' + token
22
+
23
+ concat = concat + token + ' '
24
+ return concat.strip()
25
+
26
  def generate(prompt: str = ' ', temp: float = 0.7, length: float = 250):
27
  torch.Generator().manual_seed(cfg['seed'] if cfg['seed'] is not None else torch.seed())
28
  prompt = '[Ops] SISKO: ' if prompt is None or '' else prompt
 
41
  )
42
  text = tokenizer.batch_decode(
43
  output, skip_special_tokens=True, clean_up_tokenization_spaces=True)[0]
44
+ return split_dialogue(text)
 
 
45
 
46
  title = 'Deep Space Nine Script Generator'
47