Repos
https://github.com/mit-han-lab/deepcompressor
Installation
https://github.com/mit-han-lab/deepcompressor/issues/56
https://github.com/nunchaku-tech/deepcompressor/issues/80
Windows
https://learn.microsoft.com/en-us/windows/wsl/install
https://www.anaconda.com/docs/getting-started/miniconda/install
Environment
Hardware:
Nvidia RTX 5060 Ti (Blackwell, sm_120)
Software (WSL):
Python 3.12.11
pip 25.1
CUDA 12.8
Torch 2.7.1+cu128
Diffusers 0.35.0.dev0
Transformers 4.53.2
flash_attn 2.7.4.post1
xformers 0.0.31.post1
Calibration Dataset Preparation
Example: python -m deepcompressor.app.diffusion.dataset.collect.calib svdq/flux.1-kontext-dev.yaml examples/diffusion/configs/collect/qdiff.yaml --pipeline-path svdq/flux.1-kontext-dev/
Sample Log
In total 32 samples
Evaluating with batch size 1
Data: 3%|βββ | 1/32 [13:57<7:12:32, 837.19s/it]
Sampling: 12%|ββββββββββ | 1/8 [01:34<11:01, 94.44s/it]
Quantization
Model Path: https://github.com/nunchaku-tech/deepcompressor/issues/70#issuecomment-2788155233
Save model: --save-model true
or --save-model /PATH/TO/CHECKPOINT/DIR
Example: python -m deepcompressor.app.diffusion.ptq svdq/flux.1-kontext-dev.yaml examples/diffusion/configs/svdquant/nvfp4.yaml --pipeline-path svdq/flux.1-kontext-dev/ --save-model ~/svdq/
Model Files Structure
Deploy
https://github.com/nunchaku-tech/deepcompressor/blob/main/examples/diffusion/README.md#deployment
Example python -m deepcompressor.backend.nunchaku.convert --quant-path ~/svdq/ --output-root ~/svdq/ --model-name flux.1-kontext-dev-svdq-fp4
ComfyUI metadata reference:
Remarks
2025-07-23 Test Notes
FP4 quantization model loads successfully in ComfyUI, but isnβt fully functional yet. Needs further investigation and debugging.
Calibration dataset appears misaligned, may need to revisit and adjust the sampling code for Flux.1 Kontext Dev.
Later, consider running another test using the base black-forest-labs/FLUX.1-dev model for comparison.
Check with the deepcompressor/nunchaku team, request their latest working implementation.
Blockers
- NotImplementedError: Cannot copy out of meta tensor; no data! Please use torch.nn.Module.to_empty() instead of torch.nn.Module.to() when moving module from meta to a different device.
Potential fix: app.diffusion.pipeline.config.py
@staticmethod
def _default_build(
name: str, path: str, dtype: str | torch.dtype, device: str | torch.device, shift_activations: bool
) -> DiffusionPipeline:
if not path:
if name == "sdxl":
path = "stabilityai/stable-diffusion-xl-base-1.0"
elif name == "sdxl-turbo":
path = "stabilityai/sdxl-turbo"
elif name == "pixart-sigma":
path = "PixArt-alpha/PixArt-Sigma-XL-2-1024-MS"
elif name == "flux.1-kontext-dev":
path = "black-forest-labs/FLUX.1-Kontext-dev"
elif name == "flux.1-dev":
path = "black-forest-labs/FLUX.1-dev"
elif name == "flux.1-canny-dev":
path = "black-forest-labs/FLUX.1-Canny-dev"
elif name == "flux.1-depth-dev":
path = "black-forest-labs/FLUX.1-Depth-dev"
elif name == "flux.1-fill-dev":
path = "black-forest-labs/FLUX.1-Fill-dev"
elif name == "flux.1-schnell":
path = "black-forest-labs/FLUX.1-schnell"
else:
raise ValueError(f"Path for {name} is not specified.")
if name in ["flux.1-kontext-dev"]:
pipeline = FluxKontextPipeline.from_pretrained(path, torch_dtype=dtype)
elif name in ["flux.1-canny-dev", "flux.1-depth-dev"]:
pipeline = FluxControlPipeline.from_pretrained(path, torch_dtype=dtype)
elif name == "flux.1-fill-dev":
pipeline = FluxFillPipeline.from_pretrained(path, torch_dtype=dtype)
elif name.startswith("sana-"):
if dtype == torch.bfloat16:
pipeline = SanaPipeline.from_pretrained(path, variant="bf16", torch_dtype=dtype, use_safetensors=True)
pipeline.vae.to(dtype)
pipeline.text_encoder.to(dtype)
else:
pipeline = SanaPipeline.from_pretrained(path, torch_dtype=dtype)
else:
pipeline = AutoPipelineForText2Image.from_pretrained(path, torch_dtype=dtype)
# Debug output
print(">>> DEVICE:", device)
print(">>> PIPELINE TYPE:", type(pipeline))
# Try to move each component using .to_empty()
for name in ["unet", "transformer", "vae", "text_encoder"]:
module = getattr(pipeline, name, None)
if isinstance(module, torch.nn.Module):
try:
print(f">>> Moving {name} to {device} using to_empty()")
module.to_empty(device=device)
except Exception as e:
print(f">>> WARNING: {name}.to_empty({device}) failed: {e}")
try:
print(f">>> Falling back to {name}.to({device})")
module.to(device)
except Exception as ee:
print(f">>> ERROR: {name}.to({device}) also failed: {ee}")
# Identify main model (for patching)
model = getattr(pipeline, "unet", None) or getattr(pipeline, "transformer", None)
if model is not None:
replace_fused_linear_with_concat_linear(model)
replace_up_block_conv_with_concat_conv(model)
if shift_activations:
shift_input_activations(model)
else:
print(">>> WARNING: No model (unet/transformer) found for patching")
return pipeline
- KeyError: <class 'diffusers.models.transformers.transformer_flux.FluxAttention'>
Potential fix: app.diffusion.nn.struct.py
@staticmethod
def _default_construct(
module: Attention,
/,
parent: tp.Optional["DiffusionTransformerBlockStruct"] = None,
fname: str = "",
rname: str = "",
rkey: str = "",
idx: int = 0,
**kwargs,
) -> "DiffusionAttentionStruct":
if isinstance(module, FluxAttention):
# FluxAttention has different attribute names than standard attention
with_rope = True
num_query_heads = module.heads # FluxAttention uses 'heads', not 'num_heads'
num_key_value_heads = module.heads # FLUX typically uses same for q/k/v
# FluxAttention doesn't have 'to_out', but may have other output projections
# Check what output projection attributes actually exist
o_proj = None
o_proj_rname = ""
# Try to find the correct output projection
if hasattr(module, 'to_out') and module.to_out is not None:
o_proj = module.to_out[0] if isinstance(module.to_out, (list, tuple)) else module.to_out
o_proj_rname = "to_out.0" if isinstance(module.to_out, (list, tuple)) else "to_out"
elif hasattr(module, 'to_add_out'):
o_proj = module.to_add_out
o_proj_rname = "to_add_out"
q_proj, k_proj, v_proj = module.to_q, module.to_k, module.to_v
q_proj_rname, k_proj_rname, v_proj_rname = "to_q", "to_k", "to_v"
q, k, v = module.to_q, module.to_k, module.to_v
q_rname, k_rname, v_rname = "to_q", "to_k", "to_v"
# Handle the add_* projections that FluxAttention has
add_q_proj = getattr(module, "add_q_proj", None)
add_k_proj = getattr(module, "add_k_proj", None)
add_v_proj = getattr(module, "add_v_proj", None)
add_o_proj = getattr(module, "to_add_out", None)
add_q_proj_rname = "add_q_proj" if add_q_proj else ""
add_k_proj_rname = "add_k_proj" if add_k_proj else ""
add_v_proj_rname = "add_v_proj" if add_v_proj else ""
add_o_proj_rname = "to_add_out" if add_o_proj else ""
kwargs = (
"encoder_hidden_states",
"attention_mask",
"image_rotary_emb",
)
cross_attention = add_k_proj is not None
elif module.is_cross_attention:
q_proj, k_proj, v_proj = module.to_q, None, None
add_q_proj, add_k_proj, add_v_proj, add_o_proj = None, module.to_k, module.to_v, None
q_proj_rname, k_proj_rname, v_proj_rname = "to_q", "", ""
add_q_proj_rname, add_k_proj_rname, add_v_proj_rname, add_o_proj_rname = "", "to_k", "to_v", ""
else:
q_proj, k_proj, v_proj = module.to_q, module.to_k, module.to_v
add_q_proj = getattr(module, "add_q_proj", None)
add_k_proj = getattr(module, "add_k_proj", None)
add_v_proj = getattr(module, "add_v_proj", None)
add_o_proj = getattr(module, "to_add_out", None)
q_proj_rname, k_proj_rname, v_proj_rname = "to_q", "to_k", "to_v"
add_q_proj_rname, add_k_proj_rname, add_v_proj_rname = "add_q_proj", "add_k_proj", "add_v_proj"
add_o_proj_rname = "to_add_out"
if getattr(module, "to_out", None) is not None:
o_proj = module.to_out[0]
o_proj_rname = "to_out.0"
assert isinstance(o_proj, nn.Linear)
elif parent is not None:
assert isinstance(parent.module, FluxSingleTransformerBlock)
assert isinstance(parent.module.proj_out, ConcatLinear)
assert len(parent.module.proj_out.linears) == 2
o_proj = parent.module.proj_out.linears[0]
o_proj_rname = ".proj_out.linears.0"
else:
raise RuntimeError("Cannot find the output projection.")
if isinstance(module.processor, DiffusionAttentionProcessor):
with_rope = module.processor.rope is not None
elif module.processor.__class__.__name__.startswith("Flux"):
with_rope = True
else:
with_rope = False # TODO: fix for other processors
config = AttentionConfigStruct(
hidden_size=q_proj.weight.shape[1],
add_hidden_size=add_k_proj.weight.shape[1] if add_k_proj is not None else 0,
inner_size=q_proj.weight.shape[0],
num_query_heads=module.heads,
num_key_value_heads=module.to_k.weight.shape[0] // (module.to_q.weight.shape[0] // module.heads),
with_qk_norm=module.norm_q is not None,
with_rope=with_rope,
linear_attn=isinstance(module.processor, SanaLinearAttnProcessor2_0),
)
return DiffusionAttentionStruct(
module=module,
parent=parent,
fname=fname,
idx=idx,
rname=rname,
rkey=rkey,
config=config,
q_proj=q_proj,
k_proj=k_proj,
v_proj=v_proj,
o_proj=o_proj,
add_q_proj=add_q_proj,
add_k_proj=add_k_proj,
add_v_proj=add_v_proj,
add_o_proj=add_o_proj,
q=None, # TODO: add q, k, v
k=None,
v=None,
q_proj_rname=q_proj_rname,
k_proj_rname=k_proj_rname,
v_proj_rname=v_proj_rname,
o_proj_rname=o_proj_rname,
add_q_proj_rname=add_q_proj_rname,
add_k_proj_rname=add_k_proj_rname,
add_v_proj_rname=add_v_proj_rname,
add_o_proj_rname=add_o_proj_rname,
q_rname="",
k_rname="",
v_rname="",
)
- ValueError: Provide either
prompt
orprompt_embeds
. Cannot leave bothprompt
andprompt_embeds
undefined.
Potential Fix: app.diffusion.dataset.collect.calib.py
def collect(config: DiffusionPtqRunConfig, dataset: datasets.Dataset):
samples_dirpath = os.path.join(config.output.root, "samples")
caches_dirpath = os.path.join(config.output.root, "caches")
os.makedirs(samples_dirpath, exist_ok=True)
os.makedirs(caches_dirpath, exist_ok=True)
caches = []
pipeline = config.pipeline.build()
model = pipeline.unet if hasattr(pipeline, "unet") else pipeline.transformer
assert isinstance(model, nn.Module)
model.register_forward_hook(CollectHook(caches=caches), with_kwargs=True)
batch_size = config.eval.batch_size
print(f"In total {len(dataset)} samples")
print(f"Evaluating with batch size {batch_size}")
pipeline.set_progress_bar_config(desc="Sampling", leave=False, dynamic_ncols=True, position=1)
for batch in tqdm(
dataset.iter(batch_size=batch_size, drop_last_batch=False),
desc="Data",
leave=False,
dynamic_ncols=True,
total=(len(dataset) + batch_size - 1) // batch_size,
):
filenames = batch["filename"]
prompts = batch["prompt"]
seeds = [hash_str_to_int(name) for name in filenames]
generators = [torch.Generator(device=pipeline.device).manual_seed(seed) for seed in seeds]
pipeline_kwargs = config.eval.get_pipeline_kwargs()
task = config.pipeline.task
control_root = config.eval.control_root
if task in ["canny-to-image", "depth-to-image", "inpainting"]:
controls = get_control(
task,
batch["image"],
names=batch["filename"],
data_root=os.path.join(
control_root, collect_config.dataset_name, f"{dataset.config_name}-{config.eval.num_samples}"
),
)
if task == "inpainting":
pipeline_kwargs["image"] = controls[0]
pipeline_kwargs["mask_image"] = controls[1]
else:
pipeline_kwargs["control_image"] = controls
# Handle meta tensors by moving individual components
try:
pipeline = pipeline.to("cuda")
except NotImplementedError:
# Move individual pipeline components that have to_empty method
if hasattr(pipeline, 'transformer') and pipeline.transformer is not None:
try:
pipeline.transformer = pipeline.transformer.to("cuda")
except NotImplementedError:
pipeline.transformer = pipeline.transformer.to_empty(device="cuda")
if hasattr(pipeline, 'text_encoder') and pipeline.text_encoder is not None:
try:
pipeline.text_encoder = pipeline.text_encoder.to("cuda")
except NotImplementedError:
pipeline.text_encoder = pipeline.text_encoder.to_empty(device="cuda")
if hasattr(pipeline, 'text_encoder_2') and pipeline.text_encoder_2 is not None:
try:
pipeline.text_encoder_2 = pipeline.text_encoder_2.to("cuda")
except NotImplementedError:
pipeline.text_encoder_2 = pipeline.text_encoder_2.to_empty(device="cuda")
if hasattr(pipeline, 'vae') and pipeline.vae is not None:
try:
pipeline.vae = pipeline.vae.to("cuda")
except NotImplementedError:
pipeline.vae = pipeline.vae.to_empty(device="cuda")
result_images = pipeline(prompt=prompts, generator=generators, **pipeline_kwargs).images
num_guidances = (len(caches) // batch_size) // config.eval.num_steps
num_steps = len(caches) // (batch_size * num_guidances)
assert (
len(caches) == batch_size * num_steps * num_guidances
), f"Unexpected number of caches: {len(caches)} != {batch_size} * {config.eval.num_steps} * {num_guidances}"
for j, (filename, image) in enumerate(zip(filenames, result_images, strict=True)):
image.save(os.path.join(samples_dirpath, f"{filename}.png"))
for s in range(num_steps):
for g in range(num_guidances):
c = caches[s * batch_size * num_guidances + g * batch_size + j]
c["filename"] = filename
c["step"] = s
c["guidance"] = g
c = tree_map(lambda x: process(x), c)
torch.save(c, os.path.join(caches_dirpath, f"{filename}-{s:05d}-{g}.pt"))
caches.clear()
- RuntimeError: Tensor.item() cannot be called on meta tensors
Potential Fix: quantizer.impl.scale.py
def quantize_scale(
s: torch.Tensor,
/,
*,
quant_dtypes: tp.Sequence[QuantDataType],
quant_spans: tp.Sequence[float],
view_shapes: tp.Sequence[torch.Size],
) -> QuantScale:
"""Quantize the scale tensor.
Args:
s (`torch.Tensor`):
The scale tensor.
quant_dtypes (`Sequence[QuantDataType]`):
The quantization dtypes of the scale tensor.
quant_spans (`Sequence[float]`):
The quantization spans of the scale tensor.
view_shapes (`Sequence[torch.Size]`):
The view shapes of the scale tensor.
Returns:
`QuantScale`:
The quantized scale tensor.
"""
# Add validation at the start
if s.numel() == 0:
raise ValueError("Input tensor is empty")
if s.isnan().any() or s.isinf().any():
raise ValueError("Input tensor contains NaN or Inf values")
if (s == 0).all():
raise ValueError("Input tensor contains all zeros")
# Add meta tensor check before any operations
if s.is_meta:
raise RuntimeError("Cannot quantize scale with meta tensor. Ensure model is loaded on actual device.")
# Existing validation
if s.isnan().any() or s.isinf().any():
raise ValueError("Input tensor contains NaN or Inf values")
scale = QuantScale()
s = s.abs()
for view_shape, quant_dtype, quant_span in zip(view_shapes[:-1], quant_dtypes[:-1], quant_spans[:-1], strict=True):
s = s.view(view_shape) # (#g0, rs0, #g1, rs1, #g2, rs2, ...)
ss = s.amax(dim=list(range(1, len(view_shape), 2)), keepdim=True) # i.e., s_dynamic_span
ss = simple_quantize(
ss / quant_span, has_zero_point=False, quant_dtype=quant_dtype
) # i.e., s_scale = s_dynamic_span / s_quant_span
s = s / ss
scale.append(ss)
view_shape = view_shapes[-1]
s = s.view(view_shape)
if any(v != 1 for v in view_shape[1::2]):
ss = s.amax(dim=list(range(1, len(view_shape), 2)), keepdim=True)
ss = simple_quantize(ss / quant_spans[-1], has_zero_point=False, quant_dtype=quant_dtypes[-1])
else:
assert quant_spans[-1] == 1, "The last quant span must be 1."
ss = simple_quantize(s, has_zero_point=False, quant_dtype=quant_dtypes[-1])
scale.append(ss)
scale.remove_zero()
return scale
def quantize(
self,
*,
# scale-based quantization related arguments
scale: torch.Tensor | None = None,
zero: torch.Tensor | None = None,
# range-based quantization related arguments
tensor: torch.Tensor | None = None,
dynamic_range: DynamicRange | None = None,
) -> tuple[QuantScale, torch.Tensor]:
"""Get the quantization scale and zero point of the tensor to be quantized.
Args:
scale (`torch.Tensor` or `None`, *optional*, defaults to `None`):
The scale tensor.
zero (`torch.Tensor` or `None`, *optional*, defaults to `None`):
The zero point tensor.
tensor (`torch.Tensor` or `None`, *optional*, defaults to `None`):
Ten tensor to be quantized. This is only used for range-based quantization.
dynamic_range (`DynamicRange` or `None`, *optional*, defaults to `None`):
The dynamic range of the tensor to be quantized.
Returns:
`tuple[QuantScale, torch.Tensor]`:
The scale and the zero point.
"""
# region step 1: get the dynamic span for range-based scale or the scale tensor
if scale is None:
range_based = True
assert isinstance(tensor, torch.Tensor), "View tensor must be a tensor."
dynamic_range = dynamic_range or DynamicRange()
dynamic_range = dynamic_range.measure(
tensor.view(self.tensor_view_shape),
zero_domain=self.tensor_zero_domain,
is_float_point=self.tensor_quant_dtype.is_float_point,
)
dynamic_range = dynamic_range.intersect(self.tensor_range_bound)
dynamic_span = (dynamic_range.max - dynamic_range.min) if self.has_zero_point else dynamic_range.max
else:
range_based = False
scale = scale.view(self.scale_view_shapes[-1])
assert isinstance(scale, torch.Tensor), "Scale must be a tensor."
# endregion
# region step 2: get the scale
if self.linear_scale_quant_dtypes:
if range_based:
linear_scale = dynamic_span / self.linear_tensor_quant_span
elif self.exponent_scale_quant_dtypes:
linear_scale = scale.mul(self.exponent_tensor_quant_span).div(self.linear_tensor_quant_span)
else:
linear_scale = scale
lin_s = quantize_scale(
linear_scale,
quant_dtypes=self.linear_scale_quant_dtypes,
quant_spans=self.linear_scale_quant_spans,
view_shapes=self.linear_scale_view_shapes,
)
assert lin_s.data is not None, "Linear scale tensor is None."
if not lin_s.data.is_meta:
assert not lin_s.data.isnan().any(), "Linear scale tensor contains NaN."
assert not lin_s.data.isinf().any(), "Linear scale tensor contains Inf."
else:
lin_s = QuantScale()
if self.exponent_scale_quant_dtypes:
if range_based:
exp_scale = dynamic_span / self.exponent_tensor_quant_span
else:
exp_scale = scale
if lin_s.data is not None:
lin_s.data = lin_s.data.expand(self.linear_scale_view_shapes[-1]).reshape(self.scale_view_shapes[-1])
exp_scale = exp_scale / lin_s.data
exp_s = quantize_scale(
exp_scale,
quant_dtypes=self.exponent_scale_quant_dtypes,
quant_spans=self.exponent_scale_quant_spans,
view_shapes=self.exponent_scale_view_shapes,
)
assert exp_s.data is not None, "Exponential scale tensor is None."
assert not exp_s.data.isnan().any(), "Exponential scale tensor contains NaN."
assert not exp_s.data.isinf().any(), "Exponential scale tensor contains Inf."
s = exp_s if lin_s.data is None else lin_s.extend(exp_s)
else:
s = lin_s
# Before the final assertions, add debugging and validation
if s.data is None:
# Log debugging information
print(f"Linear scale dtypes: {self.linear_scale_quant_dtypes}")
print(f"Exponent scale dtypes: {self.exponent_scale_quant_dtypes}")
if hasattr(lin_s, 'data') and lin_s.data is not None:
print(f"Linear scale data shape: {lin_s.data.shape}")
raise RuntimeError("Scale computation failed - resulting scale is None")
assert s.data is not None, "Scale tensor is None."
assert not s.data.isnan().any(), "Scale tensor contains NaN."
assert not s.data.isinf().any(), "Scale tensor contains Inf."
# endregion
# region step 3: get the zero point
if self.has_zero_point:
if range_based:
if self.tensor_zero_domain == ZeroPointDomain.PreScale:
zero = self.tensor_quant_range.min - dynamic_range.min / s.data
else:
zero = self.tensor_quant_range.min * s.data - dynamic_range.min
assert isinstance(zero, torch.Tensor), "Zero point must be a tensor."
z = simple_quantize(zero, has_zero_point=True, quant_dtype=self.zero_quant_dtype)
else:
z = torch.tensor(0, dtype=s.data.dtype, device=s.data.device)
assert not z.isnan().any(), "Zero point tensor contains NaN."
assert not z.isinf().any(), "Zero point tensor contains Inf."
# endregion
return s, z
Potential Fix: app.diffusion.ptq.py
def ptq( # noqa: C901
model: DiffusionModelStruct,
config: DiffusionQuantConfig,
cache: DiffusionPtqCacheConfig | None = None,
load_dirpath: str = "",
save_dirpath: str = "",
copy_on_save: bool = False,
save_model: bool = False,
) -> DiffusionModelStruct:
"""Post-training quantization of a diffusion model.
Args:
model (`DiffusionModelStruct`):
The diffusion model.
config (`DiffusionQuantConfig`):
The diffusion model post-training quantization configuration.
cache (`DiffusionPtqCacheConfig`, *optional*, defaults to `None`):
The diffusion model quantization cache path configuration.
load_dirpath (`str`, *optional*, defaults to `""`):
The directory path to load the quantization checkpoint.
save_dirpath (`str`, *optional*, defaults to `""`):
The directory path to save the quantization checkpoint.
copy_on_save (`bool`, *optional*, defaults to `False`):
Whether to copy the cache to the save directory.
save_model (`bool`, *optional*, defaults to `False`):
Whether to save the quantized model checkpoint.
Returns:
`DiffusionModelStruct`:
The quantized diffusion model.
"""
logger = tools.logging.getLogger(__name__)
if not isinstance(model, DiffusionModelStruct):
model = DiffusionModelStruct.construct(model)
assert isinstance(model, DiffusionModelStruct)
quant_wgts = config.enabled_wgts
quant_ipts = config.enabled_ipts
quant_opts = config.enabled_opts
quant_acts = quant_ipts or quant_opts
quant = quant_wgts or quant_acts
load_model_path, load_path, save_path = "", None, None
if load_dirpath:
load_path = DiffusionQuantCacheConfig(
smooth=os.path.join(load_dirpath, "smooth.pt"),
branch=os.path.join(load_dirpath, "branch.pt"),
wgts=os.path.join(load_dirpath, "wgts.pt"),
acts=os.path.join(load_dirpath, "acts.pt"),
)
load_model_path = os.path.join(load_dirpath, "model.pt")
if os.path.exists(load_model_path):
if config.enabled_wgts and config.wgts.enabled_low_rank:
if os.path.exists(load_path.branch):
load_model = True
else:
logger.warning(f"Model low-rank branch checkpoint {load_path.branch} does not exist")
load_model = False
else:
load_model = True
if load_model:
logger.info(f"* Loading model from {load_model_path}")
save_dirpath = "" # do not save the model if loading
else:
logger.warning(f"Model checkpoint {load_model_path} does not exist")
load_model = False
else:
load_model = False
if save_dirpath:
os.makedirs(save_dirpath, exist_ok=True)
save_path = DiffusionQuantCacheConfig(
smooth=os.path.join(save_dirpath, "smooth.pt"),
branch=os.path.join(save_dirpath, "branch.pt"),
wgts=os.path.join(save_dirpath, "wgts.pt"),
acts=os.path.join(save_dirpath, "acts.pt"),
)
else:
save_model = False
if quant and config.enabled_rotation:
logger.info("* Rotating model for quantization")
tools.logging.Formatter.indent_inc()
rotate_diffusion(model, config=config)
tools.logging.Formatter.indent_dec()
gc.collect()
torch.cuda.empty_cache()
# region smooth quantization
if quant and config.enabled_smooth:
logger.info("* Smoothing model for quantization")
tools.logging.Formatter.indent_inc()
load_from = ""
if load_path and os.path.exists(load_path.smooth):
load_from = load_path.smooth
elif cache and cache.path.smooth and os.path.exists(cache.path.smooth):
load_from = cache.path.smooth
if load_from:
logger.info(f"- Loading smooth scales from {load_from}")
smooth_cache = torch.load(load_from)
smooth_diffusion(model, config, smooth_cache=smooth_cache)
else:
logger.info("- Generating smooth scales")
smooth_cache = smooth_diffusion(model, config)
if cache and cache.path.smooth:
logger.info(f"- Saving smooth scales to {cache.path.smooth}")
os.makedirs(cache.dirpath.smooth, exist_ok=True)
torch.save(smooth_cache, cache.path.smooth)
load_from = cache.path.smooth
if save_path:
if not copy_on_save and load_from:
logger.info(f"- Linking smooth scales to {save_path.smooth}")
os.symlink(os.path.relpath(load_from, save_dirpath), save_path.smooth)
else:
logger.info(f"- Saving smooth scales to {save_path.smooth}")
torch.save(smooth_cache, save_path.smooth)
del smooth_cache
tools.logging.Formatter.indent_dec()
gc.collect()
torch.cuda.empty_cache()
# endregion
# region collect original state dict
if config.needs_acts_quantizer_cache:
if load_path and os.path.exists(load_path.acts):
orig_state_dict = None
elif cache and cache.path.acts and os.path.exists(cache.path.acts):
orig_state_dict = None
else:
orig_state_dict: dict[str, torch.Tensor] = {
name: param.detach().clone() for name, param in model.module.named_parameters() if param.ndim > 1
}
else:
orig_state_dict = None
# endregion
if load_model:
logger.info(f"* Loading model checkpoint from {load_model_path}")
load_diffusion_weights_state_dict(
model,
config,
state_dict=torch.load(load_model_path),
branch_state_dict=torch.load(load_path.branch) if os.path.exists(load_path.branch) else None,
)
gc.collect()
torch.cuda.empty_cache()
elif quant_wgts:
logger.info("* Ensuring model is on actual device before quantization")
# Check if model has meta tensors
has_meta_tensors = any(param.is_meta for param in model.module.parameters())
if has_meta_tensors:
logger.info("* Model contains meta tensors, materializing to actual device")
# Option 1: Use to_empty() and reload weights (recommended)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Store original state dict if available
try:
original_state_dict = model.module.state_dict()
model.module = model.module.to_empty(device=device)
model.module.load_state_dict(original_state_dict)
logger.info("* Successfully materialized model with original weights")
except Exception as e:
logger.warning(f"* Failed to preserve weights during materialization: {e}")
# Fallback: just move to empty device (weights will be zero)
model.module = model.module.to_empty(device=device)
logger.warning("* Model moved to device but weights may be uninitialized")
else:
# Model already has real tensors, just ensure it's on the right device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.module = model.module.to(device)
# Verify no meta tensors remain
remaining_meta = [name for name, param in model.module.named_parameters() if param.is_meta]
if remaining_meta:
raise RuntimeError(f"Parameters still on meta device: {remaining_meta}")
logger.info("* Model successfully prepared for quantization")
logger.info("* Quantizing weights")
tools.logging.Formatter.indent_inc()
quantizer_state_dict, quantizer_load_from = None, ""
if load_path and os.path.exists(load_path.wgts):
quantizer_load_from = load_path.wgts
elif cache and cache.path.wgts and os.path.exists(cache.path.wgts):
quantizer_load_from = cache.path.wgts
if quantizer_load_from:
logger.info(f"- Loading weight settings from {quantizer_load_from}")
quantizer_state_dict = torch.load(quantizer_load_from)
branch_state_dict, branch_load_from = None, ""
if load_path and os.path.exists(load_path.branch):
branch_load_from = load_path.branch
elif cache and cache.path.branch and os.path.exists(cache.path.branch):
branch_load_from = cache.path.branch
if branch_load_from:
logger.info(f"- Loading branch settings from {branch_load_from}")
branch_state_dict = torch.load(branch_load_from)
if not quantizer_load_from:
logger.info("- Generating weight settings")
if not branch_load_from:
logger.info("- Generating branch settings")
quantizer_state_dict, branch_state_dict, scale_state_dict = quantize_diffusion_weights(
model,
config,
quantizer_state_dict=quantizer_state_dict,
branch_state_dict=branch_state_dict,
return_with_scale_state_dict=bool(save_dirpath),
)
if not quantizer_load_from and cache and cache.dirpath.wgts:
logger.info(f"- Saving weight settings to {cache.path.wgts}")
os.makedirs(cache.dirpath.wgts, exist_ok=True)
torch.save(quantizer_state_dict, cache.path.wgts)
quantizer_load_from = cache.path.wgts
if not branch_load_from and cache and cache.dirpath.branch:
logger.info(f"- Saving branch settings to {cache.path.branch}")
os.makedirs(cache.dirpath.branch, exist_ok=True)
torch.save(branch_state_dict, cache.path.branch)
branch_load_from = cache.path.branch
if save_path:
if not copy_on_save and quantizer_load_from:
logger.info(f"- Linking weight settings to {save_path.wgts}")
os.symlink(os.path.relpath(quantizer_load_from, save_dirpath), save_path.wgts)
else:
logger.info(f"- Saving weight settings to {save_path.wgts}")
torch.save(quantizer_state_dict, save_path.wgts)
if not copy_on_save and branch_load_from:
logger.info(f"- Linking branch settings to {save_path.branch}")
os.symlink(os.path.relpath(branch_load_from, save_dirpath), save_path.branch)
else:
logger.info(f"- Saving branch settings to {save_path.branch}")
torch.save(branch_state_dict, save_path.branch)
if save_model:
logger.info(f"- Saving model to {save_dirpath}")
torch.save(scale_state_dict, os.path.join(save_dirpath, "scale.pt"))
torch.save(model.module.state_dict(), os.path.join(save_dirpath, "model.pt"))
del quantizer_state_dict, branch_state_dict, scale_state_dict
tools.logging.Formatter.indent_dec()
gc.collect()
torch.cuda.empty_cache()
if quant_acts:
logger.info(" * Quantizing activations")
tools.logging.Formatter.indent_inc()
if config.needs_acts_quantizer_cache:
load_from = ""
if load_path and os.path.exists(load_path.acts):
load_from = load_path.acts
elif cache and cache.path.acts and os.path.exists(cache.path.acts):
load_from = cache.path.acts
if load_from:
logger.info(f"- Loading activation settings from {load_from}")
quantizer_state_dict = torch.load(load_from)
quantize_diffusion_activations(
model, config, quantizer_state_dict=quantizer_state_dict, orig_state_dict=orig_state_dict
)
else:
logger.info("- Generating activation settings")
quantizer_state_dict = quantize_diffusion_activations(model, config, orig_state_dict=orig_state_dict)
if cache and cache.dirpath.acts and quantizer_state_dict is not None:
logger.info(f"- Saving activation settings to {cache.path.acts}")
os.makedirs(cache.dirpath.acts, exist_ok=True)
torch.save(quantizer_state_dict, cache.path.acts)
load_from = cache.path.acts
if save_dirpath:
if not copy_on_save and load_from:
logger.info(f"- Linking activation quantizer settings to {save_path.acts}")
os.symlink(os.path.relpath(load_from, save_dirpath), save_path.acts)
else:
logger.info(f"- Saving activation quantizer settings to {save_path.acts}")
torch.save(quantizer_state_dict, save_path.acts)
del quantizer_state_dict
else:
logger.info("- No need to generate/load activation quantizer settings")
quantize_diffusion_activations(model, config, orig_state_dict=orig_state_dict)
tools.logging.Formatter.indent_dec()
del orig_state_dict
gc.collect()
torch.cuda.empty_cache()
return model
- RuntimeError: Dataset scripts are no longer supported, but found COCO.py
References
https://github.com/nunchaku-tech/nunchaku/commit/b99fb8be615bc98c6915bbe06a1e0092cbc074a5
https://github.com/nunchaku-tech/nunchaku/blob/main/examples/flux.1-kontext-dev.py
https://github.com/nunchaku-tech/deepcompressor/issues/91
https://deepwiki.com/nunchaku-tech/deepcompressor
https://huggingface.co/mit-han-lab/nunchaku-flux.1-kontext-dev/tree/main
Dependencies
https://github.com/Dao-AILab/flash-attention
https://github.com/facebookresearch/xformers
https://github.com/openai/CLIP
https://github.com/THUDM/ImageReward
Wheels
https://huggingface.co/datasets/siraxe/PrecompiledWheels_Torch-2.8-cu128-cp312
https://huggingface.co/lldacing/flash-attention-windows-wheel