from inference import load_q8_transformer import hashlib from q8_kernels.graph.graph import make_dynamic_graphed_callable from argparse import Namespace from diffusers import LTXPipeline import types import torch # To account for the type-casting in `ff_output` of `LTXVideoTransformerBlock` def patched_ltx_transformer_forward( self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: torch.Tensor, image_rotary_emb = None, encoder_attention_mask = None, ) -> torch.Tensor: batch_size = hidden_states.size(0) norm_hidden_states = self.norm1(hidden_states) num_ada_params = self.scale_shift_table.shape[0] ada_values = self.scale_shift_table[None, None] + temb.reshape(batch_size, temb.size(1), num_ada_params, -1) shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ada_values.unbind(dim=2) norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa attn_hidden_states = self.attn1( hidden_states=norm_hidden_states, encoder_hidden_states=None, image_rotary_emb=image_rotary_emb, ) hidden_states = hidden_states + attn_hidden_states * gate_msa attn_hidden_states = self.attn2( hidden_states, encoder_hidden_states=encoder_hidden_states, image_rotary_emb=None, attention_mask=encoder_attention_mask, ) hidden_states = hidden_states + attn_hidden_states norm_hidden_states = self.norm2(hidden_states) * (1 + scale_mlp) + shift_mlp ff_output = self.ff(norm_hidden_states).to(norm_hidden_states.dtype) hidden_states = hidden_states + ff_output * gate_mlp return hidden_states def load_transformer(): args = Namespace() args.q8_transformer_path = "sayakpaul/q8-ltx-video" transformer = load_q8_transformer(args) transformer.to(torch.bfloat16) for b in transformer.transformer_blocks: b.to(dtype=torch.float) for n, m in transformer.transformer_blocks.named_parameters(): if "scale_shift_table" in n: m.data = m.data.to(torch.bfloat16) for b in transformer.transformer_blocks: b.forward = types.MethodType(patched_ltx_transformer_forward, b) transformer.forward = make_dynamic_graphed_callable(transformer.forward) return transformer def warmup_transformer(pipe): prompt_embeds = torch.load("prompt_embeds.pt", map_location="cuda", weights_only=True) for _ in range(5): _ = pipe( **prompt_embeds, output_type="latent", width=768, height=512, num_frames=121 ) def prepare_pipeline(): pipe = LTXPipeline.from_pretrained("Lightricks/LTX-Video", text_encoder=None, torch_dtype=torch.bfloat16) pipe.transformer = load_transformer() pipe = pipe.to("cuda") pipe.transformer.compile() pipe.set_progress_bar_config(disable=True) warmup_transformer(pipe) return pipe def compute_hash(text: str) -> str: # Encode the text to bytes text_bytes = text.encode("utf-8") # Create a SHA-256 hash object hash_object = hashlib.sha256() # Update the hash object with the text bytes hash_object.update(text_bytes) # Return the hexadecimal representation of the hash return hash_object.hexdigest()