Spaces:
Running
on
Zero
Running
on
Zero
import spaces | |
import gradio as gr | |
import torch | |
import numpy as np | |
from PIL import Image | |
from accelerate import Accelerator | |
import os | |
import time | |
import math | |
import json | |
from torchvision import transforms | |
from safetensors.torch import load_file | |
from networks import asylora_flux as lora_flux | |
from library import flux_utils, strategy_flux | |
import flux_minimal_inference_asylora as flux_train_utils | |
import logging | |
from huggingface_hub import login | |
from huggingface_hub import hf_hub_download | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# Set up logger | |
logger = logging.getLogger(__name__) | |
logging.basicConfig(level=logging.DEBUG) | |
accelerator = Accelerator(mixed_precision='bf16', device_placement=True) | |
hf_token = os.getenv("HF_TOKEN") | |
login(token=hf_token) | |
domain_index = { | |
'LEGO': 1, 'Cook': 2, 'Painting': 3, 'Icon': 4, 'Landscape illustration': 5, | |
'Portrait': 6, 'Transformer': 7, 'Sand art': 8, 'Illustration': 9, 'Sketch': 10, | |
'Clay toys': 11, 'Clay sculpture': 12, 'Zbrush Modeling': 13, 'Wood sculpture': 14, | |
'Ink painting': 15, 'Pencil sketch': 16, 'Fabric toys': 17, 'Oil painting': 18, | |
'Jade Carving': 19, 'Line draw': 20, 'Emoji': 21 | |
} | |
lora_paths = { | |
"9 frame": "asymmetric_lora/asymmetric_lora_9f_general.safetensors", | |
"4 frame": "asymmetric_lora/asymmetric_lora_4f_general.safetensors" | |
} | |
# Common paths | |
flux_repo_id="Kijai/flux-fp8" | |
flux_file="flux1-dev-fp8.safetensors" | |
lora_repo_id="showlab/makeanything" | |
clip_repo_id = "comfyanonymous/flux_text_encoders" | |
t5xxl_file = "t5xxl_fp16.safetensors" | |
clip_l_file = "clip_l.safetensors" | |
ae_repo_id = "black-forest-labs/FLUX.1-dev" | |
ae_file = "ae.safetensors" | |
model = None | |
clip_l = None | |
t5xxl = None | |
ae = None | |
lora_model = None | |
# Function to load a file from Hugging Face Hub | |
def download_file(repo_id, file_name): | |
return hf_hub_download(repo_id=repo_id, filename=file_name) | |
# Load model function with dynamic paths based on the selected model | |
def load_target_model(frame, domain): | |
global model, clip_l, t5xxl, ae, lora_model | |
BASE_FLUX_CHECKPOINT=download_file(flux_repo_id, flux_file) | |
CLIP_L_PATH = download_file(clip_repo_id, clip_l_file) | |
T5XXL_PATH = download_file(clip_repo_id, t5xxl_file) | |
AE_PATH = download_file(ae_repo_id, ae_file) | |
LORA_WEIGHTS_PATH = download_file(lora_repo_id, lora_paths[frame]) | |
logger.info("Loading models...") | |
_, model = flux_utils.load_flow_model( | |
BASE_FLUX_CHECKPOINT, torch.float8_e4m3fn, "cpu", disable_mmap=False | |
) | |
clip_l = flux_utils.load_clip_l(CLIP_L_PATH, torch.bfloat16, "cpu", disable_mmap=False) | |
clip_l.eval() | |
t5xxl = flux_utils.load_t5xxl(T5XXL_PATH, torch.bfloat16, "cpu", disable_mmap=False) | |
t5xxl.eval() | |
ae = flux_utils.load_ae(AE_PATH, torch.bfloat16, "cpu", disable_mmap=False) | |
logger.info("Models loaded successfully.") | |
# Load LoRA weights | |
multiplier = 1.0 | |
weights_sd = load_file(LORA_WEIGHTS_PATH) | |
lora_ups_num = 10 if frame=="9 frame" else 21 | |
lora_model, _ = lora_flux.create_network_from_weights(multiplier, None, ae, [clip_l, t5xxl], model, weights_sd, True, lora_ups_num=lora_ups_num) | |
for sub_lora in lora_model.unet_loras: | |
sub_lora.set_lora_up_cur(domain_index[domain]-1) | |
lora_model.apply_to([clip_l, t5xxl], model) | |
info = lora_model.load_state_dict(weights_sd, strict=True) | |
logger.info(f"Loaded LoRA weights from {LORA_WEIGHTS_PATH}: {info}") | |
lora_model.eval() | |
logger.info("Models loaded successfully.") | |
return "Models loaded successfully. Using Frame: {}, Damain: {}".format(frame, domain) | |
# The function to generate image from a prompt and conditional image | |
def infer(prompt, frame, seed=0): | |
global model, clip_l, t5xxl, ae, lora_model | |
if model is None or lora_model is None or clip_l is None or t5xxl is None or ae is None: | |
logger.error("Models not loaded. Please load the models first.") | |
return None | |
frame_num = int(frame[0:1]) | |
logger.info(f"Started generating image with prompt: {prompt}") | |
lora_model.to("cuda") | |
model.eval() | |
clip_l.eval() | |
t5xxl.eval() | |
ae.eval() | |
logger.info(f"Using seed: {seed}") | |
ae.to("cpu") | |
clip_l.to(device) | |
t5xxl.to(device) | |
# Encode the prompt | |
tokenize_strategy = strategy_flux.FluxTokenizeStrategy(512) | |
text_encoding_strategy = strategy_flux.FluxTextEncodingStrategy(True) | |
tokens_and_masks = tokenize_strategy.tokenize(prompt) | |
l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, t5xxl], tokens_and_masks, True) | |
logger.debug("Prompt encoded.") | |
# Prepare the noise and other parameters | |
width = 1024 if frame_num == 4 else 1056 | |
height = 1024 if frame_num == 4 else 1056 | |
packed_latent_height, packed_latent_width = math.ceil(height / 16), math.ceil(width / 16) | |
torch.manual_seed(seed) | |
noise = torch.randn(1, packed_latent_height * packed_latent_width, 16 * 2 * 2, device=device, dtype=torch.float16) | |
logger.debug("Noise prepared.") | |
# Generate the image | |
timesteps = flux_train_utils.get_schedule(20, noise.shape[1], shift=True) # Sample steps = 20 | |
img_ids = flux_utils.prepare_img_ids(1, packed_latent_height, packed_latent_width).to(device) | |
t5_attn_mask = t5_attn_mask.to(device) | |
logger.debug("Image generation parameters set.") | |
args = lambda: None | |
args.frame_num = frame_num | |
clip_l.to("cpu") | |
t5xxl.to("cpu") | |
torch.cuda.empty_cache() | |
model.to(device) | |
print(f"Model device: {model.device}") | |
print(f"Noise device: {noise.device}") | |
print(f"Image IDs device: {img_ids.device}") | |
print(f"T5 output device: {t5_out.device}") | |
print(f"Text IDs device: {txt_ids.device}") | |
print(f"L pooled device: {l_pooled.device}") | |
# Run the denoising process | |
with accelerator.autocast(), torch.no_grad(): | |
x = flux_train_utils.denoise( | |
model, | |
noise, | |
img_ids, | |
t5_out, | |
txt_ids, | |
l_pooled, | |
timesteps, | |
guidance=4.0, | |
t5_attn_mask=t5_attn_mask, | |
cfg_scale=1.0, | |
) | |
logger.debug("Denoising process completed.") | |
# Decode the final image | |
x = x.float() | |
x = flux_utils.unpack_latents(x, packed_latent_height, packed_latent_width) | |
model.to("cpu") | |
ae.to(device) | |
with accelerator.autocast(), torch.no_grad(): | |
x = ae.decode(x) | |
logger.debug("Latents decoded into image.") | |
ae.to("cpu") | |
# Convert the tensor to an image | |
x = x.clamp(-1, 1) | |
x = x.permute(0, 2, 3, 1) | |
generated_image = Image.fromarray((127.5 * (x + 1.0)).float().cpu().numpy().astype(np.uint8)[0]) | |
logger.info("Image generation completed.") | |
torch.cuda.empty_cache() | |
return generated_image | |
def update_domains(floor): | |
domains_dict = { | |
"4 frame": [ | |
"LEGO", "Cook", "Painting", "Icon", "Landscape illustration", | |
"Portrait", "Transformer", "Sand art", "Illustration", "Sketch", | |
"Clay toys", "Clay sculpture", "Zbrush Modeling", "Wood sculpture", "Ink painting", | |
"Pencil sketch", "Fabric toys", "Oil painting", "Jade Carving", "Line draw", "Emoji" | |
], | |
"9 frame": [ | |
"LEGO", "Cook", "Painting", "Icon", "Landscape illustration", | |
"Portrait", "Transformer", "Sand art", "Illustration", "Sketch" | |
] | |
} | |
return gr.Dropdown(choices=domains_dict[floor], label="Select Domains") | |
# Gradio interface | |
with gr.Blocks() as demo: | |
gr.Markdown("## Asymmertric LoRA Generation") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
with gr.Row(): | |
with gr.Column(scale=1): | |
frame_selector = gr.Radio(choices=["4 frame", "9 frame"], label="Select Model") | |
with gr.Column(scale=2): | |
domain_selector = gr.Dropdown(choices=["LEGO", "Cook", "Painting", "Icon", "Landscape illustration", | |
"Portrait", "Transformer", "Sand art", "Illustration", "Sketch", | |
"Clay toys", "Clay sculpture", "Zbrush Modeling", "Wood sculpture", "Ink painting", | |
"Pencil sketch", "Fabric toys", "Oil painting", "Jade Carving", "Line draw", "Emoji"], label="Select Domains") | |
# Load Model Button | |
load_button = gr.Button("Load Model") | |
with gr.Column(scale=1): | |
# Status message box | |
status_box = gr.Textbox(label="Status", placeholder="Model loading status", interactive=False, value="Model not loaded", lines=3) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
# Input for the prompt | |
prompt = gr.Textbox(label="Prompt", placeholder="Enter your prompt here", lines=8) | |
with gr.Row(): | |
seed = gr.Slider(0, np.iinfo(np.int32).max, step=1, label="Seed", value=42) | |
run_button = gr.Button("Generate Image") | |
with gr.Column(scale=1): | |
# Output result | |
result_image = gr.Image(label="Generated Image", interactive=False) | |
frame_selector.change(update_domains, inputs=frame_selector, outputs=domain_selector) | |
# Load model button action | |
load_button.click(fn=load_target_model, inputs=[frame_selector, domain_selector], outputs=[status_box]) | |
# Run Button | |
run_button.click(fn=infer, inputs=[prompt, frame_selector, seed], outputs=[result_image]) | |
gr.Markdown("### Examples") | |
examples = [ | |
[ | |
"9 frame", | |
"LEGO", | |
"sks1, 3*3 puzzle of 9 sub-images, step-by-step construction process of a LEGO model,<image-1> Lay down a gray plate as a road surface.<image-2> Position two red 2x4 bricks side by side to start forming a sports car’s chassis.<image-3> Attach black slope bricks at the front, shaping a sleek hood.<image-4> Insert transparent pieces at the front for headlights.<image-5> Clip on black wheel assemblies at each corner.<image-6> Add a windshield piece and a small black steering wheel inside.<image-7> Place smooth tiles on top to create a glossy roof.<image-8> Add side mirrors and a spoiler at the back.<image-9> Conclude by placing a minifigure driver behind the wheel, ready to race.", | |
1855705978 | |
], | |
[ | |
"9 frame", | |
"Portrait", | |
"sks6, 3*3 puzzle of 9 sub-images, step-by-step portrait painting process, woman with blonde curly hair", | |
1062070717 | |
], | |
[ | |
"9 frame", | |
"Sand art", | |
"sks8, 3*3 puzzle of 9 sub-images, step-by-step description of sand art creation, <image-1>: The outline of a classic pirate ship is drawn, capturing its sails and hull. <image-2>: Basic shapes of the ship’s structure and masts are added, defining its adventurous form. <image-3>: Details of the sails and rigging begin to appear, adding complexity. <image-4>: Shadows and highlights enhance the ship’s three-dimensional appearance. <image-5>: The ship’s deck and cannons are refined, giving it character. <image-6>: Additional elements like waves and seagulls are added for movement. <image-7>: A backdrop of a stormy sea with dark clouds is introduced, adding drama. <image-8>: Further details like lightning and crashing waves are sketched for intensity. <image-9>: Final touches include vibrant blues and grays, completing the thrilling pirate ship scene.", | |
641262478 | |
], | |
] | |
gr.Examples( | |
examples=examples, | |
inputs=[frame_selector, domain_selector, prompt, seed], | |
outputs=[result_image], | |
cache_examples=False | |
) | |
# Launch the Gradio app | |
demo.launch() | |