Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,804 Bytes
137645c abc93c3 137645c abc93c3 137645c abc93c3 137645c abc93c3 137645c abc93c3 137645c abc93c3 137645c abc93c3 137645c abc93c3 137645c abc93c3 137645c abc93c3 137645c abc93c3 137645c abc93c3 137645c abc93c3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
import argparse
import os
from PIL import Image
import gradio as gr
import spaces
from imagenet_en_cn import IMAGENET_1K_CLASSES
from omegaconf import OmegaConf
from huggingface_hub import snapshot_download
import torch
from transformers import T5EncoderModel, AutoTokenizer
from pixelflow.scheduling_pixelflow import PixelFlowScheduler
from pixelflow.pipeline_pixelflow import PixelFlowPipeline
from pixelflow.utils import config as config_utils
from pixelflow.utils.misc import seed_everything
parser = argparse.ArgumentParser(description='Gradio Demo', add_help=False)
parser.add_argument('--checkpoint', type=str, help='checkpoint folder path')
parser.add_argument('--class_cond', action='store_true', help='use class conditional generation')
args = parser.parse_args()
# deploy
args.checkpoint = "pixelflow_t2i"
args.class_cond = False
output_dir = args.checkpoint
if args.class_cond:
if not os.path.exists(output_dir):
snapshot_download(repo_id="ShoufaChen/PixelFlow-Class2Image", local_dir=output_dir)
config = OmegaConf.load(f"{output_dir}/config.yaml")
model = config_utils.instantiate_from_config(config.model)
print(f"Num of parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
ckpt = torch.load(f"{output_dir}/model.pt", map_location="cpu", weights_only=True)
text_encoder = None
tokenizer = None
resolution = 256
NUM_EXAMPLES = 4
else:
if not os.path.exists(output_dir):
snapshot_download(repo_id="ShoufaChen/PixelFlow-Text2Image", local_dir=output_dir)
config = OmegaConf.load(f"{output_dir}/config.yaml")
model = config_utils.instantiate_from_config(config.model)
print(f"Num of parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
ckpt = torch.load(f"{output_dir}/model.pt", map_location="cpu", weights_only=True)
text_encoder = T5EncoderModel.from_pretrained("google/flan-t5-xl")
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-xl")
resolution = 1024
NUM_EXAMPLES = 1
model.load_state_dict(ckpt, strict=True)
model.eval()
print(f"outside space.GPU. {torch.cuda.is_available()=}")
if torch.cuda.is_available():
model = model.cuda()
text_encoder = text_encoder.cuda() if text_encoder else None
device = torch.device("cuda")
else:
raise ValueError("No GPU")
scheduler = PixelFlowScheduler(config.scheduler.num_train_timesteps, num_stages=config.scheduler.num_stages, gamma=-1/3)
pipeline = PixelFlowPipeline(
scheduler,
model,
text_encoder=text_encoder,
tokenizer=tokenizer,
max_token_length=512,
)
@spaces.GPU(duration=120)
def infer(noise_shift, cfg_scale, class_label, seed, *num_steps_per_stage):
print(f"inside space.GPU. {torch.cuda.is_available()=}")
seed_everything(seed)
with torch.autocast("cuda", dtype=torch.bfloat16), torch.no_grad():
samples = pipeline(
prompt=[class_label] * NUM_EXAMPLES,
height=resolution,
width=resolution,
num_inference_steps=list(num_steps_per_stage),
guidance_scale=cfg_scale, # The guidance for the first frame, set it to 7 for 384p variant
device=device,
shift=noise_shift,
use_ode_dopri5=False,
)
samples = (samples * 255).round().astype("uint8")
samples = [Image.fromarray(sample) for sample in samples]
return samples
css = """
h1 {
text-align: center;
display: block;
}
.follow-link {
margin-top: 0.8em;
font-size: 1em;
text-align: center;
}
"""
with gr.Blocks(css=css) as demo:
gr.Markdown("# PixelFlow: Pixel-Space Generative Models with Flow")
gr.HTML("""
<div class="follow-link">
For online class-to-image generation, please try
<a href="https://huggingface.co/spaces/ShoufaChen/PixelFlow">class-to-image</a>.
For more details, refer to our
<a href="https://arxiv.org/abs/2504.07963">arXiv paper</a> and <a href="https://github.com/ShoufaChen/PixelFlow">GitHub repo</a>.
</div>
""")
with gr.Tabs():
with gr.TabItem('Generate'):
with gr.Row():
with gr.Column():
with gr.Row():
if args.class_cond:
user_input = gr.Dropdown(
list(IMAGENET_1K_CLASSES.values()),
value='daisy [雏菊]',
type="index", label='ImageNet-1K Class'
)
else:
# text input
user_input = gr.Textbox(label='Enter your prompt', show_label=False, max_lines=1, placeholder="Enter your prompt",)
noise_shift = gr.Slider(minimum=1.0, maximum=100.0, step=1, value=1.0, label='Noise Shift')
cfg_scale = gr.Slider(minimum=1, maximum=25, step=0.1, value=4.0, label='Classifier-free Guidance Scale')
num_steps_per_stage = []
for stage_idx in range(config.scheduler.num_stages):
num_steps = gr.Slider(minimum=1, maximum=100, step=1, value=5, label=f'Num Inference Steps (Stage {stage_idx})')
num_steps_per_stage.append(num_steps)
seed = gr.Slider(minimum=0, maximum=1000, step=1, value=42, label='Seed')
button = gr.Button("Generate", variant="primary")
with gr.Column():
output = gr.Gallery(label='Generated Images', height=700)
button.click(infer, inputs=[noise_shift, cfg_scale, user_input, seed, *num_steps_per_stage], outputs=[output])
demo.queue()
demo.launch(share=True, debug=True)
|