# Minimum Inference Code for FLUX import argparse import datetime import math import os import random from typing import Callable, List, Optional import einops import numpy as np import torch from tqdm import tqdm from PIL import Image import accelerate from transformers import CLIPTextModel from safetensors.torch import load_file from library import device_utils from library.device_utils import init_ipex, get_preferred_device from networks import oft_flux init_ipex() from library.utils import setup_logging, str_to_dtype setup_logging() import logging logger = logging.getLogger(__name__) import networks.asylora_flux as lora_flux from library import flux_models, flux_utils, sd3_utils, strategy_flux def time_shift(mu: float, sigma: float, t: torch.Tensor): return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15) -> Callable[[float], float]: m = (y2 - y1) / (x2 - x1) b = y1 - m * x1 return lambda x: m * x + b def get_schedule( num_steps: int, image_seq_len: int, base_shift: float = 0.5, max_shift: float = 1.15, shift: bool = True, ) -> list[float]: # extra step for zero timesteps = torch.linspace(1, 0, num_steps + 1) # shifting the schedule to favor high timesteps for higher signal images if shift: # eastimate mu based on linear estimation between two points mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len) timesteps = time_shift(mu, 1.0, timesteps) return timesteps.tolist() def denoise( model: flux_models.Flux, img: torch.Tensor, img_ids: torch.Tensor, txt: torch.Tensor, txt_ids: torch.Tensor, vec: torch.Tensor, timesteps: list[float], guidance: float = 4.0, t5_attn_mask: Optional[torch.Tensor] = None, neg_txt: Optional[torch.Tensor] = None, neg_vec: Optional[torch.Tensor] = None, neg_t5_attn_mask: Optional[torch.Tensor] = None, cfg_scale: Optional[float] = None, ): # this is ignored for schnell logger.info(f"guidance: {guidance}, cfg_scale: {cfg_scale}") guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) # prepare classifier free guidance if neg_txt is not None and neg_vec is not None: b_img_ids = torch.cat([img_ids, img_ids], dim=0) b_txt_ids = torch.cat([txt_ids, txt_ids], dim=0) b_txt = torch.cat([neg_txt, txt], dim=0) b_vec = torch.cat([neg_vec, vec], dim=0) if t5_attn_mask is not None and neg_t5_attn_mask is not None: b_t5_attn_mask = torch.cat([neg_t5_attn_mask, t5_attn_mask], dim=0) else: b_t5_attn_mask = None else: b_img_ids = img_ids b_txt_ids = txt_ids b_txt = txt b_vec = vec b_t5_attn_mask = t5_attn_mask for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]): t_vec = torch.full((b_img_ids.shape[0],), t_curr, dtype=img.dtype, device=img.device) # classifier free guidance if neg_txt is not None and neg_vec is not None: b_img = torch.cat([img, img], dim=0) else: b_img = img pred = model( img=b_img, img_ids=b_img_ids, txt=b_txt, txt_ids=b_txt_ids, y=b_vec, timesteps=t_vec, guidance=guidance_vec, txt_attention_mask=b_t5_attn_mask, ) # classifier free guidance if neg_txt is not None and neg_vec is not None: pred_uncond, pred = torch.chunk(pred, 2, dim=0) pred = pred_uncond + cfg_scale * (pred - pred_uncond) img = img + (t_prev - t_curr) * pred return img def do_sample( accelerator: Optional[accelerate.Accelerator], model: flux_models.Flux, img: torch.Tensor, img_ids: torch.Tensor, l_pooled: torch.Tensor, t5_out: torch.Tensor, txt_ids: torch.Tensor, num_steps: int, guidance: float, t5_attn_mask: Optional[torch.Tensor], is_schnell: bool, device: torch.device, flux_dtype: torch.dtype, neg_l_pooled: Optional[torch.Tensor] = None, neg_t5_out: Optional[torch.Tensor] = None, neg_t5_attn_mask: Optional[torch.Tensor] = None, cfg_scale: Optional[float] = None, ): logger.info(f"num_steps: {num_steps}") timesteps = get_schedule(num_steps, img.shape[1], shift=not is_schnell) # denoise initial noise if accelerator: with accelerator.autocast(), torch.no_grad(): x = denoise( model, img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance, t5_attn_mask, neg_t5_out, neg_l_pooled, neg_t5_attn_mask, cfg_scale, ) else: with torch.autocast(device_type=device.type, dtype=flux_dtype), torch.no_grad(): x = denoise( model, img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance, t5_attn_mask, neg_t5_out, neg_l_pooled, neg_t5_attn_mask, cfg_scale, ) return x def generate_image( model, clip_l: CLIPTextModel, t5xxl, ae, prompt: str, seed: Optional[int], image_width: int, image_height: int, steps: Optional[int], guidance: float, negative_prompt: Optional[str], cfg_scale: float, ): seed = seed if seed is not None else random.randint(0, 2**32 - 1) logger.info(f"Seed: {seed}") # make first noise with packed shape # original: b,16,2*h//16,2*w//16, packed: b,h//16*w//16,16*2*2 packed_latent_height, packed_latent_width = math.ceil(image_height / 16), math.ceil(image_width / 16) noise_dtype = torch.float32 if is_fp8(dtype) else dtype noise = torch.randn( 1, packed_latent_height * packed_latent_width, 16 * 2 * 2, device=device, dtype=noise_dtype, generator=torch.Generator(device=device).manual_seed(seed), ) # prepare img and img ids # this is needed only for img2img # img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) # if img.shape[0] == 1 and bs > 1: # img = repeat(img, "1 ... -> bs ...", bs=bs) # txt2img only needs img_ids img_ids = flux_utils.prepare_img_ids(1, packed_latent_height, packed_latent_width) # prepare fp8 models if is_fp8(clip_l_dtype) and (not hasattr(clip_l, "fp8_prepared") or not clip_l.fp8_prepared): logger.info(f"prepare CLIP-L for fp8: set to {clip_l_dtype}, set embeddings to {torch.bfloat16}") clip_l.to(clip_l_dtype) # fp8 clip_l.text_model.embeddings.to(dtype=torch.bfloat16) clip_l.fp8_prepared = True if is_fp8(t5xxl_dtype) and (not hasattr(t5xxl, "fp8_prepared") or not t5xxl.fp8_prepared): logger.info(f"prepare T5xxl for fp8: set to {t5xxl_dtype}") def prepare_fp8(text_encoder, target_dtype): def forward_hook(module): def forward(hidden_states): hidden_gelu = module.act(module.wi_0(hidden_states)) hidden_linear = module.wi_1(hidden_states) hidden_states = hidden_gelu * hidden_linear hidden_states = module.dropout(hidden_states) hidden_states = module.wo(hidden_states) return hidden_states return forward for module in text_encoder.modules(): if module.__class__.__name__ in ["T5LayerNorm", "Embedding"]: # print("set", module.__class__.__name__, "to", target_dtype) module.to(target_dtype) if module.__class__.__name__ in ["T5DenseGatedActDense"]: # print("set", module.__class__.__name__, "hooks") module.forward = forward_hook(module) t5xxl.to(t5xxl_dtype) prepare_fp8(t5xxl.encoder, torch.bfloat16) t5xxl.fp8_prepared = True # prepare embeddings logger.info("Encoding prompts...") clip_l = clip_l.to(device) t5xxl = t5xxl.to(device) def encode(prpt: str): tokens_and_masks = tokenize_strategy.tokenize(prpt) with torch.no_grad(): if is_fp8(clip_l_dtype): with accelerator.autocast(): l_pooled, _, _, _ = encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, None], tokens_and_masks) else: with torch.autocast(device_type=device.type, dtype=clip_l_dtype): l_pooled, _, _, _ = encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, None], tokens_and_masks) if is_fp8(t5xxl_dtype): with accelerator.autocast(): _, t5_out, txt_ids, t5_attn_mask = encoding_strategy.encode_tokens( tokenize_strategy, [clip_l, t5xxl], tokens_and_masks, args.apply_t5_attn_mask ) else: with torch.autocast(device_type=device.type, dtype=t5xxl_dtype): _, t5_out, txt_ids, t5_attn_mask = encoding_strategy.encode_tokens( tokenize_strategy, [None, t5xxl], tokens_and_masks, args.apply_t5_attn_mask ) return l_pooled, t5_out, txt_ids, t5_attn_mask l_pooled, t5_out, txt_ids, t5_attn_mask = encode(prompt) if negative_prompt: neg_l_pooled, neg_t5_out, _, neg_t5_attn_mask = encode(negative_prompt) else: neg_l_pooled, neg_t5_out, neg_t5_attn_mask = None, None, None # NaN check if torch.isnan(l_pooled).any(): raise ValueError("NaN in l_pooled") if torch.isnan(t5_out).any(): raise ValueError("NaN in t5_out") if args.offload: clip_l = clip_l.cpu() t5xxl = t5xxl.cpu() # del clip_l, t5xxl device_utils.clean_memory() # generate image logger.info("Generating image...") model = model.to(device) if steps is None: steps = 4 if is_schnell else 50 img_ids = img_ids.to(device) t5_attn_mask = t5_attn_mask.to(device) if args.apply_t5_attn_mask else None x = do_sample( accelerator, model, noise, img_ids, l_pooled, t5_out, txt_ids, steps, guidance, t5_attn_mask, is_schnell, device, flux_dtype, neg_l_pooled, neg_t5_out, neg_t5_attn_mask, cfg_scale, ) if args.offload: model = model.cpu() # del model device_utils.clean_memory() # unpack x = x.float() x = einops.rearrange(x, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=packed_latent_height, w=packed_latent_width, ph=2, pw=2) # decode logger.info("Decoding image...") ae = ae.to(device) with torch.no_grad(): if is_fp8(ae_dtype): with accelerator.autocast(): x = ae.decode(x) else: with torch.autocast(device_type=device.type, dtype=ae_dtype): x = ae.decode(x) if args.offload: ae = ae.cpu() x = x.clamp(-1, 1) x = x.permute(0, 2, 3, 1) img = Image.fromarray((127.5 * (x + 1.0)).float().cpu().numpy().astype(np.uint8)[0]) # save image output_dir = args.output_dir os.makedirs(output_dir, exist_ok=True) output_path = os.path.join(output_dir, f"{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.png") img.save(output_path) logger.info(f"Saved image to {output_path}") if __name__ == "__main__": target_height = 768 # 1024 target_width = 1360 # 1024 # steps = 50 # 28 # 50 # guidance_scale = 5 # seed = 1 # None # 1 device = get_preferred_device() parser = argparse.ArgumentParser() parser.add_argument("--lora_ups_num", type=int, required=True) parser.add_argument("--lora_up_cur", type=int, required=True) parser.add_argument("--ckpt_path", type=str, required=True) parser.add_argument("--clip_l", type=str, required=False) parser.add_argument("--t5xxl", type=str, required=False) parser.add_argument("--ae", type=str, required=False) parser.add_argument("--apply_t5_attn_mask", action="store_true") parser.add_argument("--prompt", type=str, default="A photo of a cat") parser.add_argument("--output_dir", type=str, default=".") parser.add_argument("--dtype", type=str, default="bfloat16", help="base dtype") parser.add_argument("--clip_l_dtype", type=str, default=None, help="dtype for clip_l") parser.add_argument("--ae_dtype", type=str, default=None, help="dtype for ae") parser.add_argument("--t5xxl_dtype", type=str, default=None, help="dtype for t5xxl") parser.add_argument("--flux_dtype", type=str, default=None, help="dtype for flux") parser.add_argument("--seed", type=int, default=None) parser.add_argument("--steps", type=int, default=None, help="Number of steps. Default is 4 for schnell, 50 for dev") parser.add_argument("--guidance", type=float, default=3.5) parser.add_argument("--negative_prompt", type=str, default=None) parser.add_argument("--cfg_scale", type=float, default=1.0) parser.add_argument("--offload", action="store_true", help="Offload to CPU") parser.add_argument( "--lora_weights", type=str, nargs="*", default=[], help="LoRA weights, only supports networks.lora_flux and lora_oft, each argument is a `path;multiplier` (semi-colon separated)", ) parser.add_argument("--merge_lora_weights", action="store_true", help="Merge LoRA weights to model") parser.add_argument("--width", type=int, default=target_width) parser.add_argument("--height", type=int, default=target_height) parser.add_argument("--interactive", action="store_true") args = parser.parse_args() seed = args.seed steps = args.steps guidance_scale = args.guidance lora_ups_num = args.lora_ups_num lora_up_cur = args.lora_up_cur def is_fp8(dt): return dt in [torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz] dtype = str_to_dtype(args.dtype) clip_l_dtype = str_to_dtype(args.clip_l_dtype, dtype) t5xxl_dtype = str_to_dtype(args.t5xxl_dtype, dtype) ae_dtype = str_to_dtype(args.ae_dtype, dtype) flux_dtype = str_to_dtype(args.flux_dtype, dtype) logger.info(f"Dtypes for clip_l, t5xxl, ae, flux: {clip_l_dtype}, {t5xxl_dtype}, {ae_dtype}, {flux_dtype}") loading_device = "cpu" if args.offload else device use_fp8 = [is_fp8(d) for d in [dtype, clip_l_dtype, t5xxl_dtype, ae_dtype, flux_dtype]] if any(use_fp8): accelerator = accelerate.Accelerator(mixed_precision="bf16") else: accelerator = None # load clip_l logger.info(f"Loading clip_l from {args.clip_l}...") clip_l = flux_utils.load_clip_l(args.clip_l, clip_l_dtype, loading_device) clip_l.eval() logger.info(f"Loading t5xxl from {args.t5xxl}...") t5xxl = flux_utils.load_t5xxl(args.t5xxl, t5xxl_dtype, loading_device) t5xxl.eval() # if is_fp8(clip_l_dtype): # clip_l = accelerator.prepare(clip_l) # if is_fp8(t5xxl_dtype): # t5xxl = accelerator.prepare(t5xxl) # DiT is_schnell, model = flux_utils.load_flow_model(args.ckpt_path, None, loading_device) model.eval() logger.info(f"Casting model to {flux_dtype}") model.to(flux_dtype) # make sure model is dtype # if is_fp8(flux_dtype): # model = accelerator.prepare(model) # if args.offload: # model = model.to("cpu") t5xxl_max_length = 256 if is_schnell else 512 tokenize_strategy = strategy_flux.FluxTokenizeStrategy(t5xxl_max_length) encoding_strategy = strategy_flux.FluxTextEncodingStrategy() # AE ae = flux_utils.load_ae(args.ae, ae_dtype, loading_device) ae.eval() # if is_fp8(ae_dtype): # ae = accelerator.prepare(ae) # LoRA lora_models: List[lora_flux.LoRANetwork] = [] for weights_file in args.lora_weights: if ";" in weights_file: weights_file, multiplier = weights_file.split(";") multiplier = float(multiplier) else: multiplier = 1.0 weights_sd = load_file(weights_file) is_lora = is_oft = False for key in weights_sd.keys(): if key.startswith("lora"): is_lora = True if key.startswith("oft"): is_oft = True if is_lora or is_oft: break module = lora_flux if is_lora else oft_flux lora_model, _ = module.create_network_from_weights(multiplier, None, ae, [clip_l, t5xxl], model, weights_sd, True, lora_ups_num) for sub_lora in lora_model.unet_loras: sub_lora.set_lora_up_cur(lora_up_cur-1) if args.merge_lora_weights: lora_model.merge_to([clip_l, t5xxl], model, weights_sd) else: lora_model.apply_to([clip_l, t5xxl], model) info = lora_model.load_state_dict(weights_sd, strict=True) logger.info(f"Loaded LoRA weights from {weights_file}: {info}") lora_model.eval() lora_model.to(device) lora_models.append(lora_model) if not args.interactive: generate_image( model, clip_l, t5xxl, ae, args.prompt, args.seed, args.width, args.height, args.steps, args.guidance, args.negative_prompt, args.cfg_scale, ) else: # loop for interactive width = target_width height = target_height steps = None guidance = args.guidance cfg_scale = args.cfg_scale while True: print( "Enter prompt (empty to exit). Options: --w --h --s --d --g --m " " --n , `-` for empty negative prompt --c " ) prompt = input() if prompt == "": break # parse options options = prompt.split("--") prompt = options[0].strip() seed = None negative_prompt = None for opt in options[1:]: try: opt = opt.strip() if opt.startswith("w"): width = int(opt[1:].strip()) elif opt.startswith("h"): height = int(opt[1:].strip()) elif opt.startswith("s"): steps = int(opt[1:].strip()) elif opt.startswith("d"): seed = int(opt[1:].strip()) elif opt.startswith("g"): guidance = float(opt[1:].strip()) elif opt.startswith("m"): mutipliers = opt[1:].strip().split(",") if len(mutipliers) != len(lora_models): logger.error(f"Invalid number of multipliers, expected {len(lora_models)}") continue for i, lora_model in enumerate(lora_models): lora_model.set_multiplier(float(mutipliers[i])) elif opt.startswith("n"): negative_prompt = opt[1:].strip() if negative_prompt == "-": negative_prompt = "" elif opt.startswith("c"): cfg_scale = float(opt[1:].strip()) except ValueError as e: logger.error(f"Invalid option: {opt}, {e}") generate_image(model, clip_l, t5xxl, ae, prompt, seed, width, height, steps, guidance, negative_prompt, cfg_scale) logger.info("Done!")