Spaces:
Paused
Paused
| import torch | |
| from typing import Union, Tuple, List, Optional | |
| import numpy as np | |
| ###### Thanks to the RifleX project (https://github.com/thu-ml/RIFLEx/) for this alternative pos embed for long videos | |
| # | |
| def get_1d_rotary_pos_embed_riflex( | |
| dim: int, | |
| pos: Union[np.ndarray, int], | |
| theta: float = 10000.0, | |
| use_real=False, | |
| k: Optional[int] = None, | |
| L_test: Optional[int] = None, | |
| ): | |
| """ | |
| RIFLEx: Precompute the frequency tensor for complex exponentials (cis) with given dimensions. | |
| This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end | |
| index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64 | |
| data type. | |
| Args: | |
| dim (`int`): Dimension of the frequency tensor. | |
| pos (`np.ndarray` or `int`): Position indices for the frequency tensor. [S] or scalar | |
| theta (`float`, *optional*, defaults to 10000.0): | |
| Scaling factor for frequency computation. Defaults to 10000.0. | |
| use_real (`bool`, *optional*): | |
| If True, return real part and imaginary part separately. Otherwise, return complex numbers. | |
| k (`int`, *optional*, defaults to None): the index for the intrinsic frequency in RoPE | |
| L_test (`int`, *optional*, defaults to None): the number of frames for inference | |
| Returns: | |
| `torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2] | |
| """ | |
| assert dim % 2 == 0 | |
| if isinstance(pos, int): | |
| pos = torch.arange(pos) | |
| if isinstance(pos, np.ndarray): | |
| pos = torch.from_numpy(pos) # type: ignore # [S] | |
| freqs = 1.0 / ( | |
| theta ** (torch.arange(0, dim, 2, device=pos.device)[: (dim // 2)].float() / dim) | |
| ) # [D/2] | |
| # === Riflex modification start === | |
| # Reduce the intrinsic frequency to stay within a single period after extrapolation (see Eq. (8)). | |
| # Empirical observations show that a few videos may exhibit repetition in the tail frames. | |
| # To be conservative, we multiply by 0.9 to keep the extrapolated length below 90% of a single period. | |
| if k is not None: | |
| freqs[k-1] = 0.9 * 2 * torch.pi / L_test | |
| # === Riflex modification end === | |
| freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2] | |
| if use_real: | |
| freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D] | |
| freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D] | |
| return freqs_cos, freqs_sin | |
| else: | |
| # lumina | |
| freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2] | |
| return freqs_cis | |
| def identify_k( b: float, d: int, N: int): | |
| """ | |
| This function identifies the index of the intrinsic frequency component in a RoPE-based pre-trained diffusion transformer. | |
| Args: | |
| b (`float`): The base frequency for RoPE. | |
| d (`int`): Dimension of the frequency tensor | |
| N (`int`): the first observed repetition frame in latent space | |
| Returns: | |
| k (`int`): the index of intrinsic frequency component | |
| N_k (`int`): the period of intrinsic frequency component in latent space | |
| Example: | |
| In HunyuanVideo, b=256 and d=16, the repetition occurs approximately 8s (N=48 in latent space). | |
| k, N_k = identify_k(b=256, d=16, N=48) | |
| In this case, the intrinsic frequency index k is 4, and the period N_k is 50. | |
| """ | |
| # Compute the period of each frequency in RoPE according to Eq.(4) | |
| periods = [] | |
| for j in range(1, d // 2 + 1): | |
| theta_j = 1.0 / (b ** (2 * (j - 1) / d)) | |
| N_j = round(2 * torch.pi / theta_j) | |
| periods.append(N_j) | |
| # Identify the intrinsic frequency whose period is closed to N(see Eq.(7)) | |
| diffs = [abs(N_j - N) for N_j in periods] | |
| k = diffs.index(min(diffs)) + 1 | |
| N_k = periods[k-1] | |
| return k, N_k | |
| def _to_tuple(x, dim=2): | |
| if isinstance(x, int): | |
| return (x,) * dim | |
| elif len(x) == dim: | |
| return x | |
| else: | |
| raise ValueError(f"Expected length {dim} or int, but got {x}") | |
| def get_meshgrid_nd(start, *args, dim=2): | |
| """ | |
| Get n-D meshgrid with start, stop and num. | |
| Args: | |
| start (int or tuple): If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop, | |
| step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num. For n-dim, start/stop/num | |
| should be int or n-tuple. If n-tuple is provided, the meshgrid will be stacked following the dim order in | |
| n-tuples. | |
| *args: See above. | |
| dim (int): Dimension of the meshgrid. Defaults to 2. | |
| Returns: | |
| grid (np.ndarray): [dim, ...] | |
| """ | |
| if len(args) == 0: | |
| # start is grid_size | |
| num = _to_tuple(start, dim=dim) | |
| start = (0,) * dim | |
| stop = num | |
| elif len(args) == 1: | |
| # start is start, args[0] is stop, step is 1 | |
| start = _to_tuple(start, dim=dim) | |
| stop = _to_tuple(args[0], dim=dim) | |
| num = [stop[i] - start[i] for i in range(dim)] | |
| elif len(args) == 2: | |
| # start is start, args[0] is stop, args[1] is num | |
| start = _to_tuple(start, dim=dim) # Left-Top eg: 12,0 | |
| stop = _to_tuple(args[0], dim=dim) # Right-Bottom eg: 20,32 | |
| num = _to_tuple(args[1], dim=dim) # Target Size eg: 32,124 | |
| else: | |
| raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}") | |
| # PyTorch implement of np.linspace(start[i], stop[i], num[i], endpoint=False) | |
| axis_grid = [] | |
| for i in range(dim): | |
| a, b, n = start[i], stop[i], num[i] | |
| g = torch.linspace(a, b, n + 1, dtype=torch.float32)[:n] | |
| axis_grid.append(g) | |
| grid = torch.meshgrid(*axis_grid, indexing="ij") # dim x [W, H, D] | |
| grid = torch.stack(grid, dim=0) # [dim, W, H, D] | |
| return grid | |
| ################################################################################# | |
| # Rotary Positional Embedding Functions # | |
| ################################################################################# | |
| # https://github.com/meta-llama/llama/blob/be327c427cc5e89cc1d3ab3d3fec4484df771245/llama/model.py#L80 | |
| def reshape_for_broadcast( | |
| freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], | |
| x: torch.Tensor, | |
| head_first=False, | |
| ): | |
| """ | |
| Reshape frequency tensor for broadcasting it with another tensor. | |
| This function reshapes the frequency tensor to have the same shape as the target tensor 'x' | |
| for the purpose of broadcasting the frequency tensor during element-wise operations. | |
| Notes: | |
| When using FlashMHAModified, head_first should be False. | |
| When using Attention, head_first should be True. | |
| Args: | |
| freqs_cis (Union[torch.Tensor, Tuple[torch.Tensor]]): Frequency tensor to be reshaped. | |
| x (torch.Tensor): Target tensor for broadcasting compatibility. | |
| head_first (bool): head dimension first (except batch dim) or not. | |
| Returns: | |
| torch.Tensor: Reshaped frequency tensor. | |
| Raises: | |
| AssertionError: If the frequency tensor doesn't match the expected shape. | |
| AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions. | |
| """ | |
| ndim = x.ndim | |
| assert 0 <= 1 < ndim | |
| if isinstance(freqs_cis, tuple): | |
| # freqs_cis: (cos, sin) in real space | |
| if head_first: | |
| assert freqs_cis[0].shape == ( | |
| x.shape[-2], | |
| x.shape[-1], | |
| ), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}" | |
| shape = [ | |
| d if i == ndim - 2 or i == ndim - 1 else 1 | |
| for i, d in enumerate(x.shape) | |
| ] | |
| else: | |
| assert freqs_cis[0].shape == ( | |
| x.shape[1], | |
| x.shape[-1], | |
| ), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}" | |
| shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] | |
| return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape) | |
| else: | |
| # freqs_cis: values in complex space | |
| if head_first: | |
| assert freqs_cis.shape == ( | |
| x.shape[-2], | |
| x.shape[-1], | |
| ), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}" | |
| shape = [ | |
| d if i == ndim - 2 or i == ndim - 1 else 1 | |
| for i, d in enumerate(x.shape) | |
| ] | |
| else: | |
| assert freqs_cis.shape == ( | |
| x.shape[1], | |
| x.shape[-1], | |
| ), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}" | |
| shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] | |
| return freqs_cis.view(*shape) | |
| def rotate_half(x): | |
| x_real, x_imag = ( | |
| x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1) | |
| ) # [B, S, H, D//2] | |
| return torch.stack([-x_imag, x_real], dim=-1).flatten(3) | |
| def apply_rotary_emb( qklist, | |
| freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], | |
| head_first: bool = False, | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """ | |
| Apply rotary embeddings to input tensors using the given frequency tensor. | |
| This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided | |
| frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor | |
| is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are | |
| returned as real tensors. | |
| Args: | |
| xq (torch.Tensor): Query tensor to apply rotary embeddings. [B, S, H, D] | |
| xk (torch.Tensor): Key tensor to apply rotary embeddings. [B, S, H, D] | |
| freqs_cis (torch.Tensor or tuple): Precomputed frequency tensor for complex exponential. | |
| head_first (bool): head dimension first (except batch dim) or not. | |
| Returns: | |
| Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. | |
| """ | |
| xq, xk = qklist | |
| qklist.clear() | |
| xk_out = None | |
| if isinstance(freqs_cis, tuple): | |
| cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first) # [S, D] | |
| cos, sin = cos.to(xq.device), sin.to(xq.device) | |
| # real * cos - imag * sin | |
| # imag * cos + real * sin | |
| xq_dtype = xq.dtype | |
| xq_out = xq.to(torch.float) | |
| xq = None | |
| xq_rot = rotate_half(xq_out) | |
| xq_out *= cos | |
| xq_rot *= sin | |
| xq_out += xq_rot | |
| del xq_rot | |
| xq_out = xq_out.to(xq_dtype) | |
| xk_out = xk.to(torch.float) | |
| xk = None | |
| xk_rot = rotate_half(xk_out) | |
| xk_out *= cos | |
| xk_rot *= sin | |
| xk_out += xk_rot | |
| del xk_rot | |
| xk_out = xk_out.to(xq_dtype) | |
| else: | |
| # view_as_complex will pack [..., D/2, 2](real) to [..., D/2](complex) | |
| xq_ = torch.view_as_complex( | |
| xq.float().reshape(*xq.shape[:-1], -1, 2) | |
| ) # [B, S, H, D//2] | |
| freqs_cis = reshape_for_broadcast(freqs_cis, xq_, head_first).to( | |
| xq.device | |
| ) # [S, D//2] --> [1, S, 1, D//2] | |
| # (real, imag) * (cos, sin) = (real * cos - imag * sin, imag * cos + real * sin) | |
| # view_as_real will expand [..., D/2](complex) to [..., D/2, 2](real) | |
| xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3).type_as(xq) | |
| xk_ = torch.view_as_complex( | |
| xk.float().reshape(*xk.shape[:-1], -1, 2) | |
| ) # [B, S, H, D//2] | |
| xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3).type_as(xk) | |
| return xq_out, xk_out | |
| def get_nd_rotary_pos_embed_new(rope_dim_list, start, *args, theta=10000., use_real=False, | |
| theta_rescale_factor: Union[float, List[float]]=1.0, | |
| interpolation_factor: Union[float, List[float]]=1.0, | |
| concat_dict={}, | |
| k = 4, | |
| L_test = 66, | |
| enable_riflex = True | |
| ): | |
| grid = get_meshgrid_nd(start, *args, dim=len(rope_dim_list)) # [3, W, H, D] / [2, W, H] | |
| if len(concat_dict)<1: | |
| pass | |
| else: | |
| if concat_dict['mode']=='timecat': | |
| bias = grid[:,:1].clone() | |
| bias[0] = concat_dict['bias']*torch.ones_like(bias[0]) | |
| grid = torch.cat([bias, grid], dim=1) | |
| elif concat_dict['mode']=='timecat-w': | |
| bias = grid[:,:1].clone() | |
| bias[0] = concat_dict['bias']*torch.ones_like(bias[0]) | |
| bias[2] += start[-1] ## ref https://github.com/Yuanshi9815/OminiControl/blob/main/src/generate.py#L178 | |
| grid = torch.cat([bias, grid], dim=1) | |
| if isinstance(theta_rescale_factor, int) or isinstance(theta_rescale_factor, float): | |
| theta_rescale_factor = [theta_rescale_factor] * len(rope_dim_list) | |
| elif isinstance(theta_rescale_factor, list) and len(theta_rescale_factor) == 1: | |
| theta_rescale_factor = [theta_rescale_factor[0]] * len(rope_dim_list) | |
| assert len(theta_rescale_factor) == len(rope_dim_list), "len(theta_rescale_factor) should equal to len(rope_dim_list)" | |
| if isinstance(interpolation_factor, int) or isinstance(interpolation_factor, float): | |
| interpolation_factor = [interpolation_factor] * len(rope_dim_list) | |
| elif isinstance(interpolation_factor, list) and len(interpolation_factor) == 1: | |
| interpolation_factor = [interpolation_factor[0]] * len(rope_dim_list) | |
| assert len(interpolation_factor) == len(rope_dim_list), "len(interpolation_factor) should equal to len(rope_dim_list)" | |
| # use 1/ndim of dimensions to encode grid_axis | |
| embs = [] | |
| for i in range(len(rope_dim_list)): | |
| # === RIFLEx modification start === | |
| # apply RIFLEx for time dimension | |
| if i == 0 and enable_riflex: | |
| emb = get_1d_rotary_pos_embed_riflex(rope_dim_list[i], grid[i].reshape(-1), theta, use_real=True, k=k, L_test=L_test) | |
| # === RIFLEx modification end === | |
| else: | |
| emb = get_1d_rotary_pos_embed(rope_dim_list[i], grid[i].reshape(-1), theta, use_real=True, theta_rescale_factor=theta_rescale_factor[i],interpolation_factor=interpolation_factor[i],) | |
| # emb = get_1d_rotary_pos_embed(rope_dim_list[i], grid[i].reshape(-1), theta, use_real=use_real, | |
| # theta_rescale_factor=theta_rescale_factor[i], | |
| # w interpolation_factor=interpolation_factor[i]) # 2 x [WHD, rope_dim_list[i]] | |
| embs.append(emb) | |
| if use_real: | |
| cos = torch.cat([emb[0] for emb in embs], dim=1) # (WHD, D/2) | |
| sin = torch.cat([emb[1] for emb in embs], dim=1) # (WHD, D/2) | |
| return cos, sin | |
| else: | |
| emb = torch.cat(embs, dim=1) # (WHD, D/2) | |
| return emb | |
| def get_nd_rotary_pos_embed( | |
| rope_dim_list, | |
| start, | |
| *args, | |
| theta=10000.0, | |
| use_real=False, | |
| theta_rescale_factor: Union[float, List[float]] = 1.0, | |
| interpolation_factor: Union[float, List[float]] = 1.0, | |
| k = 4, | |
| L_test = 66, | |
| enable_riflex = True | |
| ): | |
| """ | |
| This is a n-d version of precompute_freqs_cis, which is a RoPE for tokens with n-d structure. | |
| Args: | |
| rope_dim_list (list of int): Dimension of each rope. len(rope_dim_list) should equal to n. | |
| sum(rope_dim_list) should equal to head_dim of attention layer. | |
| start (int | tuple of int | list of int): If len(args) == 0, start is num; If len(args) == 1, start is start, | |
| args[0] is stop, step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num. | |
| *args: See above. | |
| theta (float): Scaling factor for frequency computation. Defaults to 10000.0. | |
| use_real (bool): If True, return real part and imaginary part separately. Otherwise, return complex numbers. | |
| Some libraries such as TensorRT does not support complex64 data type. So it is useful to provide a real | |
| part and an imaginary part separately. | |
| theta_rescale_factor (float): Rescale factor for theta. Defaults to 1.0. | |
| Returns: | |
| pos_embed (torch.Tensor): [HW, D/2] | |
| """ | |
| grid = get_meshgrid_nd( | |
| start, *args, dim=len(rope_dim_list) | |
| ) # [3, W, H, D] / [2, W, H] | |
| if isinstance(theta_rescale_factor, int) or isinstance(theta_rescale_factor, float): | |
| theta_rescale_factor = [theta_rescale_factor] * len(rope_dim_list) | |
| elif isinstance(theta_rescale_factor, list) and len(theta_rescale_factor) == 1: | |
| theta_rescale_factor = [theta_rescale_factor[0]] * len(rope_dim_list) | |
| assert len(theta_rescale_factor) == len( | |
| rope_dim_list | |
| ), "len(theta_rescale_factor) should equal to len(rope_dim_list)" | |
| if isinstance(interpolation_factor, int) or isinstance(interpolation_factor, float): | |
| interpolation_factor = [interpolation_factor] * len(rope_dim_list) | |
| elif isinstance(interpolation_factor, list) and len(interpolation_factor) == 1: | |
| interpolation_factor = [interpolation_factor[0]] * len(rope_dim_list) | |
| assert len(interpolation_factor) == len( | |
| rope_dim_list | |
| ), "len(interpolation_factor) should equal to len(rope_dim_list)" | |
| # use 1/ndim of dimensions to encode grid_axis | |
| embs = [] | |
| for i in range(len(rope_dim_list)): | |
| # emb = get_1d_rotary_pos_embed( | |
| # rope_dim_list[i], | |
| # grid[i].reshape(-1), | |
| # theta, | |
| # use_real=use_real, | |
| # theta_rescale_factor=theta_rescale_factor[i], | |
| # interpolation_factor=interpolation_factor[i], | |
| # ) # 2 x [WHD, rope_dim_list[i]] | |
| # === RIFLEx modification start === | |
| # apply RIFLEx for time dimension | |
| if i == 0 and enable_riflex: | |
| emb = get_1d_rotary_pos_embed_riflex(rope_dim_list[i], grid[i].reshape(-1), theta, use_real=True, k=k, L_test=L_test) | |
| # === RIFLEx modification end === | |
| else: | |
| emb = get_1d_rotary_pos_embed(rope_dim_list[i], grid[i].reshape(-1), theta, use_real=True, theta_rescale_factor=theta_rescale_factor[i],interpolation_factor=interpolation_factor[i],) | |
| embs.append(emb) | |
| if use_real: | |
| cos = torch.cat([emb[0] for emb in embs], dim=1) # (WHD, D/2) | |
| sin = torch.cat([emb[1] for emb in embs], dim=1) # (WHD, D/2) | |
| return cos, sin | |
| else: | |
| emb = torch.cat(embs, dim=1) # (WHD, D/2) | |
| return emb | |
| def get_1d_rotary_pos_embed( | |
| dim: int, | |
| pos: Union[torch.FloatTensor, int], | |
| theta: float = 10000.0, | |
| use_real: bool = False, | |
| theta_rescale_factor: float = 1.0, | |
| interpolation_factor: float = 1.0, | |
| ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: | |
| """ | |
| Precompute the frequency tensor for complex exponential (cis) with given dimensions. | |
| (Note: `cis` means `cos + i * sin`, where i is the imaginary unit.) | |
| This function calculates a frequency tensor with complex exponential using the given dimension 'dim' | |
| and the end index 'end'. The 'theta' parameter scales the frequencies. | |
| The returned tensor contains complex values in complex64 data type. | |
| Args: | |
| dim (int): Dimension of the frequency tensor. | |
| pos (int or torch.FloatTensor): Position indices for the frequency tensor. [S] or scalar | |
| theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0. | |
| use_real (bool, optional): If True, return real part and imaginary part separately. | |
| Otherwise, return complex numbers. | |
| theta_rescale_factor (float, optional): Rescale factor for theta. Defaults to 1.0. | |
| Returns: | |
| freqs_cis: Precomputed frequency tensor with complex exponential. [S, D/2] | |
| freqs_cos, freqs_sin: Precomputed frequency tensor with real and imaginary parts separately. [S, D] | |
| """ | |
| if isinstance(pos, int): | |
| pos = torch.arange(pos).float() | |
| # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning | |
| # has some connection to NTK literature | |
| if theta_rescale_factor != 1.0: | |
| theta *= theta_rescale_factor ** (dim / (dim - 2)) | |
| freqs = 1.0 / ( | |
| theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim) | |
| ) # [D/2] | |
| # assert interpolation_factor == 1.0, f"interpolation_factor: {interpolation_factor}" | |
| freqs = torch.outer(pos * interpolation_factor, freqs) # [S, D/2] | |
| if use_real: | |
| freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D] | |
| freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D] | |
| return freqs_cos, freqs_sin | |
| else: | |
| freqs_cis = torch.polar( | |
| torch.ones_like(freqs), freqs | |
| ) # complex64 # [S, D/2] | |
| return freqs_cis | |