Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import torch | |
import numpy as np | |
import random | |
from PIL import Image | |
from accelerate import Accelerator | |
import os | |
import time | |
from torchvision import transforms | |
from safetensors.torch import load_file | |
from networks import lora_flux | |
from library import flux_utils, flux_train_utils_recraft as flux_train_utils, strategy_flux | |
import logging | |
# Set up logger | |
logger = logging.getLogger(__name__) | |
logging.basicConfig(level=logging.DEBUG) | |
# Ensure necessary devices are available | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
accelerator = Accelerator(mixed_precision='bf16', device_placement=True) | |
# Model paths (replace these with your actual model paths) | |
BASE_FLUX_CHECKPOINT="/tiamat-NAS/songyiren/FYP/liucheng/sd-scripts/MergeModel/6_Portrait/6_Portrait.safetensors" | |
LORA_WEIGHTS_PATH="/tiamat-NAS/songyiren/FYP/liucheng/sd-scripts/RecraftModel/6_Portrait/6_Portrait-step00025000.safetensors" | |
CLIP_L_PATH="/tiamat-NAS/hailong/storage_backup/models/stabilityai/stable-diffusion-3-medium/text_encoders/clip_l.safetensors" | |
T5XXL_PATH="/tiamat-NAS/hailong/storage_backup/models/stabilityai/stable-diffusion-3-medium/text_encoders/t5xxl_fp16.safetensors" | |
AE_PATH="/tiamat-vePFS/share_data/storage/huggingface/models/black-forest-labs/FLUX.1-dev/ae.safetensors" | |
# Load model function | |
def load_target_model(): | |
logger.info("Loading models...") | |
try: | |
_, model = flux_utils.load_flow_model( | |
BASE_FLUX_CHECKPOINT, torch.float8_e4m3fn, "cpu", disable_mmap=False | |
) | |
clip_l = flux_utils.load_clip_l(CLIP_L_PATH, torch.bfloat16, "cpu", disable_mmap=False) | |
clip_l.eval() | |
t5xxl = flux_utils.load_t5xxl(T5XXL_PATH, torch.bfloat16, "cpu", disable_mmap=False) | |
t5xxl.eval() | |
ae = flux_utils.load_ae(AE_PATH, torch.bfloat16, "cpu", disable_mmap=False) | |
logger.info("Models loaded successfully.") | |
return model, [clip_l, t5xxl], ae | |
except Exception as e: | |
logger.error(f"Error loading models: {e}") | |
raise | |
# Image pre-processing (resize and padding) | |
class ResizeWithPadding: | |
def __init__(self, size, fill=255): | |
self.size = size | |
self.fill = fill | |
def __call__(self, img): | |
if isinstance(img, np.ndarray): | |
img = Image.fromarray(img) | |
elif not isinstance(img, Image.Image): | |
raise TypeError("Input must be a PIL Image or a NumPy array") | |
width, height = img.size | |
if width == height: | |
img = img.resize((self.size, self.size), Image.LANCZOS) | |
else: | |
max_dim = max(width, height) | |
new_img = Image.new("RGB", (max_dim, max_dim), (self.fill, self.fill, self.fill)) | |
new_img.paste(img, ((max_dim - width) // 2, (max_dim - height) // 2)) | |
img = new_img.resize((self.size, self.size), Image.LANCZOS) | |
return img | |
# The function to generate image from a prompt and conditional image | |
def infer(prompt, sample_image, frame_num, seed=0, randomize_seed=False): | |
logger.info(f"Started generating image with prompt: {prompt}") | |
# Load models | |
model, [clip_l, t5xxl], ae = load_target_model() | |
model.eval() | |
clip_l.eval() | |
t5xxl.eval() | |
ae.eval() | |
# LoRA | |
multiplier = 1.0 | |
weights_sd = load_file(LORA_WEIGHTS_PATH) | |
lora_model, _ = lora_flux.create_network_from_weights(multiplier, None, ae, [clip_l, t5xxl], model, weights_sd, | |
True) | |
lora_model.apply_to([clip_l, t5xxl], model) | |
info = lora_model.load_state_dict(weights_sd, strict=True) | |
logger.info(f"Loaded LoRA weights from {LORA_WEIGHTS_PATH}: {info}") | |
lora_model.eval() | |
lora_model.to("cuda") | |
# Process the seed | |
if randomize_seed: | |
seed = random.randint(0, np.iinfo(np.int32).max) | |
logger.debug(f"Using seed: {seed}") | |
# Preprocess the conditional image | |
resize_transform = ResizeWithPadding(size=512) if frame_num == 4 else ResizeWithPadding(size=352) | |
img_transforms = transforms.Compose([ | |
resize_transform, | |
transforms.ToTensor(), | |
transforms.Normalize([0.5], [0.5]), | |
]) | |
image = img_transforms(np.array(sample_image, dtype=np.uint8)).unsqueeze(0).to( | |
device=device, | |
dtype=torch.bfloat16 | |
) | |
logger.debug("Conditional image preprocessed.") | |
# Encode the image to latents | |
ae.to("cuda") | |
latents = ae.encode(image) | |
logger.debug("Image encoded to latents.") | |
conditions = {} | |
conditions[prompt] = latents.to("cpu") | |
ae.to("cpu") | |
clip_l.to("cuda") | |
t5xxl.to("cuda") | |
# Encode the prompt | |
tokenize_strategy = strategy_flux.FluxTokenizeStrategy(512) | |
text_encoding_strategy = strategy_flux.FluxTextEncodingStrategy(True) | |
tokens_and_masks = tokenize_strategy.tokenize(prompt) | |
l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, t5xxl], tokens_and_masks, True) | |
logger.debug("Prompt encoded.") | |
# Prepare the noise and other parameters | |
width = 1024 if frame_num == 4 else 1056 | |
height = 1024 if frame_num == 4 else 1056 | |
height = max(64, height - height % 16) | |
width = max(64, width - width % 16) | |
packed_latent_height = height // 16 | |
packed_latent_width = width // 16 | |
noise = torch.randn(1, packed_latent_height * packed_latent_width, 16 * 2 * 2, device=device, dtype=torch.float16) | |
logger.debug("Noise prepared.") | |
# Generate the image | |
timesteps = flux_train_utils.get_schedule(20, noise.shape[1], shift=True) # Sample steps = 20 | |
img_ids = flux_utils.prepare_img_ids(1, packed_latent_height, packed_latent_width).to(device) | |
t5_attn_mask = t5_attn_mask.to(device) | |
ae_outputs = conditions[prompt] | |
logger.debug("Image generation parameters set.") | |
args = lambda: None | |
args.frame_num = frame_num | |
clip_l.to("cpu") | |
t5xxl.to("cpu") | |
torch.cuda.empty_cache() | |
model.to("cuda") | |
# import pdb | |
# pdb.set_trace() | |
# Run the denoising process | |
with accelerator.autocast(), torch.no_grad(): | |
x = flux_train_utils.denoise( | |
args, model, noise, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=1.0, t5_attn_mask=t5_attn_mask, ae_outputs=ae_outputs | |
) | |
logger.debug("Denoising process completed.") | |
# Decode the final image | |
x = x.float() | |
x = flux_utils.unpack_latents(x, packed_latent_height, packed_latent_width) | |
model.to("cpu") | |
ae.to("cuda") | |
with accelerator.autocast(), torch.no_grad(): | |
x = ae.decode(x) | |
logger.debug("Latents decoded into image.") | |
ae.to("cpu") | |
# Convert the tensor to an image | |
x = x.clamp(-1, 1) | |
x = x.permute(0, 2, 3, 1) | |
generated_image = Image.fromarray((127.5 * (x + 1.0)).float().cpu().numpy().astype(np.uint8)[0]) | |
logger.info("Image generation completed.") | |
return generated_image | |
# Gradio interface | |
with gr.Blocks() as demo: | |
gr.Markdown("## FLUX Image Generation") | |
with gr.Row(): | |
# Input for the prompt | |
prompt = gr.Textbox(label="Prompt", placeholder="Enter your prompt here", lines=1) | |
# File upload for image | |
sample_image = gr.Image(label="Upload a Conditional Image", type="pil") | |
# Frame number selection | |
frame_num = gr.Radio([4, 9], label="Select Frame Number", value=4) | |
# Seed and randomize seed options | |
seed = gr.Slider(0, np.iinfo(np.int32).max, step=1, label="Seed", value=0) | |
randomize_seed = gr.Checkbox(label="Randomize Seed", value=True) | |
# Run Button | |
run_button = gr.Button("Generate Image") | |
# Output result | |
result_image = gr.Image(label="Generated Image") | |
run_button.click( | |
fn=infer, | |
inputs=[prompt, sample_image, frame_num, seed, randomize_seed], | |
outputs=[result_image] | |
) | |
# Launch the Gradio app | |
demo.launch(server_port=8289, server_name="0.0.0.0", share=True) | |
# prompt = "1girl" | |
# sample_image = Image.open("/tiamat-NAS/songyiren/FYP/liucheng/sd-scripts/MergeModel/test/1.png") # 使用一个测试图像 | |
# frame_num = 9 | |
# seed = 42 | |
# randomize_seed = False | |
# result = infer(prompt, sample_image, frame_num, seed, randomize_seed) | |
# result.save('asy_results/generated_image.png') | |