q8-ltx-video / q8_ltx.py
sayakpaul's picture
sayakpaul HF staff
Upload 14 files
8cf98bd verified
"""
References:
https://github.com/KONAKONA666/q8_kernels/blob/9cee3f3d4ca5ec8ab463179be32c8001e31f8f33/q8_kernels/utils/convert_weights.py
"""
import torch
import torch.nn as nn
from q8_kernels.modules.rms_norm import RMSNorm as QRMSNorm
from diffusers.models.normalization import RMSNorm
from q8_kernels.modules.activations import GELU as QGELU
from diffusers.models.activations import GELU
from q8_kernels.modules.linear import Q8Linear
from q8_attention_processors import LTXVideoQ8AttentionProcessor
MODULES_TO_NOT_CONVERT = ["proj_in", "time_embed", "caption_projection", "proj_out"]
def replace_linear(model, current_key_name=None, replaced=False):
for name, child in model.named_children():
if current_key_name is None:
current_key_name = []
current_key_name.append(name)
if isinstance(child, nn.Linear) and name not in MODULES_TO_NOT_CONVERT:
# Check if the current key is not in the `modules_to_not_convert`
current_key_name_str = ".".join(current_key_name)
if not any(
(key + "." in current_key_name_str) or (key == current_key_name_str) for key in MODULES_TO_NOT_CONVERT
):
new_linear = Q8Linear(
child.in_features, child.out_features, bias=child.bias is not None, device=child.weight.device
)
setattr(model, name, new_linear)
replaced = True
else:
replace_linear(model=child, current_key_name=current_key_name, replaced=replaced)
current_key_name.pop(-1)
return model, replaced
def get_parent_module_and_attr(root, dotted_name: str):
"""
Splits 'a.b.c' into:
- parent module = root.a.b
- attr_name = 'c'
"""
parts = dotted_name.split(".")
*parent_parts, attr_name = parts
parent_module = root
for p in parent_parts:
parent_module = getattr(parent_module, p)
return parent_module, attr_name
def replace_rms_norm(model):
modules_to_replace = []
for dotted_name, module in model.named_modules():
if isinstance(module, RMSNorm):
modules_to_replace.append((dotted_name, module))
replaced = False
for dotted_name, module in modules_to_replace:
parent, attr_name = get_parent_module_and_attr(model, dotted_name)
new_norm = QRMSNorm(
dim=module.dim,
elementwise_affine=module.elementwise_affine,
)
setattr(parent, attr_name, new_norm)
replaced = True
return model, replaced
def replace_gelu(model, replaced=False):
for name, child in model.named_children():
if isinstance(child, GELU):
new_gelu = QGELU(
dim_in=child.proj.in_features,
dim_out=child.proj.out_features,
approximate=child.approximate,
bias=child.proj.bias is not None,
)
setattr(model, name, new_gelu)
replaced = True
else:
replace_gelu(model=child, replaced=replaced)
return model, replaced
def set_attn_processors(model, processor):
def fn_recursive_attn_processor(name, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
module.set_processor(processor)
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for name, module in model.named_children():
fn_recursive_attn_processor(name, module, processor)
def attn_processors(model) -> dict:
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: dict):
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor()
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
for name, module in model.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
def check_transformer_replaced_correctly(model):
for block in model.transformer_blocks:
assert isinstance(block.attn1.to_q, Q8Linear), f"{type(block.attn1.to_q)=} not linear."
assert isinstance(block.attn2.to_q, Q8Linear), f"{type(block.attn2.to_q)=} not linear."
assert block.attn1.to_q.weight.dtype == torch.int8, f"{block.attn1.to_q.weight.dtype=}."
assert block.attn2.to_q.weight.dtype == torch.int8, f"{name=} {block.attn2.to_q.weight.dtype=}."
for name, module in model.named_modules():
if "norm" in name and "norm_out" not in name:
assert isinstance(module, QRMSNorm), f"{name=}, {type(module)=}"
for block in model.transformer_blocks:
assert isinstance(block.ff.net[0], QGELU), f"{type(block.ff.net[0])=}"
if getattr(block.ff.net[0], "proj", None) is not None:
assert block.ff.net[0].proj.weight.dtype == torch.int8, f"{block.ff.net[0].proj.weight.dtype=}."
set_attn_processors(model, LTXVideoQ8AttentionProcessor())
all_attn_processors = attn_processors(model)
for k, v in all_attn_processors.items():
assert isinstance(v, LTXVideoQ8AttentionProcessor), f"{name} is not of type LTXVideoQ8AttentionProcessor."