Spaces:
Running
on
Zero
Running
on
Zero
import argparse | |
import os | |
from PIL import Image | |
import gradio as gr | |
import spaces | |
from imagenet_en_cn import IMAGENET_1K_CLASSES | |
from omegaconf import OmegaConf | |
from huggingface_hub import snapshot_download | |
import torch | |
from transformers import T5EncoderModel, AutoTokenizer | |
from pixelflow.scheduling_pixelflow import PixelFlowScheduler | |
from pixelflow.pipeline_pixelflow import PixelFlowPipeline | |
from pixelflow.utils import config as config_utils | |
from pixelflow.utils.misc import seed_everything | |
parser = argparse.ArgumentParser(description='Gradio Demo', add_help=False) | |
parser.add_argument('--checkpoint', type=str, help='checkpoint folder path') | |
parser.add_argument('--class_cond', action='store_true', help='use class conditional generation') | |
args = parser.parse_args() | |
# deploy | |
args.checkpoint = "pixelflow_t2i" | |
args.class_cond = False | |
output_dir = args.checkpoint | |
if args.class_cond: | |
if not os.path.exists(output_dir): | |
snapshot_download(repo_id="ShoufaChen/PixelFlow-Class2Image", local_dir=output_dir) | |
config = OmegaConf.load(f"{output_dir}/config.yaml") | |
model = config_utils.instantiate_from_config(config.model) | |
print(f"Num of parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}") | |
ckpt = torch.load(f"{output_dir}/model.pt", map_location="cpu", weights_only=True) | |
text_encoder = None | |
tokenizer = None | |
resolution = 256 | |
NUM_EXAMPLES = 4 | |
else: | |
if not os.path.exists(output_dir): | |
snapshot_download(repo_id="ShoufaChen/PixelFlow-Text2Image", local_dir=output_dir) | |
config = OmegaConf.load(f"{output_dir}/config.yaml") | |
model = config_utils.instantiate_from_config(config.model) | |
print(f"Num of parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}") | |
ckpt = torch.load(f"{output_dir}/model.pt", map_location="cpu", weights_only=True) | |
text_encoder = T5EncoderModel.from_pretrained("google/flan-t5-xl") | |
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-xl") | |
resolution = 1024 | |
NUM_EXAMPLES = 1 | |
model.load_state_dict(ckpt, strict=True) | |
model.eval() | |
print(f"outside space.GPU. {torch.cuda.is_available()=}") | |
if torch.cuda.is_available(): | |
model = model.cuda() | |
text_encoder = text_encoder.cuda() if text_encoder else None | |
device = torch.device("cuda") | |
else: | |
raise ValueError("No GPU") | |
scheduler = PixelFlowScheduler(config.scheduler.num_train_timesteps, num_stages=config.scheduler.num_stages, gamma=-1/3) | |
pipeline = PixelFlowPipeline( | |
scheduler, | |
model, | |
text_encoder=text_encoder, | |
tokenizer=tokenizer, | |
max_token_length=512, | |
) | |
def infer(noise_shift, cfg_scale, class_label, seed, *num_steps_per_stage): | |
print(f"inside space.GPU. {torch.cuda.is_available()=}") | |
seed_everything(seed) | |
with torch.autocast("cuda", dtype=torch.bfloat16), torch.no_grad(): | |
samples = pipeline( | |
prompt=[class_label] * NUM_EXAMPLES, | |
height=resolution, | |
width=resolution, | |
num_inference_steps=list(num_steps_per_stage), | |
guidance_scale=cfg_scale, # The guidance for the first frame, set it to 7 for 384p variant | |
device=device, | |
shift=noise_shift, | |
use_ode_dopri5=False, | |
) | |
samples = (samples * 255).round().astype("uint8") | |
samples = [Image.fromarray(sample) for sample in samples] | |
return samples | |
css = """ | |
h1 { | |
text-align: center; | |
display: block; | |
} | |
.follow-link { | |
margin-top: 0.8em; | |
font-size: 1em; | |
text-align: center; | |
} | |
""" | |
with gr.Blocks(css=css) as demo: | |
gr.Markdown("# PixelFlow: Pixel-Space Generative Models with Flow") | |
gr.HTML(""" | |
<div class="follow-link"> | |
For online class-to-image generation, please try | |
<a href="https://huggingface.co/spaces/ShoufaChen/PixelFlow">class-to-image</a>. | |
For more details, refer to our | |
<a href="https://arxiv.org/abs/2504.07963">arXiv paper</a> and <a href="https://github.com/ShoufaChen/PixelFlow">GitHub repo</a>. | |
</div> | |
""") | |
with gr.Tabs(): | |
with gr.TabItem('Generate'): | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Row(): | |
if args.class_cond: | |
user_input = gr.Dropdown( | |
list(IMAGENET_1K_CLASSES.values()), | |
value='daisy [้่]', | |
type="index", label='ImageNet-1K Class' | |
) | |
else: | |
# text input | |
user_input = gr.Textbox(label='Enter your prompt', show_label=False, max_lines=1, placeholder="Enter your prompt",) | |
noise_shift = gr.Slider(minimum=1.0, maximum=100.0, step=1, value=1.0, label='Noise Shift') | |
cfg_scale = gr.Slider(minimum=1, maximum=25, step=0.1, value=4.0, label='Classifier-free Guidance Scale') | |
num_steps_per_stage = [] | |
for stage_idx in range(config.scheduler.num_stages): | |
num_steps = gr.Slider(minimum=1, maximum=100, step=1, value=5, label=f'Num Inference Steps (Stage {stage_idx})') | |
num_steps_per_stage.append(num_steps) | |
seed = gr.Slider(minimum=0, maximum=1000, step=1, value=42, label='Seed') | |
button = gr.Button("Generate", variant="primary") | |
with gr.Column(): | |
output = gr.Gallery(label='Generated Images', height=700) | |
button.click(infer, inputs=[noise_shift, cfg_scale, user_input, seed, *num_steps_per_stage], outputs=[output]) | |
demo.queue() | |
demo.launch(share=True, debug=True) | |