File size: 3,734 Bytes
b0f3145
 
187715d
b0f3145
 
187715d
 
 
b0f3145
187715d
88dc089
187715d
b0f3145
 
187715d
 
 
 
 
b0f3145
 
e08daf1
187715d
 
 
59f3984
187715d
 
 
59f3984
187715d
 
 
 
 
0b120d5
b0f3145
187715d
 
 
 
 
 
 
0ce0e61
b0f3145
187715d
75859e2
187715d
 
 
2a8c2d1
bef93f3
187715d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2a8c2d1
e9028a2
187715d
 
2a8c2d1
187715d
 
 
 
 
 
 
 
 
 
 
 
b0f3145
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import gradio as gr
import torch
from diffusers import StableDiffusionXLPipeline, EulerDiscreteScheduler
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
import spaces
import os
from PIL import Image

SAFETY_CHECKER = os.environ.get("SAFETY_CHECKER", "0") == "1"

# Constants
base = "stabilityai/stable-diffusion-xl-base-1.0"
repo = "ByteDance/SDXL-Lightning"
checkpoints = {
    "1-Step" : ["sdxl_lightning_1step_unet_x0.safetensors", 1],
    "2-Step" : ["sdxl_lightning_2step_unet.safetensors", 2],
    "4-Step" : ["sdxl_lightning_4step_unet.safetensors", 4],
    "8-Step" : ["sdxl_lightning_8step_unet.safetensors", 8],
}


# Ensure model and scheduler are initialized in GPU-enabled function
if torch.cuda.is_available():
    pipe = StableDiffusionXLPipeline.from_pretrained(base, torch_dtype=torch.float16, variant="fp16").to("cuda")

if SAFETY_CHECKER:
    from safety_checker import StableDiffusionSafetyChecker
    from transformers import CLIPFeatureExtractor

    safety_checker = StableDiffusionSafetyChecker.from_pretrained(
        "CompVis/stable-diffusion-safety-checker"
    ).to("cuda")
    feature_extractor = CLIPFeatureExtractor.from_pretrained(
        "openai/clip-vit-base-patch32"
    )

    def check_nsfw_images(
        images: list[Image.Image],
    ) -> tuple[list[Image.Image], list[bool]]:
        safety_checker_input = feature_extractor(images, return_tensors="pt").to("cuda")
        has_nsfw_concepts = safety_checker(
            images=[images],
            clip_input=safety_checker_input.pixel_values.to("cuda")
        )

        return images, has_nsfw_concepts

# Function 
@spaces.GPU(enable_queue=True)
def generate_image(prompt, ckpt):
    print(prompt, ckpt)

    checkpoint = checkpoints[ckpt][0]
    num_inference_steps = checkpoints[ckpt][1] 

    if num_inference_steps==1:
        # Ensure sampler uses "trailing" timesteps and "sample" prediction type for 1-step inference.
        pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", prediction_type="sample")
    else:
        # Ensure sampler uses "trailing" timesteps.
        pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
        
    pipe.unet.load_state_dict(load_file(hf_hub_download(repo, checkpoint), device="cuda"))
    results = pipe(prompt, num_inference_steps=num_inference_steps, guidance_scale=0)

    if SAFETY_CHECKER:
        images, has_nsfw_concepts = check_nsfw_images(results.images)
        if any(has_nsfw_concepts):
            gr.Warning("NSFW content detected.")
            return Image.new("RGB", (512, 512))
        return images[0]
    return results.images[0]



# Gradio Interface

with gr.Blocks(css="style.css") as demo:
    gr.HTML("<h1><center>SDXL-Lightning ⚡</center></h1>")
    gr.HTML("<p><center>Lightning-fast text-to-image generation</center></p><p><center><a href='https://huggingface.co/ByteDance/SDXL-Lightning'>https://huggingface.co/ByteDance/SDXL-Lightning</a></center></p>")
    with gr.Group():
        with gr.Row():
            prompt = gr.Textbox(label='Enter you image prompt', scale=8)
            ckpt = gr.Dropdown(label='Select inference steps',choices=['1-Step', '2-Step', '4-Step', '8-Step'], value='4-Step', interactive=True)
            submit = gr.Button(scale=1, variant='primary')
    img = gr.Image(label='SDXL-Lightning Generated Image')

    prompt.submit(fn=generate_image,
                 inputs=[prompt, ckpt],
                 outputs=img,
                 )
    submit.click(fn=generate_image,
                 inputs=[prompt, ckpt],
                 outputs=img,
                 )
    
demo.queue().launch()