Update app.py
Browse files
app.py
CHANGED
@@ -9,6 +9,20 @@ with open('config.yaml', "r") as f:
|
|
9 |
tokenizer = AutoTokenizer.from_pretrained("Xibanya/DS9Bot")
|
10 |
model = AutoModelForCausalLM.from_pretrained("Xibanya/DS9Bot")
|
11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
def generate(prompt: str = ' ', temp: float = 0.7, length: float = 250):
|
13 |
torch.Generator().manual_seed(cfg['seed'] if cfg['seed'] is not None else torch.seed())
|
14 |
prompt = '[Ops] SISKO: ' if prompt is None or '' else prompt
|
@@ -27,9 +41,7 @@ def generate(prompt: str = ' ', temp: float = 0.7, length: float = 250):
|
|
27 |
)
|
28 |
text = tokenizer.batch_decode(
|
29 |
output, skip_special_tokens=True, clean_up_tokenization_spaces=True)[0]
|
30 |
-
|
31 |
-
text = text.replace(string, '\n' + string.strip())
|
32 |
-
return text.strip()
|
33 |
|
34 |
title = 'Deep Space Nine Script Generator'
|
35 |
|
|
|
9 |
tokenizer = AutoTokenizer.from_pretrained("Xibanya/DS9Bot")
|
10 |
model = AutoModelForCausalLM.from_pretrained("Xibanya/DS9Bot")
|
11 |
|
12 |
+
def split_dialogue(line):
|
13 |
+
tokens = line.split()
|
14 |
+
concat = ''
|
15 |
+
total = len(tokens)
|
16 |
+
for t in range(total):
|
17 |
+
token = tokens[t]
|
18 |
+
if token != '[OC]:' and '[on ' not in token and \
|
19 |
+
('[' in token or
|
20 |
+
(token.isupper() and (':' in token or (t < total - 1 and tokens[t + 1] == '[OC]:')))):
|
21 |
+
token = '\n' + token
|
22 |
+
|
23 |
+
concat = concat + token + ' '
|
24 |
+
return concat.strip()
|
25 |
+
|
26 |
def generate(prompt: str = ' ', temp: float = 0.7, length: float = 250):
|
27 |
torch.Generator().manual_seed(cfg['seed'] if cfg['seed'] is not None else torch.seed())
|
28 |
prompt = '[Ops] SISKO: ' if prompt is None or '' else prompt
|
|
|
41 |
)
|
42 |
text = tokenizer.batch_decode(
|
43 |
output, skip_special_tokens=True, clean_up_tokenization_spaces=True)[0]
|
44 |
+
return split_dialogue(text)
|
|
|
|
|
45 |
|
46 |
title = 'Deep Space Nine Script Generator'
|
47 |
|