Integrate Flux-dev on Compose

#523
by baby20cen - opened

I have card 3060 with 12gb ram. i want deploy an docker compose project use with flux-dev to run locally on my pc.
But seem cant use in normal as guideline. Because transformer is too huge. Leading can directly load and need use quanto or something like that.

This is my code on compose for loading pipe. Hope someone can help me about quanto the transform. seem i working in wrong way.

import os
import json
import logging
import torch
from PIL import Image
import io
from diffusers import FluxControlNetModel, FluxControlNetPipeline, FluxTransformer2DModel, AutoencoderKL
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
from torch.cuda.amp import autocast
import psutil
from accelerate import Accelerator
from optimum import quanto
from optimum.quanto import quantization_map
from safetensors.torch import save_file

Environment setup

os.environ["STATRELOAD_DISABLE"] = "1"
os.environ["HF_HUB_OFFLINE"] = "1"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"

Logging setup

logging.basicConfig(level=logging.DEBUG, format="%(asctime)s - %(levelname)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
logger = logging.getLogger(name)

Device config

try:
accelerator = Accelerator(mixed_precision="fp16" if torch.cuda.is_available() else "no")
DEVICE = accelerator.device
DTYPE = torch.float16 if torch.cuda.is_available() else torch.bfloat16 # Use bfloat16 for CPU
logger.info(f"Using device: {DEVICE} with dtype: {DTYPE}")
if torch.cuda.is_available():
logger.debug(f"CUDA version: {torch.version.cuda}, Device count: {torch.cuda.device_count()}")
logger.debug(f"GPU: {torch.cuda.get_device_name(0)}")
except Exception as e:
logger.warning(f"GPU detection failed: {str(e)}. Falling back to CPU.")
DEVICE = torch.device("cpu")
DTYPE = torch.bfloat16

Model paths

CONTROLNET_PATH = "/app/models/Flux.1-dev-Controlnet-Upscaler"
BASE_MODEL_PATH = "/app/models/FLUX.1-dev"

def log_memory():
if torch.cuda.is_available():
try:
allocated = torch.cuda.memory_allocated(DEVICE) / 1024 ** 3
reserved = torch.cuda.memory_reserved(DEVICE) / 1024 ** 3
total = torch.cuda.get_device_properties(0).total_memory / 1024 ** 3
logger.debug(f"GPU Memory - Allocated: {allocated:.2f} GB, Reserved: {reserved:.2f} GB, Total: {total:.2f} GB")
if reserved > 0.9 * total:
logger.warning("High GPU memory usage detected. Consider reducing batch size.")
except Exception as e:
logger.warning(f"Failed to log GPU memory: {str(e)}")
try:
mem = psutil.virtual_memory()
total = mem.total / 1024 ** 3
used = mem.used / 1024 ** 3
logger.debug(f"System RAM - Total: {total:.2f} GB, Used: {used:.2f} GB")
except Exception as e:
logger.warning(f"Failed to log system RAM: {str(e)}")

def validate_model_directory(model_path, required_files=None):
logger.debug(f"Validating model directory: {model_path}")
if not os.path.exists(model_path):
logger.error(f"Model directory does not exist: {model_path}")
return False
if required_files:
missing_files = [f for f in required_files if not os.path.isfile(os.path.join(model_path, f))]
if missing_files:
logger.error(f"Missing required files: {', '.join(missing_files)}")
return False
if model_path == BASE_MODEL_PATH:
required_subfolders = ["transformer", "text_encoder", "text_encoder_2", "tokenizer", "tokenizer_2", "vae", "scheduler"]
for subfolder in required_subfolders:
subfolder_path = os.path.join(model_path, subfolder)
if not os.path.isdir(subfolder_path):
logger.error(f"Missing required subfolder: {subfolder_path}")
return False
try:
subfolder_contents = os.listdir(subfolder_path)
if not subfolder_contents:
logger.error(f"Subfolder {subfolder_path} is empty")
return False
except OSError as e:
logger.error(f"Error accessing subfolder {subfolder_path}: {str(e)}")
return False
return True

class LazyLoadPipeline:
def init(self):
self.pipe = None

def load(self):
    if self.pipe is None:
        logger.info("πŸ”§ Starting FluxControlNet pipeline loading...")
        log_memory()

        if not validate_model_directory(BASE_MODEL_PATH, ["model_index.json"]):
            raise FileNotFoundError(f"Base model directory {BASE_MODEL_PATH} is missing required files or subfolders.")
        if not validate_model_directory(CONTROLNET_PATH, ["config.json"]):
            raise FileNotFoundError(f"ControlNet model directory {CONTROLNET_PATH} is missing required files.")

        try:
            torch.cuda.empty_cache()
            logger.debug("Cleared CUDA cache")
            log_memory()

            # Load models directly onto the target device
            logger.debug("Loading scheduler...")
            scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
                os.path.join(BASE_MODEL_PATH, "scheduler"),
                local_files_only=True
            )
            logger.debug("βœ… Scheduler loaded")

            with torch.no_grad():
                logger.debug("Loading ControlNet...")
                controlnet = FluxControlNetModel.from_pretrained(
                    CONTROLNET_PATH,
                    torch_dtype=DTYPE,
                    local_files_only=True
                ).to(DEVICE)
                logger.debug("Quantizing ControlNet with quanto...")
                quanto.quantize(controlnet, weights=quanto.qint8)
                quanto.freeze(controlnet)
                controlnet = accelerator.prepare(controlnet)
                logger.debug("βœ… ControlNet loaded and quantized")
                log_memory()

                logger.debug("Loading Transformer...")
                transformer = FluxTransformer2DModel.from_pretrained(
                    os.path.join(BASE_MODEL_PATH, "transformer"),
                    torch_dtype=DTYPE,
                    local_files_only=True
                ).to(DEVICE)
                logger.debug("Quantizing Transformer with quanto...")
                quanto.quantize(transformer, weights=quanto.qint8)
                quanto.freeze(transformer)
                transformer = accelerator.prepare(transformer)
                logger.debug("βœ… Transformer loaded and quantized")
                log_memory()

                logger.debug("Loading VAE...")
                vae = AutoencoderKL.from_pretrained(
                    os.path.join(BASE_MODEL_PATH, "vae"),
                    torch_dtype=DTYPE,
                    local_files_only=True
                ).to(DEVICE)
                logger.debug("Quantizing VAE with quanto...")
                quanto.quantize(vae, weights=quanto.qint8)
                quanto.freeze(vae)
                vae = accelerator.prepare(vae)
                logger.debug("βœ… VAE loaded and quantized")
                log_memory()

                logger.debug("Loading CLIP Text Encoder...")
                text_encoder = CLIPTextModel.from_pretrained(
                    os.path.join(BASE_MODEL_PATH, "text_encoder"),
                    torch_dtype=DTYPE,
                    local_files_only=True
                ).to(DEVICE)
                logger.debug("Quantizing CLIP Text Encoder with quanto...")
                quanto.quantize(text_encoder, weights=quanto.qint8)
                quanto.freeze(text_encoder)
                text_encoder = accelerator.prepare(text_encoder)
                logger.debug("βœ… CLIP Text Encoder loaded and quantized")
                log_memory()

                logger.debug("Loading T5 Text Encoder...")
                text_encoder_2 = T5EncoderModel.from_pretrained(
                    os.path.join(BASE_MODEL_PATH, "text_encoder_2"),
                    torch_dtype=DTYPE,
                    local_files_only=True
                ).to(DEVICE)
                logger.debug("Quantizing T5 Text Encoder with quanto...")
                quanto.quantize(text_encoder_2, weights=quanto.qint8)
                quanto.freeze(text_encoder_2)
                text_encoder_2 = accelerator.prepare(text_encoder_2)
                logger.debug("βœ… T5 Text Encoder loaded and quantized")
                log_memory()

            logger.debug("Loading CLIP Tokenizer...")
            tokenizer = CLIPTokenizer.from_pretrained(
                os.path.join(BASE_MODEL_PATH, "tokenizer"),
                local_files_only=True,
                add_prefix_space=True
            )
            logger.debug("βœ… CLIP Tokenizer loaded")

            logger.debug("Loading T5 Tokenizer...")
            tokenizer_2 = T5TokenizerFast.from_pretrained(
                os.path.join(BASE_MODEL_PATH, "tokenizer_2"),
                local_files_only=True,
                use_fast=True,
                add_prefix_space=True
            )
            logger.debug("βœ… T5 Tokenizer loaded")

            logger.debug("Initializing FluxControlNetPipeline...")
            with autocast(enabled=torch.cuda.is_available()):
                self.pipe = FluxControlNetPipeline(
                    scheduler=scheduler,
                    text_encoder=text_encoder,
                    text_encoder_2=text_encoder_2,
                    tokenizer=tokenizer,
                    tokenizer_2=tokenizer_2,
                    transformer=transformer,
                    vae=vae,
                    controlnet=controlnet
                )
                self.pipe = accelerator.prepare(self.pipe)
                logger.debug("βœ… Pipeline initialized and prepared")

            try:
                self.pipe.enable_xformers_memory_efficient_attention()
                logger.debug("βœ… Enabled xformers memory-efficient attention")
            except ImportError:
                logger.warning("⚠️ xformers not available, enabling attention slicing")
                self.pipe.enable_attention_slicing()

            if torch.cuda.is_available():
                try:
                    self.pipe.transformer = torch.compile(self.pipe.transformer)
                    logger.debug("βœ… Transformer compiled with torch.compile")
                except Exception as e:
                    logger.warning(f"Failed to compile transformer: {str(e)}")

            # Save quantized models and quantization maps
            logger.debug("Saving quantized models...")
            save_file(
                self.pipe.transformer.state_dict(),
                os.path.join(BASE_MODEL_PATH, "transformer", "diffusion_pytorch_model.safetensors")
            )
            with open(os.path.join(BASE_MODEL_PATH, "transformer", "quantization_map.json"), "w") as f:
                json.dump(quantization_map(self.pipe.transformer), f)

            self.pipe.text_encoder_2.save_pretrained(os.path.join(BASE_MODEL_PATH, "text_encoder_2"))
            with open(os.path.join(BASE_MODEL_PATH, "text_encoder_2", "quantization_map.json"), "w") as f:
                json.dump(quantization_map(self.pipe.text_encoder_2), f)

            self.pipe.controlnet.save_pretrained(CONTROLNET_PATH)
            with open(os.path.join(CONTROLNET_PATH, "quantization_map.json"), "w") as f:
                json.dump(quantization_map(self.pipe.controlnet), f)

            self.pipe.vae.save_pretrained(os.path.join(BASE_MODEL_PATH, "vae"))
            with open(os.path.join(BASE_MODEL_PATH, "vae", "quantization_map.json"), "w") as f:
                json.dump(quantization_map(self.pipe.vae), f)

            self.pipe.text_encoder.save_pretrained(os.path.join(BASE_MODEL_PATH, "text_encoder"))
            with open(os.path.join(BASE_MODEL_PATH, "text_encoder", "quantization_map.json"), "w") as f:
                json.dump(quantization_map(self.pipe.text_encoder), f)

            log_memory()
            logger.info("βœ… Pipeline loaded and quantized successfully")
        except Exception as e:
            logger.error(f"❌ Failed to load pipeline: {str(e)}", exc_info=True)
            raise

Sign up or log in to comment