Tzktz's picture
Upload 7664 files
6fc683c verified
from app_model import AppModel
from app_utils import *
from controlnet.app_canny import create_demo_canny
from controlnet.app_depth import create_demo_depth
from controlnet.app_ip2p import create_demo_ip2p
from controlnet.app_lineart import create_demo_lineart
from controlnet.app_mlsd import create_demo_mlsd
from controlnet.app_normal import create_demo_normal
from controlnet.app_openpose import create_demo_openpose
from controlnet.app_scribble import create_demo_scribble
from controlnet.app_scribble_interactive import create_demo_scribble_interactive
from controlnet.app_segmentation import create_demo_segmentation
from controlnet.app_shuffle import create_demo_shuffle
from controlnet.app_softedge import create_demo_softedge
from fairseq import options
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
from fairseq.distributed import utils as distributed_utils
def main(cfg):
appmodel = AppModel(cfg)
with gr.Blocks() as app:
with gr.Row():
with gr.Column(scale=5):
ckpt = gr.Textbox(value=cfg.model.pretrained_ckpt_path, show_label=False, container=False)
with gr.Column(scale=4):
current_ckpt = gr.Textbox(show_label=False, container=False)
with gr.Column(scale=1, min_width=100):
scheduler = gr.Dropdown(['dpms', 'pndm', 'ddim'], value='dpms', show_label=False, container=False,
min_width=60)
with gr.Row():
with gr.Column(scale=5):
lora = gr.Dropdown(['None'] + appmodel.get_available_lora(), value=cfg.model.lora_name,
show_label=False, container=False)
with gr.Column(scale=4):
current_lora = gr.Textbox(show_label=False, container=False)
with gr.Column(scale=1, min_width=60):
set_ckpt_scheduler_button = gr.Button('Set', container=False, min_width=60)
set_ckpt_scheduler_button.click(
fn=appmodel.set_ckpt_scheduler_fn, inputs=[ckpt, scheduler], outputs=current_ckpt, queue=False
).then(fn=appmodel.load_lora, inputs=lora, outputs=current_lora, queue=False)
with gr.Tabs():
with gr.TabItem('KOSMOS-G'):
with gr.Blocks():
with gr.Row():
with gr.Column(scale=1):
prompt = gr.Textbox(label="Prompt", max_lines=1,
placeholder="Use <i> to represent the images in prompt")
num_input_images = gr.Slider(1, MAX_INPUT_IMAGES, value=DEFAULT_INPUT_IMAGES, step=1,
label="Number of input images:")
input_images = [gr.Image(label=f'img{i}', type="pil",
visible=True if i < DEFAULT_INPUT_IMAGES else False)
for i in range(MAX_INPUT_IMAGES)]
num_input_images.change(variable_images, num_input_images, input_images)
text_guidance_scale = gr.Slider(1, 15, value=6, step=0.5, label="Text Guidance Scale")
seed = gr.Slider(label="Seed", minimum=MIN_SEED, maximum=MAX_SEED, step=1, value=0)
randomize_seed = gr.Checkbox(label='Randomize seed', value=True)
run_button = gr.Button(label="Run")
with gr.Accordion("Advanced options", open=False):
lora_scale = gr.Slider(0, 1, value=0, step=0.05, label="LoRA Scale")
num_inference_steps = gr.Slider(label="num_inference_steps", minimum=10, maximum=100,
value=50, step=5)
negative_prompt = gr.Textbox(label="Negative Prompt", max_lines=1, value="")
num_images_per_prompt = gr.Slider(1, MAX_IMAGES_PER_PROMPT,
value=4, step=1, label="Number of Images")
with gr.Column(scale=2):
result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery",
columns=2, height='100%')
ips = [prompt, lora_scale, num_inference_steps, text_guidance_scale, negative_prompt,
num_images_per_prompt, *input_images]
prompt.submit(
fn=appmodel.set_ckpt_scheduler_fn, inputs=[ckpt, scheduler], outputs=current_ckpt, queue=False
).then(fn=appmodel.load_lora, inputs=lora, outputs=current_lora, queue=False).then(
fn=randomize_seed_fn, inputs=[seed, randomize_seed], outputs=seed, queue=False, api_name=False
).then(fn=appmodel.kosmosg_generation, inputs=ips, outputs=result_gallery)
run_button.click(
fn=appmodel.set_ckpt_scheduler_fn, inputs=[ckpt, scheduler], outputs=current_ckpt, queue=False
).then(fn=appmodel.load_lora, inputs=lora, outputs=current_lora, queue=False).then(
fn=randomize_seed_fn, inputs=[seed, randomize_seed], outputs=seed, queue=False, api_name=False
).then(fn=appmodel.kosmosg_generation, inputs=ips, outputs=result_gallery)
gr.Examples(
examples=[
['<i>', 'appimg/dog.jpg', None],
['<i> swimming underwater', 'appimg/dog.jpg', None],
['<i> in Batman suit', 'appimg/dog.jpg', None],
['<i> as an oil painting by Vincent van Gogh', 'appimg/dog.jpg', None],
['<i> in Minecraft', 'appimg/dog.jpg', None],
['<i> in the suit of <i>', 'appimg/dog2.jpg', 'appimg/ironman.jpg'],
['<i> in Unity3D', 'appimg/car.jpg', None],
['<i>', 'appimg/bengio.jpg', None],
['<i> as an oil painting in the style of <i>', 'appimg/bengio.jpg', 'appimg/vangogh.jpg'],
['<i> wearing <i>', 'appimg/bengio.jpg', 'appimg/sunglasses.jpg'],
['<i> in <i>\'s jacket', 'appimg/bengio.jpg', 'appimg/huang.jpg'],
['<i> taking a selfie at <i>', 'appimg/bengio.jpg', 'appimg/ijen.jpg'],
['<i> in the style of <i>', 'appimg/bengio.jpg', 'appimg/uname.jpg'],
],
inputs=[prompt, input_images[0], input_images[1]],
cache_examples=False,
examples_per_page=100
)
with gr.TabItem('ControlNet KOSMOS-G'):
with gr.Tabs():
with gr.TabItem('Canny'):
create_demo_canny(appmodel.controlnet_generation_canny)
with gr.TabItem('MLSD'):
create_demo_mlsd(appmodel.controlnet_generation_mlsd)
with gr.TabItem('Scribble'):
create_demo_scribble(appmodel.controlnet_generation_scribble)
with gr.TabItem('Scribble Interactive'):
create_demo_scribble_interactive(appmodel.controlnet_generation_scribble_interactive)
with gr.TabItem('SoftEdge'):
create_demo_softedge(appmodel.controlnet_generation_softedge)
with gr.TabItem('OpenPose'):
create_demo_openpose(appmodel.controlnet_generation_openpose)
with gr.TabItem('Segmentation'):
create_demo_segmentation(appmodel.controlnet_generation_segmentation)
with gr.TabItem('Depth'):
create_demo_depth(appmodel.controlnet_generation_depth)
with gr.TabItem('Normal map'):
create_demo_normal(appmodel.controlnet_generation_normal)
with gr.TabItem('Lineart'):
create_demo_lineart(appmodel.controlnet_generation_lineart)
with gr.TabItem('Content Shuffle'):
create_demo_shuffle(appmodel.controlnet_generation_shuffle)
with gr.TabItem('Instruct Pix2Pix'):
create_demo_ip2p(appmodel.controlnet_generation_ip2p)
app.queue(concurrency_count=1).launch(share=True)
if __name__ == "__main__":
parser = options.get_training_parser()
args = options.parse_args_and_arch(parser, modify_parser=None)
cfg = convert_namespace_to_omegaconf(args)
distributed_utils.call_main(cfg, main)