What changes were needed to support this model?

#6
by aw632 - opened

What changes were made to https://github.com/Tencent-Hunyuan/HunyuanVideo/blob/main/hyvideo/modules/fp8_optimization.py#L100 to support the model here? I tried to run the quantization script but was unable to videos with the same quality as the one here

This was my script, I don't know if it's optimal, especially the e5m2 part:

from safetensors.torch import load_file, save_file
import safetensors.torch
import torch

import gc

out_dtype = torch.float8_e4m3fn
if out_dtype == torch.float8_e4m3fn:
    bits = 8
    mantissa_bit = 3
    sign_bits = 1
    dtype_string = "e4m3fn"
elif out_dtype == torch.float8_e5m2:
    bits = 8
    mantissa_bit = 2
    sign_bits = 1
    dtype_string = "e5m2"
    
def fp8_tensor_quant(x, scale):
    scale = scale.reshape([-1] + [1] * (x.dim() - 1))
    quant_dequant_x = quantize_to_fp8(x / scale)
    return quant_dequant_x, scale

def quantize_to_fp8(x):
    device = x.device
    dtype = x.dtype
    M = max(1, min(mantissa_bit, bits - sign_bits))
    E = bits - sign_bits - M
    bias = 2 ** (E - 1) - 1
    mantissa = 1
    for i in range(mantissa_bit - 1):
        mantissa += 1 / (2 ** (i+1))
    maxval = torch.tensor(mantissa * 2 ** (2**E - 1 - bias), device=device, dtype=dtype)
    minval = -maxval if sign_bits == 1 else torch.tensor(0, device=device, dtype=dtype)
    input_clamp = torch.clamp(x, minval.item(), maxval.item())
    eps = 1e-6
    log_scales = torch.floor(torch.log2(torch.abs(input_clamp) + eps) + bias)
    log_scales = torch.clamp_min(log_scales, 1.0)
    log_scales = 2.0 ** (log_scales - M - bias)
    qdq_out = torch.round(input_clamp / log_scales) * log_scales
    return qdq_out


def load_file(path):
    print(f"Loading {path}...")
    if not path.endswith(".safetensors"):
        loaded = torch.load(path)
    else:
        loaded = safetensors.torch.load_file(path)
    return loaded


sd_pruned = dict()
params_to_keep = {"norm", "head", "bias", "time_in", "vector_in", "patch_embedding", "text_embedding", "time_", "img_emb", "modulation"}
model_type = "Wan2_2-T2V-A14B-HIGH"

for i in range(1, 7):
    file_path = f"diffusion_pytorch_model-{i:05d}-of-00006.safetensors"
    current_sd = load_file(file_path)

    for k, v in current_sd.items():
        if k not in sd_pruned:
            if any(keyword in k for keyword in params_to_keep):
                if "patch_embedding" in k:
                    sd_pruned[k] = v
                else:
                    sd_pruned[k] = v.to(torch.float32)
            elif "weight" in k:
                if out_dtype == torch.float8_e4m3fn:
                    scale = torch.max(torch.abs(v))
                elif out_dtype == torch.float8_e5m2:
                    abs_v = torch.abs(v.cpu().flatten())
                    if abs_v.numel() > 1000000:
                        sample = abs_v[torch.randint(0, abs_v.numel(), (1000000,))]
                        scale = torch.quantile(sample, 0.999)
                    else:
                        scale = torch.quantile(abs_v, 0.999)
                qdq_out, scale_tensor = fp8_tensor_quant(v.cuda(), scale.cuda())
                print(f"scale_tensor for key {k}: {scale_tensor}")
                sd_pruned[k] = qdq_out.to(out_dtype)
                sd_pruned[k.replace(".weight", ".scale_weight")] = scale_tensor[0]
            else:
                sd_pruned[k] = v#.to(torch.float16)

    # Clear memory
    del current_sd
    gc.collect()
    torch.cuda.empty_cache()

sd_pruned["scaled_fp8"] = torch.tensor([0., 0.], dtype=out_dtype)

for k, v in sd_pruned.items():
    if isinstance(v, torch.Tensor):
        print(f"{k}: {v.shape} {v.dtype}")
    else:
        print(f"{k}: {type(v)}")
save_file(sd_pruned, f"{model_type}_fp8_{dtype_string}_scaled_KJ.safetensors", metadata={"format": "pt", "model_type": model_type})

Sign up or log in to comment