import gradio as gr import torch import numpy as np import random from PIL import Image from accelerate import Accelerator import os import time from torchvision import transforms from safetensors.torch import load_file from networks import lora_flux from library import flux_utils, flux_train_utils_recraft as flux_train_utils, strategy_flux import logging # Set up logger logger = logging.getLogger(__name__) logging.basicConfig(level=logging.DEBUG) # Ensure necessary devices are available device = "cuda" if torch.cuda.is_available() else "cpu" accelerator = Accelerator(mixed_precision='bf16', device_placement=True) # Model paths (replace these with your actual model paths) BASE_FLUX_CHECKPOINT="/tiamat-NAS/songyiren/FYP/liucheng/sd-scripts/MergeModel/6_Portrait/6_Portrait.safetensors" LORA_WEIGHTS_PATH="/tiamat-NAS/songyiren/FYP/liucheng/sd-scripts/RecraftModel/6_Portrait/6_Portrait-step00025000.safetensors" CLIP_L_PATH="/tiamat-NAS/hailong/storage_backup/models/stabilityai/stable-diffusion-3-medium/text_encoders/clip_l.safetensors" T5XXL_PATH="/tiamat-NAS/hailong/storage_backup/models/stabilityai/stable-diffusion-3-medium/text_encoders/t5xxl_fp16.safetensors" AE_PATH="/tiamat-vePFS/share_data/storage/huggingface/models/black-forest-labs/FLUX.1-dev/ae.safetensors" # Load model function def load_target_model(): logger.info("Loading models...") try: _, model = flux_utils.load_flow_model( BASE_FLUX_CHECKPOINT, torch.float8_e4m3fn, "cpu", disable_mmap=False ) clip_l = flux_utils.load_clip_l(CLIP_L_PATH, torch.bfloat16, "cpu", disable_mmap=False) clip_l.eval() t5xxl = flux_utils.load_t5xxl(T5XXL_PATH, torch.bfloat16, "cpu", disable_mmap=False) t5xxl.eval() ae = flux_utils.load_ae(AE_PATH, torch.bfloat16, "cpu", disable_mmap=False) logger.info("Models loaded successfully.") return model, [clip_l, t5xxl], ae except Exception as e: logger.error(f"Error loading models: {e}") raise # Image pre-processing (resize and padding) class ResizeWithPadding: def __init__(self, size, fill=255): self.size = size self.fill = fill def __call__(self, img): if isinstance(img, np.ndarray): img = Image.fromarray(img) elif not isinstance(img, Image.Image): raise TypeError("Input must be a PIL Image or a NumPy array") width, height = img.size if width == height: img = img.resize((self.size, self.size), Image.LANCZOS) else: max_dim = max(width, height) new_img = Image.new("RGB", (max_dim, max_dim), (self.fill, self.fill, self.fill)) new_img.paste(img, ((max_dim - width) // 2, (max_dim - height) // 2)) img = new_img.resize((self.size, self.size), Image.LANCZOS) return img # The function to generate image from a prompt and conditional image def infer(prompt, sample_image, frame_num, seed=0, randomize_seed=False): logger.info(f"Started generating image with prompt: {prompt}") # Load models model, [clip_l, t5xxl], ae = load_target_model() model.eval() clip_l.eval() t5xxl.eval() ae.eval() # LoRA multiplier = 1.0 weights_sd = load_file(LORA_WEIGHTS_PATH) lora_model, _ = lora_flux.create_network_from_weights(multiplier, None, ae, [clip_l, t5xxl], model, weights_sd, True) lora_model.apply_to([clip_l, t5xxl], model) info = lora_model.load_state_dict(weights_sd, strict=True) logger.info(f"Loaded LoRA weights from {LORA_WEIGHTS_PATH}: {info}") lora_model.eval() lora_model.to("cuda") # Process the seed if randomize_seed: seed = random.randint(0, np.iinfo(np.int32).max) logger.debug(f"Using seed: {seed}") # Preprocess the conditional image resize_transform = ResizeWithPadding(size=512) if frame_num == 4 else ResizeWithPadding(size=352) img_transforms = transforms.Compose([ resize_transform, transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ]) image = img_transforms(np.array(sample_image, dtype=np.uint8)).unsqueeze(0).to( device=device, dtype=torch.bfloat16 ) logger.debug("Conditional image preprocessed.") # Encode the image to latents ae.to("cuda") latents = ae.encode(image) logger.debug("Image encoded to latents.") conditions = {} conditions[prompt] = latents.to("cpu") ae.to("cpu") clip_l.to("cuda") t5xxl.to("cuda") # Encode the prompt tokenize_strategy = strategy_flux.FluxTokenizeStrategy(512) text_encoding_strategy = strategy_flux.FluxTextEncodingStrategy(True) tokens_and_masks = tokenize_strategy.tokenize(prompt) l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, t5xxl], tokens_and_masks, True) logger.debug("Prompt encoded.") # Prepare the noise and other parameters width = 1024 if frame_num == 4 else 1056 height = 1024 if frame_num == 4 else 1056 height = max(64, height - height % 16) width = max(64, width - width % 16) packed_latent_height = height // 16 packed_latent_width = width // 16 noise = torch.randn(1, packed_latent_height * packed_latent_width, 16 * 2 * 2, device=device, dtype=torch.float16) logger.debug("Noise prepared.") # Generate the image timesteps = flux_train_utils.get_schedule(20, noise.shape[1], shift=True) # Sample steps = 20 img_ids = flux_utils.prepare_img_ids(1, packed_latent_height, packed_latent_width).to(device) t5_attn_mask = t5_attn_mask.to(device) ae_outputs = conditions[prompt] logger.debug("Image generation parameters set.") args = lambda: None args.frame_num = frame_num clip_l.to("cpu") t5xxl.to("cpu") torch.cuda.empty_cache() model.to("cuda") # import pdb # pdb.set_trace() # Run the denoising process with accelerator.autocast(), torch.no_grad(): x = flux_train_utils.denoise( args, model, noise, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=1.0, t5_attn_mask=t5_attn_mask, ae_outputs=ae_outputs ) logger.debug("Denoising process completed.") # Decode the final image x = x.float() x = flux_utils.unpack_latents(x, packed_latent_height, packed_latent_width) model.to("cpu") ae.to("cuda") with accelerator.autocast(), torch.no_grad(): x = ae.decode(x) logger.debug("Latents decoded into image.") ae.to("cpu") # Convert the tensor to an image x = x.clamp(-1, 1) x = x.permute(0, 2, 3, 1) generated_image = Image.fromarray((127.5 * (x + 1.0)).float().cpu().numpy().astype(np.uint8)[0]) logger.info("Image generation completed.") return generated_image # Gradio interface with gr.Blocks() as demo: gr.Markdown("## FLUX Image Generation") with gr.Row(): # Input for the prompt prompt = gr.Textbox(label="Prompt", placeholder="Enter your prompt here", lines=1) # File upload for image sample_image = gr.Image(label="Upload a Conditional Image", type="pil") # Frame number selection frame_num = gr.Radio([4, 9], label="Select Frame Number", value=4) # Seed and randomize seed options seed = gr.Slider(0, np.iinfo(np.int32).max, step=1, label="Seed", value=0) randomize_seed = gr.Checkbox(label="Randomize Seed", value=True) # Run Button run_button = gr.Button("Generate Image") # Output result result_image = gr.Image(label="Generated Image") run_button.click( fn=infer, inputs=[prompt, sample_image, frame_num, seed, randomize_seed], outputs=[result_image] ) # Launch the Gradio app demo.launch(server_port=8289, server_name="0.0.0.0", share=True) # prompt = "1girl" # sample_image = Image.open("/tiamat-NAS/songyiren/FYP/liucheng/sd-scripts/MergeModel/test/1.png") # 使用一个测试图像 # frame_num = 9 # seed = 42 # randomize_seed = False # result = infer(prompt, sample_image, frame_num, seed, randomize_seed) # result.save('asy_results/generated_image.png')