File size: 2,864 Bytes
b934e2e
 
 
 
 
 
 
 
 
 
 
 
eb3b93f
b934e2e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import os
os.system("pip uninstall torchvision -y")
os.system("pip install torchvision --force-reinstall --no-cache-dir")

import torch
from diffusers import AutoPipelineForText2Image
import gradio as gr
from PIL import Image
import spaces

pipe = AutoPipelineForText2Image.from_pretrained(
    "ostris/Flex.2-preview",
    custom_pipeline="pipeline.py",
    torch_dtype=torch.bfloat16,
).to("cuda")

@spaces.GPU
def generate_image(
    prompt: str,
    inpaint_img: Image.Image,
    inpaint_mask: Image.Image,
    control_img: Image.Image,
    height: int,
    width: int,
    guidance_scale: float,
    num_inference_steps: int,
    seed: int,
    control_strength: float,
    control_stop: float,
):
    gen = torch.Generator(device="cuda").manual_seed(seed)

    inp_img = inpaint_img.convert("RGB")
    inp_mask = inpaint_mask.convert("RGB")
    ctrl_img = control_img.convert("RGB")

    result = pipe(
        prompt=prompt,
        inpaint_image=inp_img,
        inpaint_mask=inp_mask,
        control_image=ctrl_img,
        control_strength=control_strength,
        control_stop=control_stop,
        height=height,
        width=width,
        guidance_scale=guidance_scale,
        num_inference_steps=num_inference_steps,
        generator=gen,
    )
    return result.images[0]

with gr.Blocks(title="Flex.2-preview Image Generator") as demo:
    gr.Markdown("# Flex.2-preview Text→Image Generator")

    prompt = gr.Textbox(label="Prompt", placeholder="Enter your prompt...", lines=2)

    with gr.Row():
        inpaint_img = gr.Image(type="pil", label="Inpaint Image")
        inpaint_mask = gr.Image(type="pil", label="Inpaint Mask")
        control_img = gr.Image(type="pil", label="Control Image")

    with gr.Row():
        height = gr.Slider(64, 2048, value=512, step=64, label="Height")
        width = gr.Slider(64, 2048, value=512, step=64, label="Width")
    with gr.Row():
        guidance_scale = gr.Slider(0.0, 20.0, value=3.5, step=0.1, label="Guidance Scale")
        num_inference_steps = gr.Slider(1, 100, value=50, step=1, label="Inference Steps")
    seed = gr.Number(value=42, precision=0, label="Random Seed")
    control_strength = gr.Slider(0.0, 1.0, value=0.0, step=0.01, label="Control Strength")
    control_stop = gr.Slider(0.0, 1.0, value=1.0, step=0.01, label="Control Stop")

    generate_btn = gr.Button("Generate")
    output = gr.Image(type="pil", label="Generated Image")

    generate_btn.click(
        fn=generate_image,
        inputs=[
            prompt,
            inpaint_img,
            inpaint_mask,
            control_img,
            height,
            width,
            guidance_scale,
            num_inference_steps,
            seed,
            control_strength,
            control_stop,
        ],
        outputs=[output],
    )

if __name__ == "__main__":
    demo.launch(share=True)