zero2story / interfaces /story_gen_ui.py
chansung's picture
.
3332aa4
import re
import copy
import random
import gradio as gr
from gradio_client import Client
from pathlib import Path
from modules import (
ImageMaker, MusicMaker, merge_video
)
from modules.llms import get_llm_factory
from interfaces import utils
from pingpong import PingPong
from pingpong.context import CtxLastWindowStrategy
img_maker = ImageMaker('https://huggingface.co/jphan32/Zero2Story/landscapeAnimePro_v20Inspiration.safetensors',
vae="https://huggingface.co/jphan32/Zero2Story/cute20vae.safetensors")
bgm_maker = MusicMaker(model_size='small', output_format='mp3')
video_gen_client_url = None # e.g. "https://0447df3cf5f7c49c46.gradio.live"
async def update_story_gen(
cursors, cur_cursor_idx,
genre, place, mood,
main_char_name, main_char_age, main_char_personality, main_char_job,
side_char_enable1, side_char_name1, side_char_age1, side_char_personality1, side_char_job1,
side_char_enable2, side_char_name2, side_char_age2, side_char_personality2, side_char_job2,
side_char_enable3, side_char_name3, side_char_age3, side_char_personality3, side_char_job3,
):
if len(cursors) == 1:
return await first_story_gen(
cursors,
genre, place, mood,
main_char_name, main_char_age, main_char_personality, main_char_job,
side_char_enable1, side_char_name1, side_char_age1, side_char_personality1, side_char_job1,
side_char_enable2, side_char_name2, side_char_age2, side_char_personality2, side_char_job2,
side_char_enable3, side_char_name3, side_char_age3, side_char_personality3, side_char_job3,
cur_cursor_idx=cur_cursor_idx
)
else:
return await next_story_gen(
cursors,
None,
genre, place, mood,
main_char_name, main_char_age, main_char_personality, main_char_job,
side_char_enable1, side_char_name1, side_char_age1, side_char_personality1, side_char_job1,
side_char_enable2, side_char_name2, side_char_age2, side_char_personality2, side_char_job2,
side_char_enable3, side_char_name3, side_char_age3, side_char_personality3, side_char_job3,
cur_cursor_idx=cur_cursor_idx
)
async def next_story_gen(
cursors,
action,
genre, place, mood,
main_char_name, main_char_age, main_char_personality, main_char_job,
side_char_enable1, side_char_name1, side_char_age1, side_char_personality1, side_char_job1,
side_char_enable2, side_char_name2, side_char_age2, side_char_personality2, side_char_job2,
side_char_enable3, side_char_name3, side_char_age3, side_char_personality3, side_char_job3,
cur_cursor_idx=None,
llm_type="PaLM"
):
factory = get_llm_factory(llm_type)
prompts = factory.create_prompt_manager().prompts
llm_service = factory.create_llm_service()
stories = ""
cur_side_chars = 1
action = cursors[cur_cursor_idx]["action"] if cur_cursor_idx is not None else action
end_idx = len(cursors) if cur_cursor_idx is None else len(cursors)-1
for cursor in cursors[:end_idx]:
stories = stories + cursor["story"]
side_char_prompt = utils.add_side_character(
[side_char_enable1, side_char_enable2, side_char_enable3],
[side_char_name1, side_char_name2, side_char_name3],
[side_char_job1, side_char_job2, side_char_job3],
[side_char_age1, side_char_age2, side_char_age3],
[side_char_personality1, side_char_personality2, side_char_personality3],
)
prompt = prompts['story_gen']['next_story_gen'].format(
genre=genre, place=place, mood=mood,
main_char_name=main_char_name,
main_char_job=main_char_job,
main_char_age=main_char_age,
main_char_personality=main_char_personality,
side_char_placeholder=side_char_prompt,
stories=stories, action=action,
)
print(f"generated prompt:\n{prompt}")
parameters = llm_service.make_params(mode="text", temperature=1.0, top_k=40, top_p=0.9, max_output_tokens=4096)
try:
response_json = await utils.retry_until_valid_json(prompt, parameters=parameters)
except Exception as e:
print(e)
raise gr.Error(e)
story = response_json["paragraphs"]
if isinstance(story, list):
story = "\n\n".join(story)
if cur_cursor_idx is None:
cursors.append({
"title": "",
"story": story,
"action": action
})
else:
cursors[cur_cursor_idx]["story"] = story
cursors[cur_cursor_idx]["action"] = action
return (
cursors, len(cursors)-1,
story,
gr.update(
maximum=len(cursors), value=len(cursors),
label=f"{len(cursors)} out of {len(cursors)} stories",
visible=True, interactive=True
),
gr.update(interactive=True),
gr.update(interactive=True),
gr.update(value=None, visible=False, interactive=True),
gr.update(value=None, visible=False, interactive=True),
gr.update(value=None, visible=False, interactive=True),
)
async def actions_gen(
cursors,
genre, place, mood,
main_char_name, main_char_age, main_char_personality, main_char_job,
side_char_enable1, side_char_name1, side_char_age1, side_char_personality1, side_char_job1,
side_char_enable2, side_char_name2, side_char_age2, side_char_personality2, side_char_job2,
side_char_enable3, side_char_name3, side_char_age3, side_char_personality3, side_char_job3,
cur_cursor_idx=None,
llm_type="PaLM"
):
factory = get_llm_factory(llm_type)
prompts = factory.create_prompt_manager().prompts
llm_service = factory.create_llm_service()
stories = ""
cur_side_chars = 1
end_idx = len(cursors) if cur_cursor_idx is None else len(cursors)-1
for cursor in cursors[:end_idx]:
stories = stories + cursor["story"]
summary_prompt = prompts['story_gen']['summarize'].format(stories=stories)
print(f"generated prompt:\n{summary_prompt}")
parameters = llm_service.make_params(mode="text", temperature=1.0, top_k=40, top_p=1.0, max_output_tokens=4096)
try:
_, summary = await llm_service.gen_text(summary_prompt, mode="text", parameters=parameters)
except Exception as e:
print(e)
raise gr.Error(e)
side_char_prompt = utils.add_side_character(
[side_char_enable1, side_char_enable2, side_char_enable3],
[side_char_name1, side_char_name2, side_char_name3],
[side_char_job1, side_char_job2, side_char_job3],
[side_char_age1, side_char_age2, side_char_age3],
[side_char_personality1, side_char_personality2, side_char_personality3],
)
prompt = prompts['story_gen']['actions_gen'].format(
genre=genre, place=place, mood=mood,
main_char_name=main_char_name,
main_char_job=main_char_job,
main_char_age=main_char_age,
main_char_personality=main_char_personality,
side_char_placeholder=side_char_prompt,
summary=summary,
)
print(f"generated prompt:\n{prompt}")
parameters = llm_service.make_params(mode="text", temperature=1.0, top_k=40, top_p=1.0, max_output_tokens=4096)
try:
response_json = await utils.retry_until_valid_json(prompt, parameters=parameters)
except Exception as e:
print(e)
raise gr.Error(e)
actions = response_json["options"]
random_actions = random.sample(actions, 3)
return (
gr.update(value=random_actions[0], interactive=True),
gr.update(value=random_actions[1], interactive=True),
gr.update(value=random_actions[2], interactive=True),
" "
)
async def first_story_gen(
cursors,
genre, place, mood,
main_char_name, main_char_age, main_char_personality, main_char_job,
side_char_enable1, side_char_name1, side_char_age1, side_char_personality1, side_char_job1,
side_char_enable2, side_char_name2, side_char_age2, side_char_personality2, side_char_job2,
side_char_enable3, side_char_name3, side_char_age3, side_char_personality3, side_char_job3,
cur_cursor_idx=None,
llm_type="PaLM"
):
factory = get_llm_factory(llm_type)
prompts = factory.create_prompt_manager().prompts
llm_service = factory.create_llm_service()
cur_side_chars = 1
side_char_prompt = utils.add_side_character(
[side_char_enable1, side_char_enable2, side_char_enable3],
[side_char_name1, side_char_name2, side_char_name3],
[side_char_job1, side_char_job2, side_char_job3],
[side_char_age1, side_char_age2, side_char_age3],
[side_char_personality1, side_char_personality2, side_char_personality3],
)
prompt = prompts['story_gen']['first_story_gen'].format(
genre=genre, place=place, mood=mood,
main_char_name=main_char_name,
main_char_job=main_char_job,
main_char_age=main_char_age,
main_char_personality=main_char_personality,
side_char_placeholder=side_char_prompt,
)
print(f"generated prompt:\n{prompt}")
parameters = llm_service.make_params(mode="text", temperature=1.0, top_k=40, top_p=1.0, max_output_tokens=4096)
try:
response_json = await utils.retry_until_valid_json(prompt, parameters=parameters)
except Exception as e:
print(e)
raise gr.Error(e)
story = response_json["paragraphs"]
if isinstance(story, list):
story = "\n\n".join(story)
if cur_cursor_idx is None:
cursors.append({
"title": "",
"story": story
})
else:
cursors[cur_cursor_idx]["story"] = story
return (
cursors, len(cursors)-1,
story,
gr.update(
maximum=len(cursors), value=len(cursors),
label=f"{len(cursors)} out of {len(cursors)} stories",
visible=False if len(cursors) == 1 else True, interactive=True
),
gr.update(interactive=True),
gr.update(interactive=True),
gr.update(value=None, visible=False, interactive=True),
gr.update(value=None, visible=False, interactive=True),
gr.update(value=None, visible=False, interactive=True),
)
def video_gen(
image, audio, title, cursors, cur_cursor, use_ffmpeg=True
):
if use_ffmpeg:
output_filename = merge_video(image, audio, story_title="")
if not use_ffmpeg or not output_filename:
client = Client(video_gen_client_url)
result = client.predict(
"",
audio,
image,
f"{utils.id_generator()}.mp4",
api_name="/predict"
)
output_filename = result[0]
cursors[cur_cursor]["video"] = output_filename
return (
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=True, value=output_filename),
cursors,
" "
)
def image_gen(
genre, place, mood, title, story_content, cursors, cur_cursor, llm_type="PaLM"
):
# generate prompts for background image with LLM
for _ in range(3):
try:
prompt, neg_prompt = img_maker.generate_background_prompts(genre, place, mood, title, "", story_content, llm_type)
print(f"Image Prompt: {prompt}")
print(f"Negative Prompt: {neg_prompt}")
break
except Exception as e:
print(e)
raise gr.Error(e)
if not prompt:
raise ValueError("Failed to generate prompts for background image.")
# generate image
try:
img_filename = img_maker.text2image(prompt, neg_prompt=neg_prompt, ratio='16:9', cfg=6.5)
except ValueError as e:
print(e)
img_filename = str(Path('.') / 'assets' / 'nsfw_warning_wide.png')
cursors[cur_cursor]["img"] = img_filename
return (
gr.update(visible=True, value=img_filename),
cursors,
" "
)
def audio_gen(
genre, place, mood, title, story_content, cursors, cur_cursor, llm_type="PaLM"
):
# generate prompt for background music with LLM
for _ in range(3):
try:
prompt = bgm_maker.generate_prompt(genre, place, mood, title, "", story_content, llm_type)
print(f"Music Prompt: {prompt}")
break
except Exception as e:
print(e)
raise gr.Error(e)
if not prompt:
raise ValueError("Failed to generate prompt for background music.")
# generate music
bgm_filename = bgm_maker.text2music(prompt, length=60)
cursors[cur_cursor]["audio"] = bgm_filename
return (
gr.update(visible=True, value=bgm_filename),
cursors,
" "
)
def move_story_cursor(moved_cursor, cursors):
cursor_content = cursors[moved_cursor-1]
max_cursor = len(cursors)
action_btn = (
gr.update(interactive=False),
gr.update(interactive=False),
gr.update(interactive=False)
)
if moved_cursor == max_cursor:
action_btn = (
gr.update(interactive=True),
gr.update(interactive=True),
gr.update(interactive=True)
)
if "video" in cursor_content:
outputs = (
moved_cursor-1,
gr.update(label=f"{moved_cursor} out of {len(cursors)} chapters"),
cursor_content["story"],
gr.update(value=None, visible=False),
gr.update(value=None, visible=False),
gr.update(value=cursor_content["video"], visible=True),
)
else:
image_container = gr.update(value=None, visible=False)
audio_container = gr.update(value=None, visible=False)
if "img" in cursor_content:
image_container = gr.update(value=cursor_content["img"], visible=True)
if "audio" in cursor_content:
audio_container = gr.update(value=cursor_content["audio"], visible=True)
outputs = (
moved_cursor-1,
gr.update(label=f"{moved_cursor} out of {len(cursors)} stories"),
cursor_content["story"],
image_container,
audio_container,
gr.update(value=None, visible=False),
)
return outputs + action_btn
def update_story_content(story_content, cursors, cur_cursor):
cursors[cur_cursor]["story"] = story_content
return cursors
def disable_btns():
return (
gr.update(interactive=False), # image_gen_btn
gr.update(interactive=False), # audio_gen_btn
gr.update(interactive=False), # img_audio_combine_btn
gr.update(interactive=False), # regen_actions_btn
gr.update(interactive=False), # regen_story_btn
gr.update(interactive=False), # custom_prompt_txt
gr.update(interactive=False), # action_btn1
gr.update(interactive=False), # action_btn2
gr.update(interactive=False), # action_btn3
gr.update(interactive=False), # custom_action_txt
gr.update(interactive=False), # restart_from_story_generation_btn
gr.update(interactive=False), # story_writing_done_btn
)
def enable_btns(story_image, story_audio):
video_gen_btn_state = gr.update(interactive=False)
if story_image is not None and \
story_audio is not None:
video_gen_btn_state = gr.update(interactive=True)
return (
gr.update(interactive=True), # image_gen_btn
gr.update(interactive=True), # audio_gen_btn
video_gen_btn_state, # img_audio_combine_btn
gr.update(interactive=True), # regen_actions_btn
gr.update(interactive=True), # regen_story_btn
gr.update(interactive=True), # custom_prompt_txt
gr.update(interactive=True), # action_btn1
gr.update(interactive=True), # action_btn2
gr.update(interactive=True), # action_btn3
gr.update(interactive=True), # custom_action_txt
gr.update(interactive=True), # restart_from_story_generation_btn
gr.update(interactive=True), # story_writing_done_btn
)