Update app.py
Browse files
app.py
CHANGED
@@ -9,6 +9,7 @@ tokenizer = AutoTokenizer.from_pretrained("Xibanya/DS9Bot")
|
|
9 |
model = AutoModelForCausalLM.from_pretrained("Xibanya/DS9Bot")
|
10 |
|
11 |
def generate(prompt):
|
|
|
12 |
prompt = '[Promenade] ' if prompt is None else prompt
|
13 |
encoded_prompt = tokenizer(prompt, return_tensors="pt").input_ids
|
14 |
encoded_prompt = encoded_prompt.to(model.device)
|
@@ -24,14 +25,18 @@ def generate(prompt):
|
|
24 |
num_return_sequences=1
|
25 |
)
|
26 |
text = tokenizer.batch_decode(
|
27 |
-
output, skip_special_tokens=True, clean_up_tokenization_spaces=True)[0]
|
28 |
for string in cfg['split']:
|
29 |
text = text.replace(string, '\n' + string)
|
30 |
-
return text
|
31 |
|
32 |
title = 'Deep Space Nine Script Generator'
|
33 |
description = 'AI-generated scripts for the best Trek'
|
34 |
|
35 |
-
iface = gr.Interface(fn=generate,
|
36 |
-
|
|
|
|
|
|
|
|
|
37 |
iface.launch()
|
|
|
9 |
model = AutoModelForCausalLM.from_pretrained("Xibanya/DS9Bot")
|
10 |
|
11 |
def generate(prompt):
|
12 |
+
torch.Generator().manual_seed(cfg['seed'] if cfg['seed'] is not None else torch.seed())
|
13 |
prompt = '[Promenade] ' if prompt is None else prompt
|
14 |
encoded_prompt = tokenizer(prompt, return_tensors="pt").input_ids
|
15 |
encoded_prompt = encoded_prompt.to(model.device)
|
|
|
25 |
num_return_sequences=1
|
26 |
)
|
27 |
text = tokenizer.batch_decode(
|
28 |
+
output, skip_special_tokens=True, clean_up_tokenization_spaces=True)[0]
|
29 |
for string in cfg['split']:
|
30 |
text = text.replace(string, '\n' + string)
|
31 |
+
return text.strip()
|
32 |
|
33 |
title = 'Deep Space Nine Script Generator'
|
34 |
description = 'AI-generated scripts for the best Trek'
|
35 |
|
36 |
+
iface = gr.Interface(fn=generate,
|
37 |
+
inputs=[gr.inputs.Textbox(label="Prompt", placeholder='Enter a prompt to generate a script from')],
|
38 |
+
outputs=[gr.outputs.TextBox(type="auto", label="Script Generated")],
|
39 |
+
title=title,
|
40 |
+
description=description,
|
41 |
+
theme='seafoam', examples=['[Promenade]', 'QUARK: ', "Commander's log, stardate 46924.5."])
|
42 |
iface.launch()
|