Spaces:
Paused
Paused
""" | |
References: | |
https://github.com/KONAKONA666/q8_kernels/blob/9cee3f3d4ca5ec8ab463179be32c8001e31f8f33/q8_kernels/utils/convert_weights.py | |
""" | |
import torch | |
import torch.nn as nn | |
from q8_kernels.modules.rms_norm import RMSNorm as QRMSNorm | |
from diffusers.models.normalization import RMSNorm | |
from q8_kernels.modules.activations import GELU as QGELU | |
from diffusers.models.activations import GELU | |
from q8_kernels.modules.linear import Q8Linear | |
from q8_attention_processors import LTXVideoQ8AttentionProcessor | |
MODULES_TO_NOT_CONVERT = ["proj_in", "time_embed", "caption_projection", "proj_out"] | |
def replace_linear(model, current_key_name=None, replaced=False): | |
for name, child in model.named_children(): | |
if current_key_name is None: | |
current_key_name = [] | |
current_key_name.append(name) | |
if isinstance(child, nn.Linear) and name not in MODULES_TO_NOT_CONVERT: | |
# Check if the current key is not in the `modules_to_not_convert` | |
current_key_name_str = ".".join(current_key_name) | |
if not any( | |
(key + "." in current_key_name_str) or (key == current_key_name_str) for key in MODULES_TO_NOT_CONVERT | |
): | |
new_linear = Q8Linear( | |
child.in_features, child.out_features, bias=child.bias is not None, device=child.weight.device | |
) | |
setattr(model, name, new_linear) | |
replaced = True | |
else: | |
replace_linear(model=child, current_key_name=current_key_name, replaced=replaced) | |
current_key_name.pop(-1) | |
return model, replaced | |
def get_parent_module_and_attr(root, dotted_name: str): | |
""" | |
Splits 'a.b.c' into: | |
- parent module = root.a.b | |
- attr_name = 'c' | |
""" | |
parts = dotted_name.split(".") | |
*parent_parts, attr_name = parts | |
parent_module = root | |
for p in parent_parts: | |
parent_module = getattr(parent_module, p) | |
return parent_module, attr_name | |
def replace_rms_norm(model): | |
modules_to_replace = [] | |
for dotted_name, module in model.named_modules(): | |
if isinstance(module, RMSNorm): | |
modules_to_replace.append((dotted_name, module)) | |
replaced = False | |
for dotted_name, module in modules_to_replace: | |
parent, attr_name = get_parent_module_and_attr(model, dotted_name) | |
new_norm = QRMSNorm( | |
dim=module.dim, | |
elementwise_affine=module.elementwise_affine, | |
) | |
setattr(parent, attr_name, new_norm) | |
replaced = True | |
return model, replaced | |
def replace_gelu(model, replaced=False): | |
for name, child in model.named_children(): | |
if isinstance(child, GELU): | |
new_gelu = QGELU( | |
dim_in=child.proj.in_features, | |
dim_out=child.proj.out_features, | |
approximate=child.approximate, | |
bias=child.proj.bias is not None, | |
) | |
setattr(model, name, new_gelu) | |
replaced = True | |
else: | |
replace_gelu(model=child, replaced=replaced) | |
return model, replaced | |
def set_attn_processors(model, processor): | |
def fn_recursive_attn_processor(name, module: torch.nn.Module, processor): | |
if hasattr(module, "set_processor"): | |
module.set_processor(processor) | |
for sub_name, child in module.named_children(): | |
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) | |
for name, module in model.named_children(): | |
fn_recursive_attn_processor(name, module, processor) | |
def attn_processors(model) -> dict: | |
processors = {} | |
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: dict): | |
if hasattr(module, "get_processor"): | |
processors[f"{name}.processor"] = module.get_processor() | |
for sub_name, child in module.named_children(): | |
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) | |
return processors | |
for name, module in model.named_children(): | |
fn_recursive_add_processors(name, module, processors) | |
return processors | |
def check_transformer_replaced_correctly(model): | |
for block in model.transformer_blocks: | |
assert isinstance(block.attn1.to_q, Q8Linear), f"{type(block.attn1.to_q)=} not linear." | |
assert isinstance(block.attn2.to_q, Q8Linear), f"{type(block.attn2.to_q)=} not linear." | |
assert block.attn1.to_q.weight.dtype == torch.int8, f"{block.attn1.to_q.weight.dtype=}." | |
assert block.attn2.to_q.weight.dtype == torch.int8, f"{name=} {block.attn2.to_q.weight.dtype=}." | |
for name, module in model.named_modules(): | |
if "norm" in name and "norm_out" not in name: | |
assert isinstance(module, QRMSNorm), f"{name=}, {type(module)=}" | |
for block in model.transformer_blocks: | |
assert isinstance(block.ff.net[0], QGELU), f"{type(block.ff.net[0])=}" | |
if getattr(block.ff.net[0], "proj", None) is not None: | |
assert block.ff.net[0].proj.weight.dtype == torch.int8, f"{block.ff.net[0].proj.weight.dtype=}." | |
set_attn_processors(model, LTXVideoQ8AttentionProcessor()) | |
all_attn_processors = attn_processors(model) | |
for k, v in all_attn_processors.items(): | |
assert isinstance(v, LTXVideoQ8AttentionProcessor), f"{name} is not of type LTXVideoQ8AttentionProcessor." | |