Repos

https://github.com/mit-han-lab/deepcompressor

Installation

https://github.com/mit-han-lab/deepcompressor/issues/56

https://github.com/nunchaku-tech/deepcompressor/issues/80

Windows

https://learn.microsoft.com/en-us/windows/wsl/install

https://www.anaconda.com/docs/getting-started/miniconda/install

Environment

Hardware:

Nvidia RTX 5060 Ti (Blackwell, sm_120)

Software (WSL):

Python 3.12.11

pip 25.1

CUDA 12.8

Torch 2.7.1+cu128

Diffusers 0.35.0.dev0

Transformers 4.53.2

flash_attn 2.7.4.post1

xformers 0.0.31.post1

Calibration Dataset Preparation

https://github.com/nunchaku-tech/deepcompressor/blob/main/examples/diffusion/README.md#step-2-calibration-dataset-preparation

Example: python -m deepcompressor.app.diffusion.dataset.collect.calib svdq/flux.1-kontext-dev.yaml examples/diffusion/configs/collect/qdiff.yaml --pipeline-path svdq/flux.1-kontext-dev/

Sample Log

In total 32 samples
Evaluating with batch size 1
Data:   3%|β–ˆβ–ˆβ–Ž                                                                        | 1/32 [13:57<7:12:32, 837.19s/it]
Sampling:  12%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–                                                                 | 1/8 [01:34<11:01, 94.44s/it]

Quantization

https://github.com/nunchaku-tech/deepcompressor/blob/main/examples/diffusion/README.md#step-3-model-quantization

Model Path: https://github.com/nunchaku-tech/deepcompressor/issues/70#issuecomment-2788155233

Save model: --save-model true or --save-model /PATH/TO/CHECKPOINT/DIR

Example: python -m deepcompressor.app.diffusion.ptq svdq/flux.1-kontext-dev.yaml examples/diffusion/configs/svdquant/nvfp4.yaml --pipeline-path svdq/flux.1-kontext-dev/ --save-model ~/svdq/

Model Files Structure

Deploy

https://github.com/nunchaku-tech/deepcompressor/blob/main/examples/diffusion/README.md#deployment

Example python -m deepcompressor.backend.nunchaku.convert --quant-path ~/svdq/ --output-root ~/svdq/ --model-name flux.1-kontext-dev-svdq-fp4

ComfyUI metadata reference:


Remarks

2025-07-23 Test Notes

  • FP4 quantization model loads successfully in ComfyUI, but isn’t fully functional yet. Needs further investigation and debugging.

  • Calibration dataset appears misaligned, may need to revisit and adjust the sampling code for Flux.1 Kontext Dev.

  • Later, consider running another test using the base black-forest-labs/FLUX.1-dev model for comparison.

  • Check with the deepcompressor/nunchaku team, request their latest working implementation.


Blockers

  1. NotImplementedError: Cannot copy out of meta tensor; no data! Please use torch.nn.Module.to_empty() instead of torch.nn.Module.to() when moving module from meta to a different device.

Potential fix: app.diffusion.pipeline.config.py

    @staticmethod
    def _default_build(
        name: str, path: str, dtype: str | torch.dtype, device: str | torch.device, shift_activations: bool
    ) -> DiffusionPipeline:
        if not path:
            if name == "sdxl":
                path = "stabilityai/stable-diffusion-xl-base-1.0"
            elif name == "sdxl-turbo":
                path = "stabilityai/sdxl-turbo"
            elif name == "pixart-sigma":
                path = "PixArt-alpha/PixArt-Sigma-XL-2-1024-MS"
            elif name == "flux.1-kontext-dev":
                path = "black-forest-labs/FLUX.1-Kontext-dev"
            elif name == "flux.1-dev":
                path = "black-forest-labs/FLUX.1-dev"
            elif name == "flux.1-canny-dev":
                path = "black-forest-labs/FLUX.1-Canny-dev"
            elif name == "flux.1-depth-dev":
                path = "black-forest-labs/FLUX.1-Depth-dev"
            elif name == "flux.1-fill-dev":
                path = "black-forest-labs/FLUX.1-Fill-dev"
            elif name == "flux.1-schnell":
                path = "black-forest-labs/FLUX.1-schnell"
            else:
                raise ValueError(f"Path for {name} is not specified.")
        if name in ["flux.1-kontext-dev"]:
            pipeline = FluxKontextPipeline.from_pretrained(path, torch_dtype=dtype)
        elif name in ["flux.1-canny-dev", "flux.1-depth-dev"]:
            pipeline = FluxControlPipeline.from_pretrained(path, torch_dtype=dtype)
        elif name == "flux.1-fill-dev":
            pipeline = FluxFillPipeline.from_pretrained(path, torch_dtype=dtype)
        elif name.startswith("sana-"):
            if dtype == torch.bfloat16:
                pipeline = SanaPipeline.from_pretrained(path, variant="bf16", torch_dtype=dtype, use_safetensors=True)
                pipeline.vae.to(dtype)
                pipeline.text_encoder.to(dtype)
            else:
                pipeline = SanaPipeline.from_pretrained(path, torch_dtype=dtype)
        else:
            pipeline = AutoPipelineForText2Image.from_pretrained(path, torch_dtype=dtype)

        # Debug output
        print(">>> DEVICE:", device)
        print(">>> PIPELINE TYPE:", type(pipeline))
    
        # Try to move each component using .to_empty()
        for name in ["unet", "transformer", "vae", "text_encoder"]:
            module = getattr(pipeline, name, None)
            if isinstance(module, torch.nn.Module):
                try:
                    print(f">>> Moving {name} to {device} using to_empty()")
                    module.to_empty(device=device)
                except Exception as e:
                    print(f">>> WARNING: {name}.to_empty({device}) failed: {e}")
                    try:
                        print(f">>> Falling back to {name}.to({device})")
                        module.to(device)
                    except Exception as ee:
                        print(f">>> ERROR: {name}.to({device}) also failed: {ee}")
    
        # Identify main model (for patching)
        model = getattr(pipeline, "unet", None) or getattr(pipeline, "transformer", None)
        if model is not None:
            replace_fused_linear_with_concat_linear(model)
            replace_up_block_conv_with_concat_conv(model)
            if shift_activations:
                shift_input_activations(model)
        else:
            print(">>> WARNING: No model (unet/transformer) found for patching")
    
        return pipeline
  1. KeyError: <class 'diffusers.models.transformers.transformer_flux.FluxAttention'>

Potential fix: app.diffusion.nn.struct.py

    @staticmethod
    def _default_construct(
        module: Attention,
        /,
        parent: tp.Optional["DiffusionTransformerBlockStruct"] = None,
        fname: str = "",
        rname: str = "",
        rkey: str = "",
        idx: int = 0,
        **kwargs,
    ) -> "DiffusionAttentionStruct":
        if isinstance(module, FluxAttention):  
            # FluxAttention has different attribute names than standard attention  
            with_rope = True  
            num_query_heads = module.heads  # FluxAttention uses 'heads', not 'num_heads'  
            num_key_value_heads = module.heads  # FLUX typically uses same for q/k/v  
              
            # FluxAttention doesn't have 'to_out', but may have other output projections  
            # Check what output projection attributes actually exist  
            o_proj = None  
            o_proj_rname = ""  
              
            # Try to find the correct output projection  
            if hasattr(module, 'to_out') and module.to_out is not None:  
                o_proj = module.to_out[0] if isinstance(module.to_out, (list, tuple)) else module.to_out  
                o_proj_rname = "to_out.0" if isinstance(module.to_out, (list, tuple)) else "to_out"  
            elif hasattr(module, 'to_add_out'):  
                o_proj = module.to_add_out  
                o_proj_rname = "to_add_out"  
              
            q_proj, k_proj, v_proj = module.to_q, module.to_k, module.to_v  
            q_proj_rname, k_proj_rname, v_proj_rname = "to_q", "to_k", "to_v"  
            q, k, v = module.to_q, module.to_k, module.to_v  
            q_rname, k_rname, v_rname = "to_q", "to_k", "to_v"  
              
            # Handle the add_* projections that FluxAttention has  
            add_q_proj = getattr(module, "add_q_proj", None)  
            add_k_proj = getattr(module, "add_k_proj", None)   
            add_v_proj = getattr(module, "add_v_proj", None)  
            add_o_proj = getattr(module, "to_add_out", None)  
            add_q_proj_rname = "add_q_proj" if add_q_proj else ""  
            add_k_proj_rname = "add_k_proj" if add_k_proj else ""  
            add_v_proj_rname = "add_v_proj" if add_v_proj else ""  
            add_o_proj_rname = "to_add_out" if add_o_proj else ""  
              
            kwargs = (  
                "encoder_hidden_states",  
                "attention_mask",   
                "image_rotary_emb",  
            )  
            cross_attention = add_k_proj is not None
        elif module.is_cross_attention:
            q_proj, k_proj, v_proj = module.to_q, None, None
            add_q_proj, add_k_proj, add_v_proj, add_o_proj = None, module.to_k, module.to_v, None
            q_proj_rname, k_proj_rname, v_proj_rname = "to_q", "", ""
            add_q_proj_rname, add_k_proj_rname, add_v_proj_rname, add_o_proj_rname = "", "to_k", "to_v", ""
        else:
            q_proj, k_proj, v_proj = module.to_q, module.to_k, module.to_v
            add_q_proj = getattr(module, "add_q_proj", None)
            add_k_proj = getattr(module, "add_k_proj", None)
            add_v_proj = getattr(module, "add_v_proj", None)
            add_o_proj = getattr(module, "to_add_out", None)
            q_proj_rname, k_proj_rname, v_proj_rname = "to_q", "to_k", "to_v"
            add_q_proj_rname, add_k_proj_rname, add_v_proj_rname = "add_q_proj", "add_k_proj", "add_v_proj"
            add_o_proj_rname = "to_add_out"
        if getattr(module, "to_out", None) is not None:
            o_proj = module.to_out[0]
            o_proj_rname = "to_out.0"
            assert isinstance(o_proj, nn.Linear)
        elif parent is not None:
            assert isinstance(parent.module, FluxSingleTransformerBlock)
            assert isinstance(parent.module.proj_out, ConcatLinear)
            assert len(parent.module.proj_out.linears) == 2
            o_proj = parent.module.proj_out.linears[0]
            o_proj_rname = ".proj_out.linears.0"
        else:
            raise RuntimeError("Cannot find the output projection.")
        if isinstance(module.processor, DiffusionAttentionProcessor):
            with_rope = module.processor.rope is not None
        elif module.processor.__class__.__name__.startswith("Flux"):
            with_rope = True
        else:
            with_rope = False  # TODO: fix for other processors
        config = AttentionConfigStruct(
            hidden_size=q_proj.weight.shape[1],
            add_hidden_size=add_k_proj.weight.shape[1] if add_k_proj is not None else 0,
            inner_size=q_proj.weight.shape[0],
            num_query_heads=module.heads,
            num_key_value_heads=module.to_k.weight.shape[0] // (module.to_q.weight.shape[0] // module.heads),
            with_qk_norm=module.norm_q is not None,
            with_rope=with_rope,
            linear_attn=isinstance(module.processor, SanaLinearAttnProcessor2_0),
        )
        return DiffusionAttentionStruct(
            module=module,
            parent=parent,
            fname=fname,
            idx=idx,
            rname=rname,
            rkey=rkey,
            config=config,
            q_proj=q_proj,
            k_proj=k_proj,
            v_proj=v_proj,
            o_proj=o_proj,
            add_q_proj=add_q_proj,
            add_k_proj=add_k_proj,
            add_v_proj=add_v_proj,
            add_o_proj=add_o_proj,
            q=None,  # TODO: add q, k, v
            k=None,
            v=None,
            q_proj_rname=q_proj_rname,
            k_proj_rname=k_proj_rname,
            v_proj_rname=v_proj_rname,
            o_proj_rname=o_proj_rname,
            add_q_proj_rname=add_q_proj_rname,
            add_k_proj_rname=add_k_proj_rname,
            add_v_proj_rname=add_v_proj_rname,
            add_o_proj_rname=add_o_proj_rname,
            q_rname="",
            k_rname="",
            v_rname="",
        )
  1. ValueError: Provide either prompt or prompt_embeds. Cannot leave both prompt and prompt_embeds undefined.

Potential Fix: app.diffusion.dataset.collect.calib.py

def collect(config: DiffusionPtqRunConfig, dataset: datasets.Dataset):
    samples_dirpath = os.path.join(config.output.root, "samples")
    caches_dirpath = os.path.join(config.output.root, "caches")
    os.makedirs(samples_dirpath, exist_ok=True)
    os.makedirs(caches_dirpath, exist_ok=True)
    caches = []

    pipeline = config.pipeline.build()
    model = pipeline.unet if hasattr(pipeline, "unet") else pipeline.transformer
    assert isinstance(model, nn.Module)
    model.register_forward_hook(CollectHook(caches=caches), with_kwargs=True)

    batch_size = config.eval.batch_size
    print(f"In total {len(dataset)} samples")
    print(f"Evaluating with batch size {batch_size}")
    pipeline.set_progress_bar_config(desc="Sampling", leave=False, dynamic_ncols=True, position=1)
    for batch in tqdm(
        dataset.iter(batch_size=batch_size, drop_last_batch=False),
        desc="Data",
        leave=False,
        dynamic_ncols=True,
        total=(len(dataset) + batch_size - 1) // batch_size,
    ):
        filenames = batch["filename"]
        prompts = batch["prompt"]
        seeds = [hash_str_to_int(name) for name in filenames]
        generators = [torch.Generator(device=pipeline.device).manual_seed(seed) for seed in seeds]
        pipeline_kwargs = config.eval.get_pipeline_kwargs()

        task = config.pipeline.task
        control_root = config.eval.control_root
        if task in ["canny-to-image", "depth-to-image", "inpainting"]:
            controls = get_control(
                task,
                batch["image"],
                names=batch["filename"],
                data_root=os.path.join(
                    control_root, collect_config.dataset_name, f"{dataset.config_name}-{config.eval.num_samples}"
                ),
            )
            if task == "inpainting":
                pipeline_kwargs["image"] = controls[0]
                pipeline_kwargs["mask_image"] = controls[1]
            else:
                pipeline_kwargs["control_image"] = controls

        # Handle meta tensors by moving individual components  
        try:  
            pipeline = pipeline.to("cuda")  
        except NotImplementedError:  
            # Move individual pipeline components that have to_empty method  
            if hasattr(pipeline, 'transformer') and pipeline.transformer is not None:  
                try:  
                    pipeline.transformer = pipeline.transformer.to("cuda")  
                except NotImplementedError:  
                    pipeline.transformer = pipeline.transformer.to_empty(device="cuda")  

            if hasattr(pipeline, 'text_encoder') and pipeline.text_encoder is not None:  
                try:  
                    pipeline.text_encoder = pipeline.text_encoder.to("cuda")  
                except NotImplementedError:  
                    pipeline.text_encoder = pipeline.text_encoder.to_empty(device="cuda")  

            if hasattr(pipeline, 'text_encoder_2') and pipeline.text_encoder_2 is not None:  
                try:  
                    pipeline.text_encoder_2 = pipeline.text_encoder_2.to("cuda")  
                except NotImplementedError:  
                    pipeline.text_encoder_2 = pipeline.text_encoder_2.to_empty(device="cuda")  

            if hasattr(pipeline, 'vae') and pipeline.vae is not None:  
                try:  
                    pipeline.vae = pipeline.vae.to("cuda")  
                except NotImplementedError:  
                    pipeline.vae = pipeline.vae.to_empty(device="cuda")

        result_images = pipeline(prompt=prompts, generator=generators, **pipeline_kwargs).images
        num_guidances = (len(caches) // batch_size) // config.eval.num_steps
        num_steps = len(caches) // (batch_size * num_guidances)
        assert (
            len(caches) == batch_size * num_steps * num_guidances
        ), f"Unexpected number of caches: {len(caches)} != {batch_size} * {config.eval.num_steps} * {num_guidances}"
        for j, (filename, image) in enumerate(zip(filenames, result_images, strict=True)):
            image.save(os.path.join(samples_dirpath, f"{filename}.png"))
            for s in range(num_steps):
                for g in range(num_guidances):
                    c = caches[s * batch_size * num_guidances + g * batch_size + j]
                    c["filename"] = filename
                    c["step"] = s
                    c["guidance"] = g
                    c = tree_map(lambda x: process(x), c)
                    torch.save(c, os.path.join(caches_dirpath, f"{filename}-{s:05d}-{g}.pt"))
        caches.clear()
  1. RuntimeError: Tensor.item() cannot be called on meta tensors

Potential Fix: quantizer.impl.scale.py

def quantize_scale(
    s: torch.Tensor,
    /,
    *,
    quant_dtypes: tp.Sequence[QuantDataType],
    quant_spans: tp.Sequence[float],
    view_shapes: tp.Sequence[torch.Size],
) -> QuantScale:
    """Quantize the scale tensor.

    Args:
        s (`torch.Tensor`):
            The scale tensor.
        quant_dtypes (`Sequence[QuantDataType]`):
            The quantization dtypes of the scale tensor.
        quant_spans (`Sequence[float]`):
            The quantization spans of the scale tensor.
        view_shapes (`Sequence[torch.Size]`):
            The view shapes of the scale tensor.

    Returns:
        `QuantScale`:
            The quantized scale tensor.
    """
    # Add validation at the start  
    if s.numel() == 0:  
        raise ValueError("Input tensor is empty")  
    if s.isnan().any() or s.isinf().any():  
        raise ValueError("Input tensor contains NaN or Inf values")  
    if (s == 0).all():  
        raise ValueError("Input tensor contains all zeros")  

    # Add meta tensor check before any operations  
    if s.is_meta:  
        raise RuntimeError("Cannot quantize scale with meta tensor. Ensure model is loaded on actual device.")  
      
    # Existing validation  
    if s.isnan().any() or s.isinf().any():  
        raise ValueError("Input tensor contains NaN or Inf values")  

    scale = QuantScale()
    s = s.abs()
    for view_shape, quant_dtype, quant_span in zip(view_shapes[:-1], quant_dtypes[:-1], quant_spans[:-1], strict=True):
        s = s.view(view_shape)  # (#g0, rs0, #g1, rs1, #g2, rs2, ...)
        ss = s.amax(dim=list(range(1, len(view_shape), 2)), keepdim=True)  # i.e., s_dynamic_span
        ss = simple_quantize(
            ss / quant_span, has_zero_point=False, quant_dtype=quant_dtype
        )  # i.e., s_scale = s_dynamic_span / s_quant_span
        s = s / ss
        scale.append(ss)
    view_shape = view_shapes[-1]
    s = s.view(view_shape)
    if any(v != 1 for v in view_shape[1::2]):
        ss = s.amax(dim=list(range(1, len(view_shape), 2)), keepdim=True)
        ss = simple_quantize(ss / quant_spans[-1], has_zero_point=False, quant_dtype=quant_dtypes[-1])
    else:
        assert quant_spans[-1] == 1, "The last quant span must be 1."
        ss = simple_quantize(s, has_zero_point=False, quant_dtype=quant_dtypes[-1])
    scale.append(ss)
    scale.remove_zero()
    return scale

    def quantize(
        self,
        *,
        # scale-based quantization related arguments
        scale: torch.Tensor | None = None,
        zero: torch.Tensor | None = None,
        # range-based quantization related arguments
        tensor: torch.Tensor | None = None,
        dynamic_range: DynamicRange | None = None,
    ) -> tuple[QuantScale, torch.Tensor]:
        """Get the quantization scale and zero point of the tensor to be quantized.

        Args:
            scale (`torch.Tensor` or `None`, *optional*, defaults to `None`):
                The scale tensor.
            zero (`torch.Tensor` or `None`, *optional*, defaults to `None`):
                The zero point tensor.
            tensor (`torch.Tensor` or `None`, *optional*, defaults to `None`):
                Ten tensor to be quantized. This is only used for range-based quantization.
            dynamic_range (`DynamicRange` or `None`, *optional*, defaults to `None`):
                The dynamic range of the tensor to be quantized.

        Returns:
            `tuple[QuantScale, torch.Tensor]`:
                The scale and the zero point.
        """
        # region step 1: get the dynamic span for range-based scale or the scale tensor
        if scale is None:
            range_based = True
            assert isinstance(tensor, torch.Tensor), "View tensor must be a tensor."
            dynamic_range = dynamic_range or DynamicRange()
            dynamic_range = dynamic_range.measure(
                tensor.view(self.tensor_view_shape),
                zero_domain=self.tensor_zero_domain,
                is_float_point=self.tensor_quant_dtype.is_float_point,
            )
            dynamic_range = dynamic_range.intersect(self.tensor_range_bound)
            dynamic_span = (dynamic_range.max - dynamic_range.min) if self.has_zero_point else dynamic_range.max
        else:
            range_based = False
            scale = scale.view(self.scale_view_shapes[-1])
            assert isinstance(scale, torch.Tensor), "Scale must be a tensor."
        # endregion
        # region step 2: get the scale
        if self.linear_scale_quant_dtypes:
            if range_based:
                linear_scale = dynamic_span / self.linear_tensor_quant_span
            elif self.exponent_scale_quant_dtypes:
                linear_scale = scale.mul(self.exponent_tensor_quant_span).div(self.linear_tensor_quant_span)
            else:
                linear_scale = scale
            lin_s = quantize_scale(
                linear_scale,
                quant_dtypes=self.linear_scale_quant_dtypes,
                quant_spans=self.linear_scale_quant_spans,
                view_shapes=self.linear_scale_view_shapes,
            )
            assert lin_s.data is not None, "Linear scale tensor is None."
        if not lin_s.data.is_meta:  
            assert not lin_s.data.isnan().any(), "Linear scale tensor contains NaN."
            assert not lin_s.data.isinf().any(), "Linear scale tensor contains Inf."
        else:
            lin_s = QuantScale()
        if self.exponent_scale_quant_dtypes:
            if range_based:
                exp_scale = dynamic_span / self.exponent_tensor_quant_span
            else:
                exp_scale = scale
            if lin_s.data is not None:
                lin_s.data = lin_s.data.expand(self.linear_scale_view_shapes[-1]).reshape(self.scale_view_shapes[-1])
                exp_scale = exp_scale / lin_s.data
            exp_s = quantize_scale(
                exp_scale,
                quant_dtypes=self.exponent_scale_quant_dtypes,
                quant_spans=self.exponent_scale_quant_spans,
                view_shapes=self.exponent_scale_view_shapes,
            )
            assert exp_s.data is not None, "Exponential scale tensor is None."
            assert not exp_s.data.isnan().any(), "Exponential scale tensor contains NaN."
            assert not exp_s.data.isinf().any(), "Exponential scale tensor contains Inf."
            s = exp_s if lin_s.data is None else lin_s.extend(exp_s)
        else:
            s = lin_s

        # Before the final assertions, add debugging and validation  
        if s.data is None:  
            # Log debugging information  
            print(f"Linear scale dtypes: {self.linear_scale_quant_dtypes}")  
            print(f"Exponent scale dtypes: {self.exponent_scale_quant_dtypes}")  
            if hasattr(lin_s, 'data') and lin_s.data is not None:  
                print(f"Linear scale data shape: {lin_s.data.shape}")  
            raise RuntimeError("Scale computation failed - resulting scale is None")  
        assert s.data is not None, "Scale tensor is None."
        assert not s.data.isnan().any(), "Scale tensor contains NaN."
        assert not s.data.isinf().any(), "Scale tensor contains Inf."
        # endregion
        # region step 3: get the zero point
        if self.has_zero_point:
            if range_based:
                if self.tensor_zero_domain == ZeroPointDomain.PreScale:
                    zero = self.tensor_quant_range.min - dynamic_range.min / s.data
                else:
                    zero = self.tensor_quant_range.min * s.data - dynamic_range.min
            assert isinstance(zero, torch.Tensor), "Zero point must be a tensor."
            z = simple_quantize(zero, has_zero_point=True, quant_dtype=self.zero_quant_dtype)
        else:
            z = torch.tensor(0, dtype=s.data.dtype, device=s.data.device)
        assert not z.isnan().any(), "Zero point tensor contains NaN."
        assert not z.isinf().any(), "Zero point tensor contains Inf."
        # endregion
        return s, z

Potential Fix: app.diffusion.ptq.py

def ptq(  # noqa: C901
    model: DiffusionModelStruct,
    config: DiffusionQuantConfig,
    cache: DiffusionPtqCacheConfig | None = None,
    load_dirpath: str = "",
    save_dirpath: str = "",
    copy_on_save: bool = False,
    save_model: bool = False,
) -> DiffusionModelStruct:
    """Post-training quantization of a diffusion model.

    Args:
        model (`DiffusionModelStruct`):
            The diffusion model.
        config (`DiffusionQuantConfig`):
            The diffusion model post-training quantization configuration.
        cache (`DiffusionPtqCacheConfig`, *optional*, defaults to `None`):
            The diffusion model quantization cache path configuration.
        load_dirpath (`str`, *optional*, defaults to `""`):
            The directory path to load the quantization checkpoint.
        save_dirpath (`str`, *optional*, defaults to `""`):
            The directory path to save the quantization checkpoint.
        copy_on_save (`bool`, *optional*, defaults to `False`):
            Whether to copy the cache to the save directory.
        save_model (`bool`, *optional*, defaults to `False`):
            Whether to save the quantized model checkpoint.

    Returns:
        `DiffusionModelStruct`:
            The quantized diffusion model.
    """
    logger = tools.logging.getLogger(__name__)
    if not isinstance(model, DiffusionModelStruct):
        model = DiffusionModelStruct.construct(model)
    assert isinstance(model, DiffusionModelStruct)

    quant_wgts = config.enabled_wgts
    quant_ipts = config.enabled_ipts
    quant_opts = config.enabled_opts
    quant_acts = quant_ipts or quant_opts
    quant = quant_wgts or quant_acts

    load_model_path, load_path, save_path = "", None, None
    if load_dirpath:
        load_path = DiffusionQuantCacheConfig(
            smooth=os.path.join(load_dirpath, "smooth.pt"),
            branch=os.path.join(load_dirpath, "branch.pt"),
            wgts=os.path.join(load_dirpath, "wgts.pt"),
            acts=os.path.join(load_dirpath, "acts.pt"),
        )
        load_model_path = os.path.join(load_dirpath, "model.pt")
        if os.path.exists(load_model_path):
            if config.enabled_wgts and config.wgts.enabled_low_rank:
                if os.path.exists(load_path.branch):
                    load_model = True
                else:
                    logger.warning(f"Model low-rank branch checkpoint {load_path.branch} does not exist")
                    load_model = False
            else:
                load_model = True
            if load_model:
                logger.info(f"* Loading model from {load_model_path}")
                save_dirpath = ""  # do not save the model if loading
        else:
            logger.warning(f"Model checkpoint {load_model_path} does not exist")
            load_model = False
    else:
        load_model = False
    if save_dirpath:
        os.makedirs(save_dirpath, exist_ok=True)
        save_path = DiffusionQuantCacheConfig(
            smooth=os.path.join(save_dirpath, "smooth.pt"),
            branch=os.path.join(save_dirpath, "branch.pt"),
            wgts=os.path.join(save_dirpath, "wgts.pt"),
            acts=os.path.join(save_dirpath, "acts.pt"),
        )
    else:
        save_model = False

    if quant and config.enabled_rotation:
        logger.info("* Rotating model for quantization")
        tools.logging.Formatter.indent_inc()
        rotate_diffusion(model, config=config)
        tools.logging.Formatter.indent_dec()
        gc.collect()
        torch.cuda.empty_cache()

    # region smooth quantization
    if quant and config.enabled_smooth:
        logger.info("* Smoothing model for quantization")
        tools.logging.Formatter.indent_inc()
        load_from = ""
        if load_path and os.path.exists(load_path.smooth):
            load_from = load_path.smooth
        elif cache and cache.path.smooth and os.path.exists(cache.path.smooth):
            load_from = cache.path.smooth
        if load_from:
            logger.info(f"- Loading smooth scales from {load_from}")
            smooth_cache = torch.load(load_from)
            smooth_diffusion(model, config, smooth_cache=smooth_cache)
        else:
            logger.info("- Generating smooth scales")
            smooth_cache = smooth_diffusion(model, config)
            if cache and cache.path.smooth:
                logger.info(f"- Saving smooth scales to {cache.path.smooth}")
                os.makedirs(cache.dirpath.smooth, exist_ok=True)
                torch.save(smooth_cache, cache.path.smooth)
                load_from = cache.path.smooth
        if save_path:
            if not copy_on_save and load_from:
                logger.info(f"- Linking smooth scales to {save_path.smooth}")
                os.symlink(os.path.relpath(load_from, save_dirpath), save_path.smooth)
            else:
                logger.info(f"- Saving smooth scales to {save_path.smooth}")
                torch.save(smooth_cache, save_path.smooth)
        del smooth_cache
        tools.logging.Formatter.indent_dec()
        gc.collect()
        torch.cuda.empty_cache()
    # endregion
    # region collect original state dict
    if config.needs_acts_quantizer_cache:
        if load_path and os.path.exists(load_path.acts):
            orig_state_dict = None
        elif cache and cache.path.acts and os.path.exists(cache.path.acts):
            orig_state_dict = None
        else:
            orig_state_dict: dict[str, torch.Tensor] = {
                name: param.detach().clone() for name, param in model.module.named_parameters() if param.ndim > 1
            }
    else:
        orig_state_dict = None
    # endregion
    if load_model:
        logger.info(f"* Loading model checkpoint from {load_model_path}")
        load_diffusion_weights_state_dict(
            model,
            config,
            state_dict=torch.load(load_model_path),
            branch_state_dict=torch.load(load_path.branch) if os.path.exists(load_path.branch) else None,
        )
        gc.collect()
        torch.cuda.empty_cache()
    elif quant_wgts:
        logger.info("* Ensuring model is on actual device before quantization")  
          
        # Check if model has meta tensors  
        has_meta_tensors = any(param.is_meta for param in model.module.parameters())  
          
        if has_meta_tensors:  
            logger.info("* Model contains meta tensors, materializing to actual device")  
              
            # Option 1: Use to_empty() and reload weights (recommended)  
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  
              
            # Store original state dict if available  
            try:  
                original_state_dict = model.module.state_dict()  
                model.module = model.module.to_empty(device=device)  
                model.module.load_state_dict(original_state_dict)  
                logger.info("* Successfully materialized model with original weights")  
            except Exception as e:  
                logger.warning(f"* Failed to preserve weights during materialization: {e}")  
                # Fallback: just move to empty device (weights will be zero)  
                model.module = model.module.to_empty(device=device)  
                logger.warning("* Model moved to device but weights may be uninitialized")  
        else:  
            # Model already has real tensors, just ensure it's on the right device  
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  
            model.module = model.module.to(device)  
          
        # Verify no meta tensors remain  
        remaining_meta = [name for name, param in model.module.named_parameters() if param.is_meta]  
        if remaining_meta:  
            raise RuntimeError(f"Parameters still on meta device: {remaining_meta}")  
          
        logger.info("* Model successfully prepared for quantization")

        logger.info("* Quantizing weights")
        tools.logging.Formatter.indent_inc()
        quantizer_state_dict, quantizer_load_from = None, ""
        if load_path and os.path.exists(load_path.wgts):
            quantizer_load_from = load_path.wgts
        elif cache and cache.path.wgts and os.path.exists(cache.path.wgts):
            quantizer_load_from = cache.path.wgts
        if quantizer_load_from:
            logger.info(f"- Loading weight settings from {quantizer_load_from}")
            quantizer_state_dict = torch.load(quantizer_load_from)
        branch_state_dict, branch_load_from = None, ""
        if load_path and os.path.exists(load_path.branch):
            branch_load_from = load_path.branch
        elif cache and cache.path.branch and os.path.exists(cache.path.branch):
            branch_load_from = cache.path.branch
        if branch_load_from:
            logger.info(f"- Loading branch settings from {branch_load_from}")
            branch_state_dict = torch.load(branch_load_from)
        if not quantizer_load_from:
            logger.info("- Generating weight settings")
        if not branch_load_from:
            logger.info("- Generating branch settings")
        quantizer_state_dict, branch_state_dict, scale_state_dict = quantize_diffusion_weights(
            model,
            config,
            quantizer_state_dict=quantizer_state_dict,
            branch_state_dict=branch_state_dict,
            return_with_scale_state_dict=bool(save_dirpath),
        )
        if not quantizer_load_from and cache and cache.dirpath.wgts:
            logger.info(f"- Saving weight settings to {cache.path.wgts}")
            os.makedirs(cache.dirpath.wgts, exist_ok=True)
            torch.save(quantizer_state_dict, cache.path.wgts)
            quantizer_load_from = cache.path.wgts
        if not branch_load_from and cache and cache.dirpath.branch:
            logger.info(f"- Saving branch settings to {cache.path.branch}")
            os.makedirs(cache.dirpath.branch, exist_ok=True)
            torch.save(branch_state_dict, cache.path.branch)
            branch_load_from = cache.path.branch
        if save_path:
            if not copy_on_save and quantizer_load_from:
                logger.info(f"- Linking weight settings to {save_path.wgts}")
                os.symlink(os.path.relpath(quantizer_load_from, save_dirpath), save_path.wgts)
            else:
                logger.info(f"- Saving weight settings to {save_path.wgts}")
                torch.save(quantizer_state_dict, save_path.wgts)
            if not copy_on_save and branch_load_from:
                logger.info(f"- Linking branch settings to {save_path.branch}")
                os.symlink(os.path.relpath(branch_load_from, save_dirpath), save_path.branch)
            else:
                logger.info(f"- Saving branch settings to {save_path.branch}")
                torch.save(branch_state_dict, save_path.branch)
        if save_model:
            logger.info(f"- Saving model to {save_dirpath}")
            torch.save(scale_state_dict, os.path.join(save_dirpath, "scale.pt"))
            torch.save(model.module.state_dict(), os.path.join(save_dirpath, "model.pt"))
        del quantizer_state_dict, branch_state_dict, scale_state_dict
        tools.logging.Formatter.indent_dec()
        gc.collect()
        torch.cuda.empty_cache()
    if quant_acts:
        logger.info("  * Quantizing activations")
        tools.logging.Formatter.indent_inc()
        if config.needs_acts_quantizer_cache:
            load_from = ""
            if load_path and os.path.exists(load_path.acts):
                load_from = load_path.acts
            elif cache and cache.path.acts and os.path.exists(cache.path.acts):
                load_from = cache.path.acts
            if load_from:
                logger.info(f"- Loading activation settings from {load_from}")
                quantizer_state_dict = torch.load(load_from)
                quantize_diffusion_activations(
                    model, config, quantizer_state_dict=quantizer_state_dict, orig_state_dict=orig_state_dict
                )
            else:
                logger.info("- Generating activation settings")
                quantizer_state_dict = quantize_diffusion_activations(model, config, orig_state_dict=orig_state_dict)
                if cache and cache.dirpath.acts and quantizer_state_dict is not None:
                    logger.info(f"- Saving activation settings to {cache.path.acts}")
                    os.makedirs(cache.dirpath.acts, exist_ok=True)
                    torch.save(quantizer_state_dict, cache.path.acts)
                load_from = cache.path.acts
            if save_dirpath:
                if not copy_on_save and load_from:
                    logger.info(f"- Linking activation quantizer settings to {save_path.acts}")
                    os.symlink(os.path.relpath(load_from, save_dirpath), save_path.acts)
                else:
                    logger.info(f"- Saving activation quantizer settings to {save_path.acts}")
                    torch.save(quantizer_state_dict, save_path.acts)
            del quantizer_state_dict
        else:
            logger.info("- No need to generate/load activation quantizer settings")
            quantize_diffusion_activations(model, config, orig_state_dict=orig_state_dict)
        tools.logging.Formatter.indent_dec()
        del orig_state_dict
        gc.collect()
        torch.cuda.empty_cache()
    return model
  1. RuntimeError: Dataset scripts are no longer supported, but found COCO.py

References

https://github.com/nunchaku-tech/nunchaku/commit/b99fb8be615bc98c6915bbe06a1e0092cbc074a5

https://github.com/nunchaku-tech/nunchaku/blob/main/examples/flux.1-kontext-dev.py

https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_flux.py#L266

https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/flux/pipeline_flux_kontext.py

https://github.com/nunchaku-tech/deepcompressor/issues/91

https://deepwiki.com/nunchaku-tech/deepcompressor

https://huggingface.co/mit-han-lab/nunchaku-flux.1-kontext-dev/tree/main


Dependencies

https://github.com/Dao-AILab/flash-attention

https://github.com/facebookresearch/xformers

https://github.com/openai/CLIP

https://github.com/THUDM/ImageReward

Wheels

https://huggingface.co/datasets/siraxe/PrecompiledWheels_Torch-2.8-cu128-cp312

https://huggingface.co/lldacing/flash-attention-windows-wheel

https://github.com/loscrossos/lib_flashattention/releases

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support