| 
							 | 
						""" | 
					
					
						
						| 
							 | 
						    Author: Eric Lin (xihlin) | 
					
					
						
						| 
							 | 
						""" | 
					
					
						
						| 
							 | 
						""" | 
					
					
						
						| 
							 | 
						    ... note(bapatra):: | 
					
					
						
						| 
							 | 
						        This is written as one big file, instead of splitting into logical components because I was running into issues with transformers auto module | 
					
					
						
						| 
							 | 
						        imports when splitting into different files. I've tried keeping the logical partitions demarkated with comment blocks, but it is not ideal. | 
					
					
						
						| 
							 | 
						        In the future, would be really good to revisit this and refactor into a more readable file structure. | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						""" | 
					
					
						
						| 
							 | 
						from typing import TypeVar | 
					
					
						
						| 
							 | 
						from functools import lru_cache | 
					
					
						
						| 
							 | 
						import math | 
					
					
						
						| 
							 | 
						import pytest | 
					
					
						
						| 
							 | 
						import torch | 
					
					
						
						| 
							 | 
						import numpy as np | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						import triton | 
					
					
						
						| 
							 | 
						import triton.language as tl | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						import os | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						import dataclasses | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						Phi3SmallConfig = TypeVar('Phi3SmallConfig') | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						@dataclasses.dataclass | 
					
					
						
						| 
							 | 
						class BlockSparseParams(object): | 
					
					
						
						| 
							 | 
						    block_size: int | 
					
					
						
						| 
							 | 
						    kernel_block_size: int | 
					
					
						
						| 
							 | 
						    num_local_blocks: int | 
					
					
						
						| 
							 | 
						    vert_stride: int | 
					
					
						
						| 
							 | 
						    homo_head_pattern: bool = False | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    @classmethod | 
					
					
						
						| 
							 | 
						    def from_config(cls, config: Phi3SmallConfig) -> "BlockSparseParams": | 
					
					
						
						| 
							 | 
						        return cls( | 
					
					
						
						| 
							 | 
						            block_size=config.blocksparse_block_size, | 
					
					
						
						| 
							 | 
						            kernel_block_size=config.blocksparse_triton_kernel_block_size, | 
					
					
						
						| 
							 | 
						            num_local_blocks=config.blocksparse_num_local_blocks, | 
					
					
						
						| 
							 | 
						            vert_stride=config.blocksparse_vert_stride, | 
					
					
						
						| 
							 | 
						            homo_head_pattern=config.blocksparse_homo_head_pattern, | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						def dense_to_crow_col(x): | 
					
					
						
						| 
							 | 
						    ''' Turning a 2D/3D torch tensor (x) to CSR rows/cols indexing. | 
					
					
						
						| 
							 | 
						    param: | 
					
					
						
						| 
							 | 
						    TODO: | 
					
					
						
						| 
							 | 
						        1. improve efficiency, is it faster if done in CPU, or customize a cuda kernel for it? | 
					
					
						
						| 
							 | 
						    NOTE: col_indices padded -1 | 
					
					
						
						| 
							 | 
						    ''' | 
					
					
						
						| 
							 | 
						    pad = -1 | 
					
					
						
						| 
							 | 
						    dim = x.dim() | 
					
					
						
						| 
							 | 
						    assert x.dim() in (2, 3) | 
					
					
						
						| 
							 | 
						    if x.dim() == 2: | 
					
					
						
						| 
							 | 
						        x = x[None] | 
					
					
						
						| 
							 | 
						    x = [xi.to_sparse_csr() for xi in x] | 
					
					
						
						| 
							 | 
						    crows = torch.vstack([xi.crow_indices() for xi in x]) | 
					
					
						
						| 
							 | 
						    cols = [xi.col_indices() for xi in x] | 
					
					
						
						| 
							 | 
						    max_cols = max(len(xi) for xi in cols) | 
					
					
						
						| 
							 | 
						    cols = [torch.cat([xi, pad + xi.new_zeros(max_cols - xi.shape[0])]) for xi in cols] | 
					
					
						
						| 
							 | 
						    cols = torch.vstack(cols) | 
					
					
						
						| 
							 | 
						    if dim == 2: | 
					
					
						
						| 
							 | 
						        crows = crows[0] | 
					
					
						
						| 
							 | 
						        cols = cols[0] | 
					
					
						
						| 
							 | 
						    return crows, cols | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def crow_col_to_dense(crows, cols, dtype=torch.float16): | 
					
					
						
						| 
							 | 
						    dim = crows.dim() | 
					
					
						
						| 
							 | 
						    if dim == 1: | 
					
					
						
						| 
							 | 
						        crows = crows[None] | 
					
					
						
						| 
							 | 
						        cols = cols[None] | 
					
					
						
						| 
							 | 
						    device = crows.device | 
					
					
						
						| 
							 | 
						    crows, cols = crows.cpu(), cols.cpu()   | 
					
					
						
						| 
							 | 
						    shape = (crows.shape[0], crows.shape[1] - 1, cols.max() + 1) | 
					
					
						
						| 
							 | 
						    x = torch.zeros(shape, dtype=dtype) | 
					
					
						
						| 
							 | 
						    for i in range(shape[0]): | 
					
					
						
						| 
							 | 
						        for j in range(shape[1]): | 
					
					
						
						| 
							 | 
						            x[i, j, cols[i, crows[i, j]:crows[i, j+1]]] = 1 | 
					
					
						
						| 
							 | 
						    if dim == 1: | 
					
					
						
						| 
							 | 
						        x = x[0] | 
					
					
						
						| 
							 | 
						    return x.to(device) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def dense_to_ccol_row(x): | 
					
					
						
						| 
							 | 
						    '''Similar, but to CSC format | 
					
					
						
						| 
							 | 
						    ''' | 
					
					
						
						| 
							 | 
						    x = x.transpose(-2, -1) | 
					
					
						
						| 
							 | 
						    return dense_to_crow_col(x) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def ccol_row_to_dense(ccol, rows, dtype=torch.float16): | 
					
					
						
						| 
							 | 
						    return crow_col_to_dense(ccol, rows, dtype).permute(0, 2, 1).contiguous() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def _get_sparse_attn_mask_homo_head(q_len, N_CTX, dtype, device, BLOCK=128, local_blocks=4, vert_stride=4, return_dense=False): | 
					
					
						
						| 
							 | 
						    ''' | 
					
					
						
						| 
							 | 
						    :return: a tuple of 3: | 
					
					
						
						| 
							 | 
						        - tuple of crow_indices, col_indices representation of CSR format. | 
					
					
						
						| 
							 | 
						        - block dense mask | 
					
					
						
						| 
							 | 
						        - all token dense mask (be aware that it can be OOM if it is too big) if `return_dense==True`, otherwise, None | 
					
					
						
						| 
							 | 
						    ''' | 
					
					
						
						| 
							 | 
						    with torch.no_grad(): | 
					
					
						
						| 
							 | 
						        N_BLOCK = triton.cdiv(N_CTX, BLOCK) | 
					
					
						
						| 
							 | 
						        q_pos = torch.arange(N_BLOCK)[:, None] | 
					
					
						
						| 
							 | 
						        k_pos = torch.arange(N_BLOCK)[None] | 
					
					
						
						| 
							 | 
						        mask_vert_strided = (torch.arange(N_BLOCK) + 1) % vert_stride == 0 | 
					
					
						
						| 
							 | 
						        block_mask_dense = ((q_pos >= k_pos) & ((q_pos - k_pos < local_blocks) | mask_vert_strided)).to(device).to(dtype) | 
					
					
						
						| 
							 | 
						        N_BLOCK_Q = triton.cdiv(q_len, BLOCK) | 
					
					
						
						| 
							 | 
						        block_mask_dense_output = block_mask_dense[-N_BLOCK_Q:].contiguous().to_sparse_csr() | 
					
					
						
						| 
							 | 
						    if return_dense: | 
					
					
						
						| 
							 | 
						        mask_dense = torch.kron(block_mask_dense, block_mask_dense.new_ones((BLOCK, BLOCK))) | 
					
					
						
						| 
							 | 
						        causal_mask = torch.tril(torch.ones(N_CTX, N_CTX)).type_as(mask_dense)[-q_len:] | 
					
					
						
						| 
							 | 
						        mask_dense = mask_dense[-q_len:, :N_CTX] * causal_mask | 
					
					
						
						| 
							 | 
						        return (block_mask_dense_output.crow_indices(), block_mask_dense_output.col_indices()), block_mask_dense, mask_dense | 
					
					
						
						| 
							 | 
						    else: | 
					
					
						
						| 
							 | 
						        return (block_mask_dense_output.crow_indices(), block_mask_dense_output.col_indices()), block_mask_dense, None | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def _get_sparse_attn_mask(n_heads, q_len, N_CTX, dtype, device, BLOCK=128, local_blocks=4, vert_stride=4, homo_head=True, return_dense=False): | 
					
					
						
						| 
							 | 
						    ''' | 
					
					
						
						| 
							 | 
						    :return: a tuple of 3: | 
					
					
						
						| 
							 | 
						        - tuple of crow_indices, col_indices representation of CSR format. | 
					
					
						
						| 
							 | 
						        - block dense mask | 
					
					
						
						| 
							 | 
						        - all token dense mask (be aware that it can be OOM if it is too big) if `return_dense==True`, otherwise, None | 
					
					
						
						| 
							 | 
						    ''' | 
					
					
						
						| 
							 | 
						    if homo_head: | 
					
					
						
						| 
							 | 
						        with torch.no_grad(): | 
					
					
						
						| 
							 | 
						            (crow, col), block_mask_dense, mask_dense = _get_sparse_attn_mask_homo_head(q_len, N_CTX, dtype, device, BLOCK, local_blocks, vert_stride, return_dense) | 
					
					
						
						| 
							 | 
						            crow = crow[None].expand(n_heads, crow.shape[0]) | 
					
					
						
						| 
							 | 
						            col = col[None].expand(n_heads, col.shape[0]) | 
					
					
						
						| 
							 | 
						            if return_dense: | 
					
					
						
						| 
							 | 
						                mask_dense = mask_dense[None].expand(n_heads, *mask_dense.shape) | 
					
					
						
						| 
							 | 
						            return (crow, col), block_mask_dense, mask_dense | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    with torch.no_grad(): | 
					
					
						
						| 
							 | 
						        N_BLOCK = triton.cdiv(N_CTX, BLOCK) | 
					
					
						
						| 
							 | 
						        q_pos = torch.arange(N_BLOCK)[None, :, None] | 
					
					
						
						| 
							 | 
						        k_pos = torch.arange(N_BLOCK)[None, None] | 
					
					
						
						| 
							 | 
						        head_sliding_step = max(1, int(vert_stride / n_heads))   | 
					
					
						
						| 
							 | 
						        mask_vert_strided = [(torch.arange(N_BLOCK) + h * head_sliding_step + 1) % vert_stride == 0 for h in range(n_heads)] | 
					
					
						
						| 
							 | 
						        mask_vert_strided = torch.vstack(mask_vert_strided).unsqueeze(1) | 
					
					
						
						| 
							 | 
						        block_mask_dense = ((q_pos >= k_pos) & ((q_pos - k_pos < local_blocks) | mask_vert_strided)).to(device).to(dtype) | 
					
					
						
						| 
							 | 
						        N_BLOCK_Q = triton.cdiv(q_len, BLOCK) | 
					
					
						
						| 
							 | 
						        block_mask_dense_output = block_mask_dense[:, -N_BLOCK_Q:] | 
					
					
						
						| 
							 | 
						    if return_dense: | 
					
					
						
						| 
							 | 
						        mask_dense = torch.kron(block_mask_dense, block_mask_dense.new_ones((BLOCK, BLOCK))) | 
					
					
						
						| 
							 | 
						        causal_mask = torch.tril(torch.ones(N_CTX, N_CTX)).type_as(mask_dense)[-q_len:] | 
					
					
						
						| 
							 | 
						        mask_dense = mask_dense[..., -q_len:, :N_CTX] * causal_mask[None] | 
					
					
						
						| 
							 | 
						        return dense_to_crow_col(block_mask_dense_output), block_mask_dense, mask_dense | 
					
					
						
						| 
							 | 
						    else: | 
					
					
						
						| 
							 | 
						        return dense_to_crow_col(block_mask_dense_output), block_mask_dense, None | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def get_sparse_attn_mask(q, N_CTX, *args, **kwargs): | 
					
					
						
						| 
							 | 
						    return _get_sparse_attn_mask(q.size(1), q.size(2), N_CTX, q.dtype, q.device, *args, **kwargs) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						@triton.jit | 
					
					
						
						| 
							 | 
						def _fwd_kernel( | 
					
					
						
						| 
							 | 
						    Q, K, V, sm_scale, | 
					
					
						
						| 
							 | 
						    layout_crow_ptr, | 
					
					
						
						| 
							 | 
						    layout_col_ptr, | 
					
					
						
						| 
							 | 
						    layout_crow_stride_h, layout_crow_stride_m, | 
					
					
						
						| 
							 | 
						    layout_col_stride_h, layout_col_stride_m, | 
					
					
						
						| 
							 | 
						    TMP, L, M,   | 
					
					
						
						| 
							 | 
						    Out, | 
					
					
						
						| 
							 | 
						    stride_qz, stride_qh, stride_qm, stride_qd, | 
					
					
						
						| 
							 | 
						    stride_kz, stride_kh, stride_kn, stride_kd, | 
					
					
						
						| 
							 | 
						    stride_vz, stride_vh, stride_vn, stride_vd, | 
					
					
						
						| 
							 | 
						    stride_oz, stride_oh, stride_om, stride_od, | 
					
					
						
						| 
							 | 
						    Z, H, N_CTX, | 
					
					
						
						| 
							 | 
						    PAST_LEN, | 
					
					
						
						| 
							 | 
						    Q_ROUNDED_LEN, | 
					
					
						
						| 
							 | 
						    BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, | 
					
					
						
						| 
							 | 
						    BLOCK_N: tl.constexpr, | 
					
					
						
						| 
							 | 
						    EVEN_M_BLOCK: tl.constexpr, | 
					
					
						
						| 
							 | 
						    EVEN_N_BLOCK: tl.constexpr, | 
					
					
						
						| 
							 | 
						    INFERENCE: tl.constexpr, | 
					
					
						
						| 
							 | 
						    NUM_DBLOCKS: tl.constexpr, | 
					
					
						
						| 
							 | 
						): | 
					
					
						
						| 
							 | 
						    Q_LEN = N_CTX - PAST_LEN | 
					
					
						
						| 
							 | 
						    start_m = tl.program_id(0) | 
					
					
						
						| 
							 | 
						    off_hz = tl.program_id(1) | 
					
					
						
						| 
							 | 
						    off_h = off_hz % H | 
					
					
						
						| 
							 | 
						    off_z = off_hz // H | 
					
					
						
						| 
							 | 
						    Q += off_z * stride_qz + off_h * stride_qh | 
					
					
						
						| 
							 | 
						    K += off_z * stride_kz + off_h * stride_kh | 
					
					
						
						| 
							 | 
						    V += off_z * stride_vz + off_h * stride_vh | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) | 
					
					
						
						| 
							 | 
						    offs_n = tl.arange(0, BLOCK_N) | 
					
					
						
						| 
							 | 
						    offs_d = tl.arange(0, BLOCK_DMODEL) | 
					
					
						
						| 
							 | 
						    off_q = offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qd | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    off_k = offs_n[None, :] * stride_kn + offs_d[:, None] * stride_kd | 
					
					
						
						| 
							 | 
						    off_v = offs_n[:, None] * stride_vn + offs_d[None, :] * stride_vd | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    q_ptrs = Q + off_q | 
					
					
						
						| 
							 | 
						    k_ptrs = K + off_k | 
					
					
						
						| 
							 | 
						    v_ptrs = V + off_v | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    t_ptrs = TMP + off_hz * Q_ROUNDED_LEN + offs_m | 
					
					
						
						| 
							 | 
						    m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float('inf') | 
					
					
						
						| 
							 | 
						    l_i = tl.zeros([BLOCK_M], dtype=tl.float32) | 
					
					
						
						| 
							 | 
						    acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) | 
					
					
						
						| 
							 | 
						    if NUM_DBLOCKS >= 2: | 
					
					
						
						| 
							 | 
						        acc2 = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    if EVEN_M_BLOCK: | 
					
					
						
						| 
							 | 
						        q = tl.load(q_ptrs) | 
					
					
						
						| 
							 | 
						        if NUM_DBLOCKS >= 2: | 
					
					
						
						| 
							 | 
						            q2 = tl.load(q_ptrs + BLOCK_DMODEL * stride_qd) | 
					
					
						
						| 
							 | 
						    else: | 
					
					
						
						| 
							 | 
						        q = tl.load(q_ptrs, mask=offs_m[:, None] < Q_LEN) | 
					
					
						
						| 
							 | 
						        if NUM_DBLOCKS >= 2: | 
					
					
						
						| 
							 | 
						            q2 = tl.load(q_ptrs + BLOCK_DMODEL * stride_qd, mask=offs_m[:, None] < Q_LEN) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    layout_ptr = layout_crow_ptr + off_h * layout_crow_stride_h + start_m * layout_crow_stride_m | 
					
					
						
						| 
							 | 
						    start_l = tl.load(layout_ptr).to(tl.int32) | 
					
					
						
						| 
							 | 
						    end_l = tl.load(layout_ptr + layout_crow_stride_m).to(tl.int32) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    for col_idx_idx in range(start_l, end_l): | 
					
					
						
						| 
							 | 
						        col_idx = tl.load(layout_col_ptr +  off_h * layout_col_stride_h + col_idx_idx * layout_col_stride_m).to(tl.int32) | 
					
					
						
						| 
							 | 
						        start_n = col_idx * BLOCK_N | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        if EVEN_N_BLOCK: | 
					
					
						
						| 
							 | 
						            k = tl.load(k_ptrs + start_n * stride_kn) | 
					
					
						
						| 
							 | 
						        else: | 
					
					
						
						| 
							 | 
						            k = tl.load(k_ptrs + start_n * stride_kn, mask=offs_n[None, :] + start_n < N_CTX) | 
					
					
						
						| 
							 | 
						        qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) | 
					
					
						
						| 
							 | 
						        qk += tl.dot(q, k) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        if NUM_DBLOCKS >= 2: | 
					
					
						
						| 
							 | 
						            if EVEN_N_BLOCK: | 
					
					
						
						| 
							 | 
						                k = tl.load(k_ptrs + start_n * stride_kn + BLOCK_DMODEL * stride_kd) | 
					
					
						
						| 
							 | 
						            else: | 
					
					
						
						| 
							 | 
						                k = tl.load(k_ptrs + start_n * stride_kn + BLOCK_DMODEL * stride_kd, mask=offs_n[None, :] + start_n < N_CTX) | 
					
					
						
						| 
							 | 
						            qk += tl.dot(q2, k) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        qk *= sm_scale | 
					
					
						
						| 
							 | 
						        qk += tl.where(offs_m[:, None] + PAST_LEN >= (start_n + offs_n[None, :]), 0, float('-inf')) | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        m_ij = tl.max(qk, 1) | 
					
					
						
						| 
							 | 
						        p = tl.exp(qk - m_ij[:, None]) | 
					
					
						
						| 
							 | 
						        l_ij = tl.sum(p, 1) | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        m_i_new = tl.maximum(m_i, m_ij) | 
					
					
						
						| 
							 | 
						        alpha = tl.exp(m_i - m_i_new) | 
					
					
						
						| 
							 | 
						        beta = tl.exp(m_ij - m_i_new) | 
					
					
						
						| 
							 | 
						        l_i_new = alpha * l_i + beta * l_ij | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        p_scale = beta / l_i_new | 
					
					
						
						| 
							 | 
						        p = p * p_scale[:, None] | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        acc_scale = l_i / l_i_new * alpha | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        acc = acc * acc_scale[:, None] | 
					
					
						
						| 
							 | 
						        if NUM_DBLOCKS >= 2: | 
					
					
						
						| 
							 | 
						            acc2 = acc2 * acc_scale[:, None] | 
					
					
						
						| 
							 | 
						        p = p.to(Q.dtype.element_ty) | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        if EVEN_N_BLOCK: | 
					
					
						
						| 
							 | 
						            v = tl.load(v_ptrs + start_n * stride_vn) | 
					
					
						
						| 
							 | 
						        else: | 
					
					
						
						| 
							 | 
						            v = tl.load(v_ptrs + start_n * stride_vn, mask=offs_n[:, None] + start_n < N_CTX) | 
					
					
						
						| 
							 | 
						        acc += tl.dot(p, v) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        if NUM_DBLOCKS >= 2: | 
					
					
						
						| 
							 | 
						            if EVEN_N_BLOCK: | 
					
					
						
						| 
							 | 
						                v = tl.load(v_ptrs + start_n * stride_vn + BLOCK_DMODEL * stride_vd) | 
					
					
						
						| 
							 | 
						            else: | 
					
					
						
						| 
							 | 
						                v = tl.load(v_ptrs + start_n * stride_vn + BLOCK_DMODEL * stride_vd, mask=offs_n[:, None] + start_n < N_CTX) | 
					
					
						
						| 
							 | 
						            acc2 += tl.dot(p, v) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        l_i = l_i_new | 
					
					
						
						| 
							 | 
						        m_i = m_i_new | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    if not INFERENCE: | 
					
					
						
						| 
							 | 
						        l_ptrs = L + off_hz * N_CTX + offs_m | 
					
					
						
						| 
							 | 
						        m_ptrs = M + off_hz * N_CTX + offs_m | 
					
					
						
						| 
							 | 
						        if EVEN_M_BLOCK: | 
					
					
						
						| 
							 | 
						            tl.store(l_ptrs, l_i) | 
					
					
						
						| 
							 | 
						            tl.store(m_ptrs, m_i) | 
					
					
						
						| 
							 | 
						        else: | 
					
					
						
						| 
							 | 
						            tl.store(l_ptrs, l_i,  mask=offs_m < Q_LEN) | 
					
					
						
						| 
							 | 
						            tl.store(m_ptrs, m_i,  mask=offs_m < Q_LEN) | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    off_o = off_z * stride_oz + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[None, :] * stride_od | 
					
					
						
						| 
							 | 
						    out_ptrs = Out + off_o | 
					
					
						
						| 
							 | 
						    tl.store(out_ptrs, acc,  mask=offs_m[:, None] < Q_LEN) | 
					
					
						
						| 
							 | 
						    if NUM_DBLOCKS >= 2: | 
					
					
						
						| 
							 | 
						        tl.store(out_ptrs + BLOCK_DMODEL * stride_od, acc2,  mask=offs_m[:, None] < Q_LEN) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						@triton.heuristics( | 
					
					
						
						| 
							 | 
						    { | 
					
					
						
						| 
							 | 
						        'EVEN_M_BLOCK': lambda kwargs: kwargs['N_CTX'] % kwargs['BLOCK_M'] == 0, | 
					
					
						
						| 
							 | 
						    } | 
					
					
						
						| 
							 | 
						) | 
					
					
						
						| 
							 | 
						@triton.jit | 
					
					
						
						| 
							 | 
						def _bwd_preprocess( | 
					
					
						
						| 
							 | 
						    Out, DO, L,  | 
					
					
						
						| 
							 | 
						    NewDO, Delta, | 
					
					
						
						| 
							 | 
						    N_CTX, | 
					
					
						
						| 
							 | 
						    BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr, | 
					
					
						
						| 
							 | 
						    EVEN_M_BLOCK: tl.constexpr, | 
					
					
						
						| 
							 | 
						): | 
					
					
						
						| 
							 | 
						    off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M) | 
					
					
						
						| 
							 | 
						    off_d = tl.arange(0, D_HEAD) | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    if EVEN_M_BLOCK: | 
					
					
						
						| 
							 | 
						        o = tl.load(Out + off_m[:, None] * D_HEAD + off_d[None, :]).to(tl.float32) | 
					
					
						
						| 
							 | 
						        do = tl.load(DO + off_m[:, None] * D_HEAD + off_d[None, :]).to(tl.float32) | 
					
					
						
						| 
							 | 
						    else: | 
					
					
						
						| 
							 | 
						        o = tl.load(Out + off_m[:, None] * D_HEAD + off_d[None, :], mask=off_m[:, None] < N_CTX).to(tl.float32) | 
					
					
						
						| 
							 | 
						        do = tl.load(DO + off_m[:, None] * D_HEAD + off_d[None, :], mask=off_m[:, None] < N_CTX).to(tl.float32) | 
					
					
						
						| 
							 | 
						    denom = tl.load(L + off_m).to(tl.float32) | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    do = do / denom[:, None] | 
					
					
						
						| 
							 | 
						    delta = tl.sum(o * do, axis=1) | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    if EVEN_M_BLOCK: | 
					
					
						
						| 
							 | 
						        tl.store(NewDO + off_m[:, None] * D_HEAD + off_d[None, :], do) | 
					
					
						
						| 
							 | 
						    else: | 
					
					
						
						| 
							 | 
						        tl.store(NewDO + off_m[:, None] * D_HEAD + off_d[None, :], do,  mask=off_m[:, None] < N_CTX) | 
					
					
						
						| 
							 | 
						    tl.store(Delta + off_m, delta) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						@triton.heuristics( | 
					
					
						
						| 
							 | 
						    { | 
					
					
						
						| 
							 | 
						        'EVEN_M_BLOCK': lambda kwargs: kwargs['N_CTX'] % kwargs['BLOCK_M'] == 0, | 
					
					
						
						| 
							 | 
						        'EVEN_N_BLOCK': lambda kwargs: kwargs['N_CTX'] % kwargs['BLOCK_N'] == 0, | 
					
					
						
						| 
							 | 
						    } | 
					
					
						
						| 
							 | 
						) | 
					
					
						
						| 
							 | 
						@triton.jit | 
					
					
						
						| 
							 | 
						def _bwd_kernel( | 
					
					
						
						| 
							 | 
						    Q, K, V, sm_scale, | 
					
					
						
						| 
							 | 
						    layout_ccol_ptr, | 
					
					
						
						| 
							 | 
						    layout_row_ptr, | 
					
					
						
						| 
							 | 
						    layout_ccol_stride_h, layout_ccol_stride_m, | 
					
					
						
						| 
							 | 
						    layout_row_stride_h, layout_row_stride_m, | 
					
					
						
						| 
							 | 
						    Out, DO,   | 
					
					
						
						| 
							 | 
						    DQ, DK, DV, | 
					
					
						
						| 
							 | 
						    L, M, | 
					
					
						
						| 
							 | 
						    D, | 
					
					
						
						| 
							 | 
						    stride_qz, stride_qh, stride_qm, stride_qd, | 
					
					
						
						| 
							 | 
						    stride_kz, stride_kh, stride_kn, stride_kd, | 
					
					
						
						| 
							 | 
						    stride_vz, stride_vh, stride_vn, stride_vd, | 
					
					
						
						| 
							 | 
						    stride_oz, stride_oh, stride_om, stride_od, | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    Z, H, N_CTX, | 
					
					
						
						| 
							 | 
						    num_block, | 
					
					
						
						| 
							 | 
						    BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, | 
					
					
						
						| 
							 | 
						    BLOCK_N: tl.constexpr, | 
					
					
						
						| 
							 | 
						    EVEN_M_BLOCK: tl.constexpr, | 
					
					
						
						| 
							 | 
						    EVEN_N_BLOCK: tl.constexpr, | 
					
					
						
						| 
							 | 
						    NUM_DBLOCKS: tl.constexpr, | 
					
					
						
						| 
							 | 
						): | 
					
					
						
						| 
							 | 
						    start_n = tl.program_id(0) | 
					
					
						
						| 
							 | 
						    off_hz = tl.program_id(1) | 
					
					
						
						| 
							 | 
						    off_z = off_hz // H | 
					
					
						
						| 
							 | 
						    off_h = off_hz % H | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    Q += off_z * stride_qz + off_h * stride_qh | 
					
					
						
						| 
							 | 
						    K += off_z * stride_kz + off_h * stride_kh | 
					
					
						
						| 
							 | 
						    V += off_z * stride_vz + off_h * stride_vh | 
					
					
						
						| 
							 | 
						    DO += off_z * stride_oz + off_h * stride_oh | 
					
					
						
						| 
							 | 
						    DQ += off_z * stride_oz + off_h * stride_oh | 
					
					
						
						| 
							 | 
						    DK += off_z * stride_oz + off_h * stride_oh | 
					
					
						
						| 
							 | 
						    DV += off_z * stride_oz + off_h * stride_oh | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) | 
					
					
						
						| 
							 | 
						    offs_m = tl.arange(0, BLOCK_M) | 
					
					
						
						| 
							 | 
						    offs_d = tl.arange(0, BLOCK_DMODEL) | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    k_ptrs = K + (offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kd) | 
					
					
						
						| 
							 | 
						    v_ptrs = V + (offs_n[:, None] * stride_vn + offs_d[None, :] * stride_vd) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    D_ptrs = D + off_hz * N_CTX | 
					
					
						
						| 
							 | 
						    m_ptrs = M + off_hz * N_CTX | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    dv = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32) | 
					
					
						
						| 
							 | 
						    dk = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32) | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    if EVEN_N_BLOCK: | 
					
					
						
						| 
							 | 
						        k = tl.load(k_ptrs) | 
					
					
						
						| 
							 | 
						        v = tl.load(v_ptrs) | 
					
					
						
						| 
							 | 
						    else: | 
					
					
						
						| 
							 | 
						        k = tl.load(k_ptrs, mask=offs_n[:, None] < N_CTX) | 
					
					
						
						| 
							 | 
						        v = tl.load(v_ptrs, mask=offs_n[:, None] < N_CTX) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    if NUM_DBLOCKS >= 2: | 
					
					
						
						| 
							 | 
						        dv2 = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32) | 
					
					
						
						| 
							 | 
						        dk2 = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32) | 
					
					
						
						| 
							 | 
						        if EVEN_N_BLOCK: | 
					
					
						
						| 
							 | 
						            k2 = tl.load(k_ptrs + BLOCK_DMODEL * stride_kd) | 
					
					
						
						| 
							 | 
						            v2 = tl.load(v_ptrs + BLOCK_DMODEL * stride_vd) | 
					
					
						
						| 
							 | 
						        else: | 
					
					
						
						| 
							 | 
						            k2 = tl.load(k_ptrs + BLOCK_DMODEL * stride_kd, mask=offs_n[:, None] < N_CTX) | 
					
					
						
						| 
							 | 
						            v2 = tl.load(v_ptrs + BLOCK_DMODEL * stride_vd, mask=offs_n[:, None] < N_CTX) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    layout_ptr = layout_ccol_ptr + off_h * layout_ccol_stride_h + start_n * layout_ccol_stride_m | 
					
					
						
						| 
							 | 
						    start_l = tl.load(layout_ptr).to(tl.int32) | 
					
					
						
						| 
							 | 
						    end_l = tl.load(layout_ptr + layout_ccol_stride_m).to(tl.int32) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    for row_idx_idx in range(start_l, end_l): | 
					
					
						
						| 
							 | 
						        row_idx = tl.load(layout_row_ptr + off_h * layout_row_stride_h + row_idx_idx * layout_row_stride_m).to(tl.int32) | 
					
					
						
						| 
							 | 
						        start_m = row_idx * BLOCK_M | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        offs_m_curr = start_m + offs_m | 
					
					
						
						| 
							 | 
						        q_ptrs =   Q + (offs_m_curr[:, None] * stride_qm + offs_d[None, :] * stride_qd) | 
					
					
						
						| 
							 | 
						        do_ptrs = DO + (offs_m_curr[:, None] * stride_om + offs_d[None, :] * stride_od) | 
					
					
						
						| 
							 | 
						        dq_ptrs = DQ + (offs_m_curr[:, None] * stride_om + offs_d[None, :] * stride_od) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        if EVEN_M_BLOCK: | 
					
					
						
						| 
							 | 
						            q = tl.load(q_ptrs) | 
					
					
						
						| 
							 | 
						        else: | 
					
					
						
						| 
							 | 
						            q = tl.load(q_ptrs, mask=offs_m_curr[:, None] < N_CTX) | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        qk = tl.dot(q, tl.trans(k)) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        if NUM_DBLOCKS >= 2: | 
					
					
						
						| 
							 | 
						            if EVEN_M_BLOCK: | 
					
					
						
						| 
							 | 
						                q2 = tl.load(q_ptrs + BLOCK_DMODEL * stride_qd) | 
					
					
						
						| 
							 | 
						            else: | 
					
					
						
						| 
							 | 
						                q2 = tl.load(q_ptrs + BLOCK_DMODEL * stride_qd, mask=offs_m_curr[:, None] < N_CTX) | 
					
					
						
						| 
							 | 
						            qk += tl.dot(q2, tl.trans(k2)) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        qk += tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), 0, float('-inf')) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        if EVEN_M_BLOCK: | 
					
					
						
						| 
							 | 
						            m = tl.load(m_ptrs + offs_m_curr) | 
					
					
						
						| 
							 | 
						        else: | 
					
					
						
						| 
							 | 
						            m = tl.load(m_ptrs + offs_m_curr, mask=offs_m_curr < N_CTX) | 
					
					
						
						| 
							 | 
						        p = tl.exp(qk * sm_scale - m[:, None]) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        if EVEN_M_BLOCK: | 
					
					
						
						| 
							 | 
						            do = tl.load(do_ptrs) | 
					
					
						
						| 
							 | 
						        else: | 
					
					
						
						| 
							 | 
						            do = tl.load(do_ptrs, mask=offs_m_curr[:, None] < N_CTX) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        if NUM_DBLOCKS >= 2: | 
					
					
						
						| 
							 | 
						            if EVEN_M_BLOCK: | 
					
					
						
						| 
							 | 
						                do2 = tl.load(do_ptrs + BLOCK_DMODEL * stride_od) | 
					
					
						
						| 
							 | 
						            else: | 
					
					
						
						| 
							 | 
						                do2 = tl.load(do_ptrs + BLOCK_DMODEL * stride_od, mask=offs_m_curr[:, None] < N_CTX) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        dv += tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        if NUM_DBLOCKS >= 2: | 
					
					
						
						| 
							 | 
						            dv2 += tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do2) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        if EVEN_M_BLOCK: | 
					
					
						
						| 
							 | 
						            Di = tl.load(D_ptrs + offs_m_curr) | 
					
					
						
						| 
							 | 
						        else: | 
					
					
						
						| 
							 | 
						            Di = tl.load(D_ptrs + offs_m_curr, mask=offs_m_curr < N_CTX) | 
					
					
						
						| 
							 | 
						        dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None] | 
					
					
						
						| 
							 | 
						        dp += tl.dot(do, tl.trans(v)) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        if NUM_DBLOCKS >= 2: | 
					
					
						
						| 
							 | 
						            dp += tl.dot(do2, tl.trans(v2)) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        ds = p * dp * sm_scale | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        dk += tl.dot(tl.trans(ds.to(Q.dtype.element_ty)), q) | 
					
					
						
						| 
							 | 
						        if NUM_DBLOCKS >= 2: | 
					
					
						
						| 
							 | 
						            dk2 += tl.dot(tl.trans(ds.to(Q.dtype.element_ty)), q2) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        dq = tl.dot(ds.to(Q.dtype.element_ty), k) | 
					
					
						
						| 
							 | 
						        if EVEN_M_BLOCK: | 
					
					
						
						| 
							 | 
						            tl.atomic_add(dq_ptrs, dq) | 
					
					
						
						| 
							 | 
						        else: | 
					
					
						
						| 
							 | 
						            tl.atomic_add(dq_ptrs, dq, mask=offs_m_curr[:, None] < N_CTX) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        if NUM_DBLOCKS >= 2: | 
					
					
						
						| 
							 | 
						            dq2 = tl.dot(ds.to(Q.dtype.element_ty), k2) | 
					
					
						
						| 
							 | 
						            dq_ptrs2 = dq_ptrs + BLOCK_DMODEL * stride_od | 
					
					
						
						| 
							 | 
						            if EVEN_M_BLOCK: | 
					
					
						
						| 
							 | 
						                tl.atomic_add(dq_ptrs2, dq2) | 
					
					
						
						| 
							 | 
						            else: | 
					
					
						
						| 
							 | 
						                tl.atomic_add(dq_ptrs2, dq2, mask=offs_m_curr[:, None] < N_CTX) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    dv_ptrs = DV + (offs_n[:, None] * stride_om + offs_d[None, :] * stride_od) | 
					
					
						
						| 
							 | 
						    dk_ptrs = DK + (offs_n[:, None] * stride_om + offs_d[None, :] * stride_od) | 
					
					
						
						| 
							 | 
						    if EVEN_N_BLOCK: | 
					
					
						
						| 
							 | 
						        tl.store(dv_ptrs, dv) | 
					
					
						
						| 
							 | 
						        tl.store(dk_ptrs, dk) | 
					
					
						
						| 
							 | 
						    else: | 
					
					
						
						| 
							 | 
						        tl.store(dv_ptrs, dv, mask=offs_n[:, None] < N_CTX) | 
					
					
						
						| 
							 | 
						        tl.store(dk_ptrs, dk, mask=offs_n[:, None] < N_CTX) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    if NUM_DBLOCKS >= 2: | 
					
					
						
						| 
							 | 
						        dv_ptrs2 = dv_ptrs + BLOCK_DMODEL * stride_od | 
					
					
						
						| 
							 | 
						        dk_ptrs2 = dk_ptrs + BLOCK_DMODEL * stride_od | 
					
					
						
						| 
							 | 
						        if EVEN_N_BLOCK: | 
					
					
						
						| 
							 | 
						            tl.store(dv_ptrs2, dv2) | 
					
					
						
						| 
							 | 
						            tl.store(dk_ptrs2, dk2) | 
					
					
						
						| 
							 | 
						        else: | 
					
					
						
						| 
							 | 
						            tl.store(dv_ptrs2, dv2, mask=offs_n[:, None] < N_CTX) | 
					
					
						
						| 
							 | 
						            tl.store(dk_ptrs2, dk2, mask=offs_n[:, None] < N_CTX) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def _forward(ctx, q, k, v, layout_crow_indices, layout_col_indices, sm_scale, BLOCK_M, BLOCK_N, num_warps=None, num_stages=1, inference=None, out=None): | 
					
					
						
						| 
							 | 
						    ''' | 
					
					
						
						| 
							 | 
						    :param q, k, v: [batch, n_heads, seq_len, model_dim]. len of q is allowed to be different than k/v. | 
					
					
						
						| 
							 | 
						    :param layout_crow_indices, layout_col_indices: same as CSR.crow_indices, and CSR.col_indices used to preresent a sparse tensor. | 
					
					
						
						| 
							 | 
						        Each element represent a block, i.e, all elements in a block to be attentdd, or not attended at all.. | 
					
					
						
						| 
							 | 
						    ''' | 
					
					
						
						| 
							 | 
						    assert q.shape[-1] == k.shape[-1] == v.shape[-1] | 
					
					
						
						| 
							 | 
						    assert k.shape[2] == v.shape[2] | 
					
					
						
						| 
							 | 
						    o = out if out is not None else torch.empty_like(q).contiguous() | 
					
					
						
						| 
							 | 
						    grid = (triton.cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1]) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    q_rounded_len = grid[0] * BLOCK_M | 
					
					
						
						| 
							 | 
						    tmp = torch.empty((q.shape[0] * q.shape[1], q_rounded_len), device=q.device, dtype=torch.float32) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    if inference is None: | 
					
					
						
						| 
							 | 
						        inference = (not q.requires_grad) and (not k.requires_grad)  and (not v.requires_grad) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    if inference: | 
					
					
						
						| 
							 | 
						        L, m = tmp, tmp   | 
					
					
						
						| 
							 | 
						    else: | 
					
					
						
						| 
							 | 
						        L = torch.empty((q.shape[0] * q.shape[1], q_rounded_len), device=q.device, dtype=torch.float32) | 
					
					
						
						| 
							 | 
						        m = torch.empty((q.shape[0] * q.shape[1], q_rounded_len), device=q.device, dtype=torch.float32) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    if layout_col_indices.dim() == 1: | 
					
					
						
						| 
							 | 
						        layout_crow_indices = layout_crow_indices[None].expand(q.shape[1] , -1) | 
					
					
						
						| 
							 | 
						        layout_col_indices = layout_col_indices[None].expand(q.shape[1] , -1) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    assert q.shape[-1] in [64, 128] | 
					
					
						
						| 
							 | 
						    BLOCK_DMODEL = 64 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    if num_warps is None: | 
					
					
						
						| 
							 | 
						        MIN_D = min(BLOCK_M, BLOCK_N, BLOCK_DMODEL) | 
					
					
						
						| 
							 | 
						        num_warps = max(1, 2 ** int(math.log2(MIN_D / 16))) | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						    else: | 
					
					
						
						| 
							 | 
						        assert math.log2(num_warps) % 1 == 0, f'''"num_warps" should be power of 2, but got {num_warps}.''' | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    _fwd_kernel[grid]( | 
					
					
						
						| 
							 | 
						        q, k, v, sm_scale, | 
					
					
						
						| 
							 | 
						        layout_crow_indices, | 
					
					
						
						| 
							 | 
						        layout_col_indices, | 
					
					
						
						| 
							 | 
						        layout_crow_indices.stride(0), layout_crow_indices.stride(1), | 
					
					
						
						| 
							 | 
						        layout_col_indices.stride(0), layout_col_indices.stride(1), | 
					
					
						
						| 
							 | 
						        tmp, L, m, | 
					
					
						
						| 
							 | 
						        o, | 
					
					
						
						| 
							 | 
						        q.stride(0), q.stride(1), q.stride(2), q.stride(3), | 
					
					
						
						| 
							 | 
						        k.stride(0), k.stride(1), k.stride(2), k.stride(3), | 
					
					
						
						| 
							 | 
						        v.stride(0), v.stride(1), v.stride(2), v.stride(3), | 
					
					
						
						| 
							 | 
						        o.stride(0), o.stride(1), o.stride(2), o.stride(3), | 
					
					
						
						| 
							 | 
						        q.shape[0], q.shape[1], k.shape[2], | 
					
					
						
						| 
							 | 
						        k.shape[2] - q.shape[2], | 
					
					
						
						| 
							 | 
						        q_rounded_len, | 
					
					
						
						| 
							 | 
						        BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, | 
					
					
						
						| 
							 | 
						        BLOCK_DMODEL=BLOCK_DMODEL, | 
					
					
						
						| 
							 | 
						        EVEN_M_BLOCK=q.shape[2] % BLOCK_M == 0, | 
					
					
						
						| 
							 | 
						        EVEN_N_BLOCK=k.shape[2] % BLOCK_N == 0 , | 
					
					
						
						| 
							 | 
						        INFERENCE=inference, | 
					
					
						
						| 
							 | 
						        NUM_DBLOCKS=q.shape[-1] // BLOCK_DMODEL, | 
					
					
						
						| 
							 | 
						        num_warps=num_warps, | 
					
					
						
						| 
							 | 
						        num_stages=num_stages, | 
					
					
						
						| 
							 | 
						    ) | 
					
					
						
						| 
							 | 
						    if inference: | 
					
					
						
						| 
							 | 
						        L, m = None, None | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    ctx.save_for_backward(q, k, v, o, L, m, layout_crow_indices,  layout_col_indices) | 
					
					
						
						| 
							 | 
						    ctx.BLOCK_M = BLOCK_M | 
					
					
						
						| 
							 | 
						    ctx.BLOCK_N = BLOCK_N | 
					
					
						
						| 
							 | 
						    ctx.BLOCK_DMODEL = BLOCK_DMODEL | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    ctx.grid = grid | 
					
					
						
						| 
							 | 
						    ctx.sm_scale = sm_scale | 
					
					
						
						| 
							 | 
						    ctx.num_warps = num_warps | 
					
					
						
						| 
							 | 
						    ctx.num_stages = num_stages | 
					
					
						
						| 
							 | 
						    return o | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def _backward(ctx, do, layout_ccol_indices, layout_row_indices, dq=None, dk=None, dv=None): | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    q, k, v, o, l, m, layout_crow_indices, layout_col_indices = ctx.saved_tensors | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    if not do.is_contiguous(): | 
					
					
						
						| 
							 | 
						        do = do.contiguous() | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    if not o.is_contiguous(): | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        raise ValueError(f'--> output is not contiguous: {o.stride()=}. This is maybe caused by q/k/v not being contiguous.') | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    if layout_ccol_indices.dim() == 1: | 
					
					
						
						| 
							 | 
						        layout_ccol_indices = layout_ccol_indices[None].expand(q.shape[1], -1) | 
					
					
						
						| 
							 | 
						        layout_row_indices = layout_row_indices[None].expand(q.shape[1], -1) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    dq = dq if dq is not None else torch.zeros_like(q, dtype=torch.float32) | 
					
					
						
						| 
							 | 
						    dk = dk if dk is not None else torch.empty_like(k) | 
					
					
						
						| 
							 | 
						    dv =dv if dv is not None else  torch.empty_like(v) | 
					
					
						
						| 
							 | 
						    do_scaled = torch.empty_like(do) | 
					
					
						
						| 
							 | 
						    delta = torch.empty_like(l) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    assert o.stride() == dq.stride() == dk.stride() == dv.stride() == do_scaled.stride() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    _bwd_preprocess[(ctx.grid[0] * ctx.grid[1], )]( | 
					
					
						
						| 
							 | 
						        o, do, l, | 
					
					
						
						| 
							 | 
						        do_scaled, delta, | 
					
					
						
						| 
							 | 
						        k.shape[2], | 
					
					
						
						| 
							 | 
						        BLOCK_M=ctx.BLOCK_M, D_HEAD=q.shape[-1], | 
					
					
						
						| 
							 | 
						    ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    grid = (triton.cdiv(q.shape[2], ctx.BLOCK_N), ctx.grid[1]) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    _bwd_kernel[grid]( | 
					
					
						
						| 
							 | 
						        q, k, v, ctx.sm_scale, | 
					
					
						
						| 
							 | 
						        layout_ccol_indices, | 
					
					
						
						| 
							 | 
						        layout_row_indices, | 
					
					
						
						| 
							 | 
						        layout_ccol_indices.stride(0), layout_ccol_indices.stride(1), | 
					
					
						
						| 
							 | 
						        layout_row_indices.stride(0), layout_row_indices.stride(1), | 
					
					
						
						| 
							 | 
						        o, do_scaled, | 
					
					
						
						| 
							 | 
						        dq, dk, dv, | 
					
					
						
						| 
							 | 
						        l, m, | 
					
					
						
						| 
							 | 
						        delta, | 
					
					
						
						| 
							 | 
						        q.stride(0), q.stride(1), q.stride(2), q.stride(3), | 
					
					
						
						| 
							 | 
						        k.stride(0), k.stride(1), k.stride(2), k.stride(3), | 
					
					
						
						| 
							 | 
						        v.stride(0), v.stride(1), v.stride(2), v.stride(3), | 
					
					
						
						| 
							 | 
						        o.stride(0), o.stride(1), o.stride(2), o.stride(3), | 
					
					
						
						| 
							 | 
						        q.shape[0], q.shape[1], q.shape[2], | 
					
					
						
						| 
							 | 
						        ctx.grid[0], | 
					
					
						
						| 
							 | 
						        BLOCK_M=ctx.BLOCK_M, | 
					
					
						
						| 
							 | 
						        BLOCK_N=ctx.BLOCK_N, | 
					
					
						
						| 
							 | 
						        BLOCK_DMODEL=ctx.BLOCK_DMODEL, | 
					
					
						
						| 
							 | 
						        NUM_DBLOCKS=q.shape[-1] // ctx.BLOCK_DMODEL, | 
					
					
						
						| 
							 | 
						        num_warps=ctx.num_warps, | 
					
					
						
						| 
							 | 
						        num_stages=1, | 
					
					
						
						| 
							 | 
						    ) | 
					
					
						
						| 
							 | 
						    return dq, dk, dv, None, None, None | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						class _sparse_attention(torch.autograd.Function): | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    @staticmethod | 
					
					
						
						| 
							 | 
						    def forward(ctx, q, k, v, layout_crow_indices, layout_col_indices, sm_scale): | 
					
					
						
						| 
							 | 
						        BLOCK = 128 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        return _forward(ctx, q, k, v, layout_crow_indices, layout_col_indices, sm_scale, BLOCK, BLOCK) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    @staticmethod | 
					
					
						
						| 
							 | 
						    def backward(ctx, do): | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        q, k, v, o, l, m, layout_crow_indices, layout_col_indices = ctx.saved_tensors | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        layout_ccol_indices, layout_row_indices = dense_to_ccol_row(crow_col_to_dense(layout_crow_indices, layout_col_indices)) | 
					
					
						
						| 
							 | 
						        return _backward(ctx, do, layout_ccol_indices, layout_row_indices) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						class _sparse_attention_inference(_sparse_attention): | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    @staticmethod | 
					
					
						
						| 
							 | 
						    def forward(ctx, q, k, v, layout_crow_indices, layout_col_indices, sm_scale): | 
					
					
						
						| 
							 | 
						        BLOCK = 128 | 
					
					
						
						| 
							 | 
						        return _forward(ctx, q, k, v, layout_crow_indices, layout_col_indices, sm_scale, 1, BLOCK) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def sparse_attention_factory(BLOCK_M=128, BLOCK_N=128, **kwargs): | 
					
					
						
						| 
							 | 
						    class _sparse_attention_config(_sparse_attention): | 
					
					
						
						| 
							 | 
						        @staticmethod | 
					
					
						
						| 
							 | 
						        def forward(ctx, q, k, v, layout_crow_indices, layout_col_indices, sm_scale): | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            return _forward(ctx, q, k, v, layout_crow_indices, layout_col_indices, sm_scale, BLOCK_M, BLOCK_N, | 
					
					
						
						| 
							 | 
						                            **kwargs | 
					
					
						
						| 
							 | 
						                        ) | 
					
					
						
						| 
							 | 
						    return _sparse_attention_config.apply | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						@lru_cache(maxsize=8) | 
					
					
						
						| 
							 | 
						def get_local_strided_sparse_attention_op( | 
					
					
						
						| 
							 | 
						        n_heads: int, | 
					
					
						
						| 
							 | 
						        max_seq_len:int, | 
					
					
						
						| 
							 | 
						        sparse_block_size: int=128, | 
					
					
						
						| 
							 | 
						        local_blocks: int=4, | 
					
					
						
						| 
							 | 
						        vert_stride: int=4, | 
					
					
						
						| 
							 | 
						        homo_head: bool=False, | 
					
					
						
						| 
							 | 
						        dtype=torch.bfloat16, | 
					
					
						
						| 
							 | 
						        device='cuda', | 
					
					
						
						| 
							 | 
						        active_head_range=None, | 
					
					
						
						| 
							 | 
						        verbose=True, | 
					
					
						
						| 
							 | 
						        **kwargs): | 
					
					
						
						| 
							 | 
						    ''' | 
					
					
						
						| 
							 | 
						    :param n_heads: total number of attention heads (regardless of tensor/model parallel) | 
					
					
						
						| 
							 | 
						    :param max_seq_len: max sequence length. Need to be bigger or equal to the length of sequences. | 
					
					
						
						| 
							 | 
						    :param sparse_block_size: sparse block size. Default to 128 | 
					
					
						
						| 
							 | 
						    :param local_blocks: number of nearest block to attend to. Default to 4, i.e., attention to previous 4xblock_size tokens. | 
					
					
						
						| 
							 | 
						    :param vert_stride: Default to 4. Meaning | 
					
					
						
						| 
							 | 
						    :param homo_head: if all head shared the same pattern. | 
					
					
						
						| 
							 | 
						    :param active_head_range: tuple of start & end of the heads, e..g, (8, 16). Default to use all heads. | 
					
					
						
						| 
							 | 
						                              Mainly for tensor/model parallelization where heads are splitted to different GPUs. | 
					
					
						
						| 
							 | 
						    ''' | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    if verbose: | 
					
					
						
						| 
							 | 
						        print((f'> new block_sparse_attn op constructed with config: ' | 
					
					
						
						| 
							 | 
						            f'{n_heads=}, {max_seq_len=}, {sparse_block_size=}, {local_blocks=}, ' | 
					
					
						
						| 
							 | 
						            f'{vert_stride=}, {homo_head=}, {active_head_range=}, {kwargs=}')) | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    _, block_sparse_pattern, _ = _get_sparse_attn_mask(n_heads, max_seq_len, max_seq_len, dtype, device, | 
					
					
						
						| 
							 | 
						                                                       BLOCK=sparse_block_size, local_blocks=local_blocks, | 
					
					
						
						| 
							 | 
						                                                       vert_stride=vert_stride, homo_head=homo_head, | 
					
					
						
						| 
							 | 
						                                                       return_dense=False) | 
					
					
						
						| 
							 | 
						    if (not homo_head) and (active_head_range is not None): | 
					
					
						
						| 
							 | 
						        assert isinstance(active_head_range, tuple) | 
					
					
						
						| 
							 | 
						        assert len(active_head_range) == 2, '"active_head_range" should be a tuple of start/end index of the heads.' | 
					
					
						
						| 
							 | 
						        h_start, h_end = active_head_range | 
					
					
						
						| 
							 | 
						        block_sparse_pattern = block_sparse_pattern[h_start:h_end] | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    return get_sparse_attn_op(block_sparse_pattern, sparse_block_size, **kwargs) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def get_sparse_attn_op( | 
					
					
						
						| 
							 | 
						        sparse_pattern: torch.tensor, | 
					
					
						
						| 
							 | 
						        sparse_block_size: int=128, | 
					
					
						
						| 
							 | 
						        kernel_block_size=128, | 
					
					
						
						| 
							 | 
						        qkv_format='q,k,v', | 
					
					
						
						| 
							 | 
						          **kwargs): | 
					
					
						
						| 
							 | 
						    ''' | 
					
					
						
						| 
							 | 
						    Ccreate a block-sparse op with fixed layout. This is to avoid the need to of create CSR layout and convert it to CSC layout everytime, | 
					
					
						
						| 
							 | 
						        which is very inefficient (use python loops on CPU.  PyTorch 1.13 supports CSR->CSC, may help.) | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						    :param sparse_pattern: sparse pattern of the blocks. Should be `num_blocks(q) x num_blocks(k)` or `n_heads x num_blocks x num_blocks`. | 
					
					
						
						| 
							 | 
						        This tensor should have lower-triangular matrices on the last 2 dimensions for causal attention | 
					
					
						
						| 
							 | 
						    :param sparse_block_size: sparse block size. Default to 128 | 
					
					
						
						| 
							 | 
						    :param kernel_block_size: the tile/block size to launch a triton instance. Default to None, i.e., same as `sparse_block_size` | 
					
					
						
						| 
							 | 
						    :param qkv_format: Choices=['q,k,v', 'q, kv', 'qkv'], i.e., separated q,k,v, or kv packed, or qkv packed. Currently, only 'q,k,v' is supported. | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						    :param kwargs: keyward arguments passed to `_forward` | 
					
					
						
						| 
							 | 
						    ''' | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    assert qkv_format == 'q,k,v' | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    if kernel_block_size is None: | 
					
					
						
						| 
							 | 
						        kernel_block_size = sparse_block_size | 
					
					
						
						| 
							 | 
						    else: | 
					
					
						
						| 
							 | 
						        assert sparse_block_size % kernel_block_size == 0, f"The sparse block size must be a multiple of {kernel_block_size}." | 
					
					
						
						| 
							 | 
						        assert kernel_block_size >=16 and math.log2(kernel_block_size) % 1 == 0, f"block_size must be power of 2 and at least 16, but {kernel_block_size} is given" | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        if sparse_block_size // kernel_block_size > 1: | 
					
					
						
						| 
							 | 
						            _mul = sparse_block_size // kernel_block_size | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            sparse_pattern = torch.kron(sparse_pattern, sparse_pattern.new_ones(_mul, _mul)) | 
					
					
						
						| 
							 | 
						            num_sparse_blocks = sparse_pattern.size(-1) | 
					
					
						
						| 
							 | 
						            block_causal_mask = torch.arange(0, num_sparse_blocks)[:, None] >= torch.arange(0, num_sparse_blocks)[None] | 
					
					
						
						| 
							 | 
						            sparse_pattern *= block_causal_mask.type_as(sparse_pattern) | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    BLOCK_N = kernel_block_size | 
					
					
						
						| 
							 | 
						    NUM_BLOCK =  sparse_pattern.size(-1) | 
					
					
						
						| 
							 | 
						    MAX_SEQ_LEN = kernel_block_size * NUM_BLOCK | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    grand_layout_crow_indices, grand_layout_col_indices = dense_to_crow_col(sparse_pattern) | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    grand_layout_ccol_indices, grand_layout_row_indices = dense_to_ccol_row(sparse_pattern) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    max_cache_size = 1 if kwargs.get('inference', False) else 8 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    @lru_cache(maxsize=max_cache_size) | 
					
					
						
						| 
							 | 
						    def get_backward_layout_by_block_len(block_len): | 
					
					
						
						| 
							 | 
						        assert block_len <= NUM_BLOCK | 
					
					
						
						| 
							 | 
						        if block_len == NUM_BLOCK: | 
					
					
						
						| 
							 | 
						            return (grand_layout_ccol_indices, grand_layout_row_indices) | 
					
					
						
						| 
							 | 
						        return dense_to_ccol_row(sparse_pattern[..., :block_len, :block_len]) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						      | 
					
					
						
						| 
							 | 
						    class _q_k_v_sparse_attention(torch.autograd.Function): | 
					
					
						
						| 
							 | 
						        @staticmethod | 
					
					
						
						| 
							 | 
						        def forward(ctx, q, k, v, sm_scale): | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            MIN_BLOCK_SIZE = 16 | 
					
					
						
						| 
							 | 
						            assert BLOCK_N >= MIN_BLOCK_SIZE | 
					
					
						
						| 
							 | 
						            BLOCK_M = 16 if q.shape[2] <= 16 else BLOCK_N   | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            K_BLOCKS = triton.cdiv(k.shape[2],  kernel_block_size) | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            Q_START_BLOCKS = K_BLOCKS - triton.cdiv(q.shape[2], BLOCK_N) | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						            layout_crow_indices = grand_layout_crow_indices[..., Q_START_BLOCKS:K_BLOCKS+1] | 
					
					
						
						| 
							 | 
						            layout_col_indices = grand_layout_col_indices | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						            return _forward(ctx, q, k, v, layout_crow_indices, layout_col_indices, sm_scale, BLOCK_M, BLOCK_N, | 
					
					
						
						| 
							 | 
						                            **kwargs | 
					
					
						
						| 
							 | 
						                        ) | 
					
					
						
						| 
							 | 
						        @staticmethod | 
					
					
						
						| 
							 | 
						        def backward(ctx, do): | 
					
					
						
						| 
							 | 
						            q, k = ctx.saved_tensors[:2] | 
					
					
						
						| 
							 | 
						            assert q.shape[2] == k.shape[2], '> currently backward can only be done if q, k have same length. Contact @EricLin if you need it.' | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            block_len = triton.cdiv(do.shape[2], kernel_block_size) | 
					
					
						
						| 
							 | 
						            backward_layout = get_backward_layout_by_block_len(block_len) | 
					
					
						
						| 
							 | 
						            return _backward(ctx, do, *backward_layout)[:4] | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def _q_k_v_sparse_attention_fn(*args): | 
					
					
						
						| 
							 | 
						        return _q_k_v_sparse_attention.apply(*args) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    _q_k_v_sparse_attention_fn.sparse_pattern = sparse_pattern | 
					
					
						
						| 
							 | 
						    _q_k_v_sparse_attention_fn.grand_layout_crow_indices = grand_layout_crow_indices | 
					
					
						
						| 
							 | 
						    _q_k_v_sparse_attention_fn.grand_layout_col_indices = grand_layout_col_indices | 
					
					
						
						| 
							 | 
						    _q_k_v_sparse_attention_fn.grand_layout_ccol_indices = grand_layout_ccol_indices | 
					
					
						
						| 
							 | 
						    _q_k_v_sparse_attention_fn.grand_layout_row_indices = grand_layout_row_indices | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    return _q_k_v_sparse_attention_fn | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def blocksparse_flash_attn_padded_fwd( | 
					
					
						
						| 
							 | 
						    q, k, v,  | 
					
					
						
						| 
							 | 
						    sm_scale, | 
					
					
						
						| 
							 | 
						    sparse_layout, | 
					
					
						
						| 
							 | 
						    *, | 
					
					
						
						| 
							 | 
						    left_paddings = None, | 
					
					
						
						| 
							 | 
						    seqlens = None, | 
					
					
						
						| 
							 | 
						    block_size = 64, | 
					
					
						
						| 
							 | 
						    max_seqlen = None | 
					
					
						
						| 
							 | 
						): | 
					
					
						
						| 
							 | 
						    ''' | 
					
					
						
						| 
							 | 
						    q, k, v: (batch, tokens, n_heads/n_kv_heads, head_size) | 
					
					
						
						| 
							 | 
						    left_paddings: (batch, ), number of left paddings for each sample. | 
					
					
						
						| 
							 | 
						    seqlens: can be used to specify right padding. No need to specify if left_paddings is used. | 
					
					
						
						| 
							 | 
						    ''' | 
					
					
						
						| 
							 | 
						    batches, q_len, n_heads, head_size = q.shape | 
					
					
						
						| 
							 | 
						    _, k_len, n_kv_heads, _ = k.shape | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    assert q.dim() == k.dim() == v.dim() == 4 | 
					
					
						
						| 
							 | 
						    assert q.size(2) % k.size(2) == 0 | 
					
					
						
						| 
							 | 
						    assert q.size(0) == k.size(0) and q.size(3) == k.size(3) | 
					
					
						
						| 
							 | 
						    assert k.shape == v.shape  | 
					
					
						
						| 
							 | 
						    assert q_len == 1 or q_len == k_len, \ | 
					
					
						
						| 
							 | 
						    f'q length can only 1 for decoding for same as k length for prefilling.' | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    q_k_ratio = q.size(2) // k.size(2) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    if max_seqlen: | 
					
					
						
						| 
							 | 
						        assert k.size(1) <= max_seqlen, f'k has seqlen {k.size(1)} while max sequence length is set to {max_seqlen}.' | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    out = q.new_zeros(q.shape) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    layout_crow_indices, layout_col_indices = sparse_layout | 
					
					
						
						| 
							 | 
						    block_d = triton.next_power_of_2(head_size) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    if left_paddings is not None: | 
					
					
						
						| 
							 | 
						        assert left_paddings.shape == (batches,) | 
					
					
						
						| 
							 | 
						        k_batch_starts = left_paddings.to(q.device, dtype=torch.int32).contiguous() | 
					
					
						
						| 
							 | 
						    else: | 
					
					
						
						| 
							 | 
						        k_batch_starts = torch.zeros((batches,), dtype=torch.int32, device=q.device) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    if seqlens is not None: | 
					
					
						
						| 
							 | 
						        k_batch_ends = k_batch_starts + seqlens.type_as(k_batch_starts) | 
					
					
						
						| 
							 | 
						        assert k_batch_ends.max() <= k_len, f'seqlens (+left_paddings if any) exceeds seqlen.' | 
					
					
						
						| 
							 | 
						    else: | 
					
					
						
						| 
							 | 
						        k_batch_ends = torch.zeros_like(k_batch_starts) + k_len | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    if q_len == 1: | 
					
					
						
						| 
							 | 
						        q_batch_starts = torch.zeros_like(k_batch_starts) | 
					
					
						
						| 
							 | 
						        q_batch_ends = q_batch_starts + 1 | 
					
					
						
						| 
							 | 
						    else: | 
					
					
						
						| 
							 | 
						        q_batch_starts = k_batch_starts | 
					
					
						
						| 
							 | 
						        q_batch_ends = k_batch_ends | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    q_lens = (q_batch_ends - q_batch_starts).cpu() | 
					
					
						
						| 
							 | 
						    n_blocks = (q_lens + block_size - 1) // block_size | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    q_batch_ids = torch.tensor([i for i, n in enumerate(n_blocks) for _ in range(n)], | 
					
					
						
						| 
							 | 
						                                dtype=q_batch_starts.dtype, | 
					
					
						
						| 
							 | 
						                                device=q_batch_starts.device) | 
					
					
						
						| 
							 | 
						    q_start_sids = torch.tensor([i * block_size for n in n_blocks for i in range(n)], | 
					
					
						
						| 
							 | 
						                               dtype=q_batch_starts.dtype, | 
					
					
						
						| 
							 | 
						                               device=q_batch_starts.device) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    grid = (len(q_start_sids), n_heads) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    _fwd_kernel_batch_inference[grid]( | 
					
					
						
						| 
							 | 
						    q, k, v, out, | 
					
					
						
						| 
							 | 
						    sm_scale, | 
					
					
						
						| 
							 | 
						    q_batch_starts, | 
					
					
						
						| 
							 | 
						    q_batch_ends, | 
					
					
						
						| 
							 | 
						    k_batch_starts, | 
					
					
						
						| 
							 | 
						    k_batch_ends, | 
					
					
						
						| 
							 | 
						    q_batch_ids, | 
					
					
						
						| 
							 | 
						    q_start_sids, | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    *q.stride(), | 
					
					
						
						| 
							 | 
						    *k.stride(), | 
					
					
						
						| 
							 | 
						    *v.stride(), | 
					
					
						
						| 
							 | 
						    *out.stride(), | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    layout_crow_indices, | 
					
					
						
						| 
							 | 
						    layout_col_indices, | 
					
					
						
						| 
							 | 
						    *layout_crow_indices.stride(), | 
					
					
						
						| 
							 | 
						    *layout_col_indices.stride(), | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    q_k_ratio, | 
					
					
						
						| 
							 | 
						    HAS_BATCH_DIM = True, | 
					
					
						
						| 
							 | 
						    D_HEAD = head_size, | 
					
					
						
						| 
							 | 
						    BLOCK_M = block_size, | 
					
					
						
						| 
							 | 
						    BLOCK_N = block_size, | 
					
					
						
						| 
							 | 
						    BLOCK_D = block_d, | 
					
					
						
						| 
							 | 
						    BLOCK_M_LOADING = 16 if q_len == 1 else block_size,  | 
					
					
						
						| 
							 | 
						    EVEN_D = block_d == head_size, | 
					
					
						
						| 
							 | 
						    num_warps = 1 if q_len == 1 else 4, | 
					
					
						
						| 
							 | 
						    num_stages = 3 | 
					
					
						
						| 
							 | 
						    ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    return out | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def blocksparse_flash_attn_varlen_fwd( | 
					
					
						
						| 
							 | 
						    q, k, v,  | 
					
					
						
						| 
							 | 
						    cu_seqlens_k, | 
					
					
						
						| 
							 | 
						    cu_seqlens_q, | 
					
					
						
						| 
							 | 
						    sm_scale, | 
					
					
						
						| 
							 | 
						    sparse_layout, | 
					
					
						
						| 
							 | 
						    *, | 
					
					
						
						| 
							 | 
						    block_size=64, | 
					
					
						
						| 
							 | 
						    max_seqlen = None | 
					
					
						
						| 
							 | 
						): | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    _, n_heads, head_size = q.shape | 
					
					
						
						| 
							 | 
						    batch_size = cu_seqlens_k.size(0) - 1 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    assert q.dim() == k.dim() == v.dim() == 3 | 
					
					
						
						| 
							 | 
						    assert q.size(1) % k.size(1) == 0 | 
					
					
						
						| 
							 | 
						    assert q.size(2) == k.size(2) | 
					
					
						
						| 
							 | 
						    assert k.shape == v.shape  | 
					
					
						
						| 
							 | 
						    assert cu_seqlens_k.dim() == 1 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    q_k_ratio = q.size(1) // k.size(1) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    if cu_seqlens_q is None: | 
					
					
						
						| 
							 | 
						        if q.size(0) == batch_size:  | 
					
					
						
						| 
							 | 
						            cu_seqlens_q = torch.arange(0, batch_size + 1, | 
					
					
						
						| 
							 | 
						                                        dtype=cu_seqlens_k.dtype, | 
					
					
						
						| 
							 | 
						                                        device=cu_seqlens_k.device) | 
					
					
						
						| 
							 | 
						        elif q.size(0) == k.size(0): | 
					
					
						
						| 
							 | 
						            cu_seqlens_q = cu_seqlens_k | 
					
					
						
						| 
							 | 
						        else: | 
					
					
						
						| 
							 | 
						            raise ValueError('cu_seqlens_q must be specified if it is mix of prefilling and decoding.') | 
					
					
						
						| 
							 | 
						    else: | 
					
					
						
						| 
							 | 
						        assert cu_seqlens_k.size(0) == cu_seqlens_q.size(0) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    q_lens = (cu_seqlens_q[1:] - cu_seqlens_q[:-1]).cpu() | 
					
					
						
						| 
							 | 
						    k_lens = (cu_seqlens_k[1:] - cu_seqlens_k[:-1]).cpu() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    assert torch.logical_or(q_lens == 1, k_lens == q_lens).all(), \ | 
					
					
						
						| 
							 | 
						        'length of q should either be 1 (decoding) or same as k (prefilling).' | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    if max_seqlen: | 
					
					
						
						| 
							 | 
						        assert k_lens.max() <= max_seqlen | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    n_blocks = (q_lens + block_size - 1) // block_size | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    q_batch_ids = torch.tensor([i for i, n in enumerate(n_blocks) for _ in range(n)], | 
					
					
						
						| 
							 | 
						                                dtype=cu_seqlens_q.dtype, | 
					
					
						
						| 
							 | 
						                                device=cu_seqlens_q.device) | 
					
					
						
						| 
							 | 
						    q_start_sids = torch.tensor([i * block_size for n in n_blocks for i in range(n)], | 
					
					
						
						| 
							 | 
						                               dtype=cu_seqlens_q.dtype, | 
					
					
						
						| 
							 | 
						                               device=cu_seqlens_q.device) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    out = q.new_empty(q.shape) | 
					
					
						
						| 
							 | 
						    cu_seqlens_q = cu_seqlens_q.contiguous() | 
					
					
						
						| 
							 | 
						    cu_seqlens_k = cu_seqlens_k.contiguous() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    layout_crow_indices, layout_col_indices = sparse_layout | 
					
					
						
						| 
							 | 
						    block_d = triton.next_power_of_2(head_size) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    decoding_only =  (q_lens == 1).all() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    grid = (len(q_start_sids), n_heads) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    _fwd_kernel_batch_inference[grid]( | 
					
					
						
						| 
							 | 
						    q, k, v, out, | 
					
					
						
						| 
							 | 
						    sm_scale, | 
					
					
						
						| 
							 | 
						    cu_seqlens_q[:-1], | 
					
					
						
						| 
							 | 
						    cu_seqlens_q[1:], | 
					
					
						
						| 
							 | 
						    cu_seqlens_k[:-1], | 
					
					
						
						| 
							 | 
						    cu_seqlens_k[1:], | 
					
					
						
						| 
							 | 
						    q_batch_ids, | 
					
					
						
						| 
							 | 
						    q_start_sids, | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    0, *q.stride(), | 
					
					
						
						| 
							 | 
						    0, *k.stride(), | 
					
					
						
						| 
							 | 
						    0, *v.stride(), | 
					
					
						
						| 
							 | 
						    0, *out.stride(), | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    layout_crow_indices, | 
					
					
						
						| 
							 | 
						    layout_col_indices, | 
					
					
						
						| 
							 | 
						    *layout_crow_indices.stride(), | 
					
					
						
						| 
							 | 
						    *layout_col_indices.stride(), | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    q_k_ratio, | 
					
					
						
						| 
							 | 
						    HAS_BATCH_DIM = False, | 
					
					
						
						| 
							 | 
						    D_HEAD = head_size, | 
					
					
						
						| 
							 | 
						    BLOCK_M = block_size, | 
					
					
						
						| 
							 | 
						    BLOCK_N = block_size, | 
					
					
						
						| 
							 | 
						    BLOCK_D = block_d, | 
					
					
						
						| 
							 | 
						    BLOCK_M_LOADING = 16 if decoding_only else block_size,  | 
					
					
						
						| 
							 | 
						    EVEN_D = block_d == head_size, | 
					
					
						
						| 
							 | 
						    num_warps = 1 if decoding_only else 4, | 
					
					
						
						| 
							 | 
						    num_stages = 3 | 
					
					
						
						| 
							 | 
						    ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    return out | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						@triton.jit | 
					
					
						
						| 
							 | 
						def _fwd_kernel_inner( | 
					
					
						
						| 
							 | 
						    acc, l_i, m_i, | 
					
					
						
						| 
							 | 
						    q, Q, | 
					
					
						
						| 
							 | 
						    k_block_col_idx, | 
					
					
						
						| 
							 | 
						    layout_col_ptr, | 
					
					
						
						| 
							 | 
						    layout_col_stride_h, layout_col_stride_m, | 
					
					
						
						| 
							 | 
						    k_ptrs, | 
					
					
						
						| 
							 | 
						    v_ptrs, | 
					
					
						
						| 
							 | 
						    off_h, offs_m, offs_n, offs_d, | 
					
					
						
						| 
							 | 
						    stride_kt, stride_vt, | 
					
					
						
						| 
							 | 
						    sm_scale, | 
					
					
						
						| 
							 | 
						    k_seqlen, | 
					
					
						
						| 
							 | 
						    past_len, | 
					
					
						
						| 
							 | 
						    LAST_K_BLOCK: tl.constexpr, | 
					
					
						
						| 
							 | 
						    BLOCK_M_LOADING: tl.constexpr, | 
					
					
						
						| 
							 | 
						    BLOCK_N: tl.constexpr, | 
					
					
						
						| 
							 | 
						    D_HEAD: tl.constexpr, | 
					
					
						
						| 
							 | 
						    EVEN_D: tl.constexpr, | 
					
					
						
						| 
							 | 
						    M_LT_N: tl.constexpr | 
					
					
						
						| 
							 | 
						): | 
					
					
						
						| 
							 | 
						    k_block_id = tl.load(layout_col_ptr +  off_h * layout_col_stride_h + k_block_col_idx * layout_col_stride_m).to(tl.int32) | 
					
					
						
						| 
							 | 
						    start_n = k_block_id * BLOCK_N | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    if LAST_K_BLOCK: | 
					
					
						
						| 
							 | 
						        if EVEN_D: | 
					
					
						
						| 
							 | 
						            k = tl.load(k_ptrs + start_n * stride_kt, | 
					
					
						
						| 
							 | 
						                        mask=offs_n[None, :] + start_n < k_seqlen) | 
					
					
						
						| 
							 | 
						        else: | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            k = tl.load(k_ptrs + start_n * stride_kt, | 
					
					
						
						| 
							 | 
						                        mask=(offs_n[None, :] + start_n < k_seqlen) & (offs_d[:, None] < D_HEAD)) | 
					
					
						
						| 
							 | 
						    else: | 
					
					
						
						| 
							 | 
						        if EVEN_D: | 
					
					
						
						| 
							 | 
						            k = tl.load(k_ptrs + start_n * stride_kt) | 
					
					
						
						| 
							 | 
						        else: | 
					
					
						
						| 
							 | 
						            k = tl.load(k_ptrs + start_n * stride_kt, | 
					
					
						
						| 
							 | 
						                        mask=offs_d[:, None] < D_HEAD) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    qk = tl.zeros([BLOCK_M_LOADING, BLOCK_N], dtype=tl.float32) | 
					
					
						
						| 
							 | 
						    qk += tl.dot(q, k) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    qk *= sm_scale | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    if LAST_K_BLOCK | M_LT_N: | 
					
					
						
						| 
							 | 
						        qk += tl.where(offs_m[:, None] + past_len >= (start_n + offs_n[None, :]), 0, float('-inf')) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    m_ij = tl.max(qk, 1) | 
					
					
						
						| 
							 | 
						    p = tl.exp(qk - m_ij[:, None]) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    l_ij = tl.sum(p, 1) | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    m_i_new = tl.maximum(m_i, m_ij) | 
					
					
						
						| 
							 | 
						    alpha = tl.exp(m_i - m_i_new) | 
					
					
						
						| 
							 | 
						    beta = tl.exp(m_ij - m_i_new) | 
					
					
						
						| 
							 | 
						    l_i_new = alpha * l_i + beta * l_ij | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    p_scale = beta / l_i_new | 
					
					
						
						| 
							 | 
						    p = p * p_scale[:, None] | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    acc_scale = l_i / l_i_new * alpha | 
					
					
						
						| 
							 | 
						    acc = acc * acc_scale[:, None] | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    p = p.to(Q.dtype.element_ty) | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    if LAST_K_BLOCK: | 
					
					
						
						| 
							 | 
						        if EVEN_D: | 
					
					
						
						| 
							 | 
						            v = tl.load(v_ptrs + start_n * stride_vt, | 
					
					
						
						| 
							 | 
						                        mask=offs_n[:, None] + start_n < k_seqlen) | 
					
					
						
						| 
							 | 
						        else: | 
					
					
						
						| 
							 | 
						            v = tl.load(v_ptrs + start_n * stride_vt, | 
					
					
						
						| 
							 | 
						                        mask=(offs_n[:, None] + start_n < k_seqlen) & (offs_d[None, :] < D_HEAD)) | 
					
					
						
						| 
							 | 
						    else: | 
					
					
						
						| 
							 | 
						        if EVEN_D: | 
					
					
						
						| 
							 | 
						            v = tl.load(v_ptrs + start_n * stride_vt) | 
					
					
						
						| 
							 | 
						        else: | 
					
					
						
						| 
							 | 
						            v = tl.load(v_ptrs + start_n * stride_vt, | 
					
					
						
						| 
							 | 
						                        mask=offs_d[None, :] < D_HEAD) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    acc += tl.dot(p, v) | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    l_i = l_i_new | 
					
					
						
						| 
							 | 
						    m_i = m_i_new | 
					
					
						
						| 
							 | 
						    return acc, l_i, m_i | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						@triton.heuristics( | 
					
					
						
						| 
							 | 
						    { | 
					
					
						
						| 
							 | 
						        'M_LT_N': lambda kwargs: kwargs['BLOCK_M'] < kwargs['BLOCK_N'], | 
					
					
						
						| 
							 | 
						    } | 
					
					
						
						| 
							 | 
						) | 
					
					
						
						| 
							 | 
						@triton.jit | 
					
					
						
						| 
							 | 
						def _fwd_kernel_batch_inference( | 
					
					
						
						| 
							 | 
						    Q, K, V, Out, | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						    sm_scale, | 
					
					
						
						| 
							 | 
						    q_batch_starts, | 
					
					
						
						| 
							 | 
						    q_batch_ends, | 
					
					
						
						| 
							 | 
						    k_batch_starts, | 
					
					
						
						| 
							 | 
						    k_batch_ends, | 
					
					
						
						| 
							 | 
						    q_batch_ids, | 
					
					
						
						| 
							 | 
						    q_start_sids, | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						    stride_qb, stride_qt, stride_qh, stride_qd, | 
					
					
						
						| 
							 | 
						    stride_kb, stride_kt, stride_kh, stride_kd, | 
					
					
						
						| 
							 | 
						    stride_vb, stride_vt, stride_vh, stride_vd, | 
					
					
						
						| 
							 | 
						    stride_ob, stride_ot, stride_oh, stride_od, | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						    layout_crow_ptr, | 
					
					
						
						| 
							 | 
						    layout_col_ptr, | 
					
					
						
						| 
							 | 
						    layout_crow_stride_h, layout_crow_stride_m, | 
					
					
						
						| 
							 | 
						    layout_col_stride_h, layout_col_stride_m, | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						    q_k_ratio, | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						    HAS_BATCH_DIM: tl.constexpr, | 
					
					
						
						| 
							 | 
						    D_HEAD: tl.constexpr, | 
					
					
						
						| 
							 | 
						    BLOCK_M: tl.constexpr, | 
					
					
						
						| 
							 | 
						    BLOCK_N: tl.constexpr, | 
					
					
						
						| 
							 | 
						    BLOCK_D: tl.constexpr, | 
					
					
						
						| 
							 | 
						    BLOCK_M_LOADING: tl.constexpr, | 
					
					
						
						| 
							 | 
						    EVEN_D: tl.constexpr, | 
					
					
						
						| 
							 | 
						    M_LT_N: tl.constexpr | 
					
					
						
						| 
							 | 
						): | 
					
					
						
						| 
							 | 
						    ''' | 
					
					
						
						| 
							 | 
						    NOTATION: | 
					
					
						
						| 
							 | 
						    pid: position id | 
					
					
						
						| 
							 | 
						    sid: storage id | 
					
					
						
						| 
							 | 
						    sbid: storage block id | 
					
					
						
						| 
							 | 
						    pbid: position block id | 
					
					
						
						| 
							 | 
						    offs_m, offs_n: storage offsets of m-dim(q, row) and n-dim(k, col) | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						    q and blocks in KV needs to be contiguous | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						    Arguments: | 
					
					
						
						| 
							 | 
						    kv_seq_lens: for compute past_len | 
					
					
						
						| 
							 | 
						    kv_storage_offsets: similar to block_tables in vllm, except it is dynamic. | 
					
					
						
						| 
							 | 
						        TODO: fix this | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						    TODO: | 
					
					
						
						| 
							 | 
						    Optimize grouped-attn | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						    CUDA graph support issue | 
					
					
						
						| 
							 | 
						        1. grid is dynamic: vllm set up multiple cuda graph in decoding phase, with diff max token size (16, 32, ...) | 
					
					
						
						| 
							 | 
						            since we mix prompt and decoing phase here, it can be more complex. | 
					
					
						
						| 
							 | 
						            need to set up diff cuda-graph for diff (off_zm, off_z) | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						            # indeed, q_batch_ids can be padded to maximum number of grid[0], i.e., assume all decoding | 
					
					
						
						| 
							 | 
						            therefore, cu_seqlens_q, kv_seq_lens | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						    ''' | 
					
					
						
						| 
							 | 
						    off_zm = tl.program_id(0) | 
					
					
						
						| 
							 | 
						    off_h = tl.program_id(1) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    off_h_for_kv = off_h // q_k_ratio | 
					
					
						
						| 
							 | 
						    off_z = tl.load(q_batch_ids + off_zm).to(tl.int32)    | 
					
					
						
						| 
							 | 
						    q_start_sid = tl.load(q_start_sids + off_zm) | 
					
					
						
						| 
							 | 
						    start_m = q_start_sid // BLOCK_M | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    if HAS_BATCH_DIM: | 
					
					
						
						| 
							 | 
						        Q += off_z * stride_qb | 
					
					
						
						| 
							 | 
						        K += off_z * stride_kb | 
					
					
						
						| 
							 | 
						        V += off_z * stride_vb | 
					
					
						
						| 
							 | 
						        Out += off_z * stride_ob | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M_LOADING) | 
					
					
						
						| 
							 | 
						    offs_n = tl.arange(0, BLOCK_N) | 
					
					
						
						| 
							 | 
						    offs_d = tl.arange(0, BLOCK_D) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    q_cu_start = tl.load(q_batch_starts + off_z).to(tl.int32) | 
					
					
						
						| 
							 | 
						    q_seqlen = tl.load(q_batch_ends + off_z).to(tl.int32) - q_cu_start | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    k_cu_start = tl.load(k_batch_starts + off_z).to(tl.int32) | 
					
					
						
						| 
							 | 
						    k_seqlen = tl.load(k_batch_ends + off_z).to(tl.int32) - k_cu_start | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    past_len = k_seqlen - q_seqlen | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    Q += q_cu_start * stride_qt + off_h * stride_qh | 
					
					
						
						| 
							 | 
						    K += k_cu_start * stride_kt + off_h_for_kv * stride_kh | 
					
					
						
						| 
							 | 
						    V += k_cu_start * stride_vt + off_h_for_kv * stride_vh | 
					
					
						
						| 
							 | 
						    Out += q_cu_start * stride_ot + off_h * stride_oh | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    q_pbid = (past_len + q_start_sid) // BLOCK_M | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    if EVEN_D: | 
					
					
						
						| 
							 | 
						        q = tl.load(Q + offs_m[:, None] * stride_qt + offs_d[None, :] * stride_qd, | 
					
					
						
						| 
							 | 
						                    mask=offs_m[:, None] < q_seqlen) | 
					
					
						
						| 
							 | 
						    else: | 
					
					
						
						| 
							 | 
						        q = tl.load(Q + offs_m[:, None] * stride_qt + offs_d[None, :] * stride_qd, | 
					
					
						
						| 
							 | 
						                    mask=(offs_m[:, None] < q_seqlen) & (offs_d[None, :] < D_HEAD), | 
					
					
						
						| 
							 | 
						                    other=0) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    sparse_crow_ptr = layout_crow_ptr + off_h * layout_crow_stride_h + q_pbid * layout_crow_stride_m | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    k_block_start = tl.load(sparse_crow_ptr).to(tl.int32) | 
					
					
						
						| 
							 | 
						    k_block_end = tl.load(sparse_crow_ptr + 1).to(tl.int32) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    m_i = tl.zeros([BLOCK_M_LOADING], dtype=tl.float32) - float('inf') | 
					
					
						
						| 
							 | 
						    l_i = tl.zeros([BLOCK_M_LOADING], dtype=tl.float32) | 
					
					
						
						| 
							 | 
						    acc = tl.zeros([BLOCK_M_LOADING, BLOCK_D], dtype=tl.float32) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    k_ptrs = K + offs_n[None, :] * stride_kt + offs_d[:, None] * stride_kd | 
					
					
						
						| 
							 | 
						    v_ptrs = V + offs_n[:, None] * stride_vt + offs_d[None, :] * stride_vd | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    for k_block_col_idx in range(k_block_start, k_block_end - 1): | 
					
					
						
						| 
							 | 
						        acc, l_i, m_i = _fwd_kernel_inner( | 
					
					
						
						| 
							 | 
						            acc, l_i, m_i, | 
					
					
						
						| 
							 | 
						            q, Q, | 
					
					
						
						| 
							 | 
						            k_block_col_idx, | 
					
					
						
						| 
							 | 
						            layout_col_ptr, | 
					
					
						
						| 
							 | 
						            layout_col_stride_h, layout_col_stride_m, | 
					
					
						
						| 
							 | 
						            k_ptrs, | 
					
					
						
						| 
							 | 
						            v_ptrs, | 
					
					
						
						| 
							 | 
						            off_h, offs_m, offs_n, offs_d, | 
					
					
						
						| 
							 | 
						            stride_kt, stride_vt, | 
					
					
						
						| 
							 | 
						            sm_scale, | 
					
					
						
						| 
							 | 
						            k_seqlen, | 
					
					
						
						| 
							 | 
						            past_len, | 
					
					
						
						| 
							 | 
						            False, | 
					
					
						
						| 
							 | 
						            BLOCK_M_LOADING, | 
					
					
						
						| 
							 | 
						            BLOCK_N, | 
					
					
						
						| 
							 | 
						            D_HEAD, | 
					
					
						
						| 
							 | 
						            EVEN_D, | 
					
					
						
						| 
							 | 
						            M_LT_N | 
					
					
						
						| 
							 | 
						            ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    acc, l_i, m_i = _fwd_kernel_inner( | 
					
					
						
						| 
							 | 
						        acc, l_i, m_i, | 
					
					
						
						| 
							 | 
						        q, Q, | 
					
					
						
						| 
							 | 
						        k_block_end - 1, | 
					
					
						
						| 
							 | 
						        layout_col_ptr, | 
					
					
						
						| 
							 | 
						        layout_col_stride_h, layout_col_stride_m, | 
					
					
						
						| 
							 | 
						        k_ptrs, | 
					
					
						
						| 
							 | 
						        v_ptrs, | 
					
					
						
						| 
							 | 
						        off_h, offs_m, offs_n, offs_d, | 
					
					
						
						| 
							 | 
						        stride_kt, stride_vt, | 
					
					
						
						| 
							 | 
						        sm_scale, | 
					
					
						
						| 
							 | 
						        k_seqlen, | 
					
					
						
						| 
							 | 
						        past_len, | 
					
					
						
						| 
							 | 
						        True, | 
					
					
						
						| 
							 | 
						        BLOCK_M_LOADING, | 
					
					
						
						| 
							 | 
						        BLOCK_N, | 
					
					
						
						| 
							 | 
						        D_HEAD, | 
					
					
						
						| 
							 | 
						        EVEN_D, | 
					
					
						
						| 
							 | 
						        M_LT_N | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    if EVEN_D: | 
					
					
						
						| 
							 | 
						        tl.store(Out + offs_m[:, None] * stride_ot + offs_d[None, :] * stride_od, acc, | 
					
					
						
						| 
							 | 
						                mask=offs_m[:, None] < q_seqlen) | 
					
					
						
						| 
							 | 
						    else: | 
					
					
						
						| 
							 | 
						        tl.store(Out + offs_m[:, None] * stride_ot + offs_d[None, :] * stride_od, acc, | 
					
					
						
						| 
							 | 
						                mask=(offs_m[:, None] < q_seqlen) & (offs_d[None, :] < D_HEAD)) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def torch_attention(q, k, v, attn_mask=None, sm_scale=None, block_attn_mask=None, block_size=128, do=None): | 
					
					
						
						| 
							 | 
						    ''' | 
					
					
						
						| 
							 | 
						    q, k, v: shape=(batch, n_heads, seq, dim) | 
					
					
						
						| 
							 | 
						    ''' | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    if sm_scale is None: | 
					
					
						
						| 
							 | 
						        sm_scale = math.sqrt(float(q.size(-1))) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    if block_attn_mask is not None: | 
					
					
						
						| 
							 | 
						        assert attn_mask is None | 
					
					
						
						| 
							 | 
						        outs = [] | 
					
					
						
						| 
							 | 
						        for s in range(0, q.size(2), block_size): | 
					
					
						
						| 
							 | 
						            e = min(s + block_size, q.size(2)) | 
					
					
						
						| 
							 | 
						            q_block = q[:, :, s:e] | 
					
					
						
						| 
							 | 
						            attn = torch.einsum('bhmd,bhnd->bhmn', q_block, k[:, :, :e]).float() * sm_scale | 
					
					
						
						| 
							 | 
						            mask = block_attn_mask[..., s // block_size, : (s // block_size + 1)] | 
					
					
						
						| 
							 | 
						            mask = torch.kron(mask, torch.ones(block_size, block_size, device=mask.device)) | 
					
					
						
						| 
							 | 
						            mask[..., :, s:].masked_fill_(torch.arange(0, block_size)[:, None] <= torch.arange(0, block_size)[None, :], 0) | 
					
					
						
						| 
							 | 
						            attn = attn.masked_fill((1 - mask).bool(), float('-inf')) | 
					
					
						
						| 
							 | 
						            attn = attn.softmax(-1) | 
					
					
						
						| 
							 | 
						            out = torch.einsum('bhmn,bhnd->bhmd', attn.type_as(v), v[:, :, :e]) | 
					
					
						
						| 
							 | 
						            outs.append(out) | 
					
					
						
						| 
							 | 
						        torch_output = torch.cat(outs, dim=2) | 
					
					
						
						| 
							 | 
						    else: | 
					
					
						
						| 
							 | 
						        attn = torch.einsum('bhmd,bhnd->bhmn', q, k).float() * sm_scale | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        if attn_mask is not None: | 
					
					
						
						| 
							 | 
						            attn = attn.masked_fill((1 - attn_mask).bool(), float('-inf')) | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        attn = attn.softmax(-1) | 
					
					
						
						| 
							 | 
						        if do is not None: | 
					
					
						
						| 
							 | 
						            dv = torch.einsum('bhqk,bhqd->bhkd', attn.type_as(do), do) | 
					
					
						
						| 
							 | 
						            print(f'> torch_attn computed dv: {dv=}') | 
					
					
						
						| 
							 | 
						        torch_output = torch.einsum('bhmn,bhnd->bhmd', attn.type_as(v), v) | 
					
					
						
						| 
							 | 
						    return torch_output | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(2, 8, 2048, 128), (1, 4, 4096, 64)]) | 
					
					
						
						| 
							 | 
						def test_op(Z, H, N_CTX, D_HEAD, Q_LEN=None, dtype=torch.bfloat16, homo_head=True, kernel_block_size=None, sparse_block_size=128, backward=True, | 
					
					
						
						| 
							 | 
						            sparse_attention_fn=None, local_blocks=4, vert_stride=4, sm_scale=None, max_length=None): | 
					
					
						
						| 
							 | 
						    Q_LEN = Q_LEN or N_CTX | 
					
					
						
						| 
							 | 
						    torch.manual_seed(20) | 
					
					
						
						| 
							 | 
						    q = torch.empty((Z, H, Q_LEN, D_HEAD), dtype=dtype, device='cuda').normal_(mean=0, std=.5)  | 
					
					
						
						| 
							 | 
						    k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device='cuda').normal_(mean=0, std=.5)  | 
					
					
						
						| 
							 | 
						    v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device='cuda').normal_(mean=0, std=.5)  | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    if sm_scale is None: | 
					
					
						
						| 
							 | 
						        sm_scale = 1. / math.sqrt(D_HEAD) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    sm_scale = 0.0078125 | 
					
					
						
						| 
							 | 
						    if backward: | 
					
					
						
						| 
							 | 
						        q.requires_grad_(), k.requires_grad_(), v.requires_grad_() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    dout = torch.randn_like(q).contiguous() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    mask_csr, _, mask_dense = get_sparse_attn_mask(q, N_CTX, BLOCK=sparse_block_size, | 
					
					
						
						| 
							 | 
						                            local_blocks=local_blocks, vert_stride=vert_stride, homo_head=homo_head, return_dense=True) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    if sparse_attention_fn is None: | 
					
					
						
						| 
							 | 
						        sparse_attention_fn = get_local_strided_sparse_attention_op(H, N_CTX, | 
					
					
						
						| 
							 | 
						                                                                    sparse_block_size=sparse_block_size, | 
					
					
						
						| 
							 | 
						                                                                    local_blocks=local_blocks, | 
					
					
						
						| 
							 | 
						                                                                    vert_stride=vert_stride, | 
					
					
						
						| 
							 | 
						                                                                    homo_head=homo_head, | 
					
					
						
						| 
							 | 
						                                                                    device=q.device, | 
					
					
						
						| 
							 | 
						                                                                    dtype=q.dtype, | 
					
					
						
						| 
							 | 
						                                                                    kernel_block_size=kernel_block_size) | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    ref_out = torch_attention(q, k, v, mask_dense, sm_scale) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    if backward: | 
					
					
						
						| 
							 | 
						        ref_out.backward(dout) | 
					
					
						
						| 
							 | 
						        ref_dv, v.grad = v.grad.clone(), None | 
					
					
						
						| 
							 | 
						        ref_dk, k.grad = k.grad.clone(), None | 
					
					
						
						| 
							 | 
						        ref_dq, q.grad = q.grad.clone(), None | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    tri_out = sparse_attention_fn(q, k, v, sm_scale) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    decimal = 1 if dtype == torch.bfloat16 else 2 | 
					
					
						
						| 
							 | 
						    assert torch.allclose(ref_out.cpu(), tri_out.cpu(), atol=1e-2, rtol=0), f'>> {ref_out[0, 0, :, 0].tolist()=}\n\n{tri_out[0, 0, :, 0].tolist()=}' | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    if backward: | 
					
					
						
						| 
							 | 
						        tri_out.backward(dout) | 
					
					
						
						| 
							 | 
						        tri_dv, v.grad = v.grad.clone(), None | 
					
					
						
						| 
							 | 
						        tri_dk, k.grad = k.grad.clone(), None | 
					
					
						
						| 
							 | 
						        tri_dq, q.grad = q.grad.clone(), None | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    if backward: | 
					
					
						
						| 
							 | 
						        assert torch.allclose(ref_dv, tri_dv, atol=1e-2, rtol=1e-2) | 
					
					
						
						| 
							 | 
						        assert torch.allclose(ref_dk, tri_dk, atol=1e-2, rtol=0) | 
					
					
						
						| 
							 | 
						        assert torch.allclose(ref_dq, tri_dq, atol=1e-2, rtol=0) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    print(f'> test passed: {Z=}, {H=}, {N_CTX=}, {D_HEAD=}, {Q_LEN=}, {dtype=}, {homo_head=}, {sparse_block_size=}') | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						if __name__ == '__main__': | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    GPU_TYPE = os.popen('nvidia-smi --query-gpu=name --format=csv | tail -n 1').read().strip() | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    support_backward = True  | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    HAS_DENSE_TRITON_FLASH = False | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    try: | 
					
					
						
						| 
							 | 
						        from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_unpadded_func | 
					
					
						
						| 
							 | 
						        HAS_FLASH = True | 
					
					
						
						| 
							 | 
						    except BaseException: | 
					
					
						
						| 
							 | 
						        HAS_FLASH = False | 
					
					
						
						| 
							 | 
						        print('> cannot import flash_attn') | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    BATCH, N_HEADS, N_CTX, D_HEAD = 4, 32, 4096, 128   | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    BLOCK_SIZE = 64 | 
					
					
						
						| 
							 | 
						    LOCAl_BLOCKS = 8  | 
					
					
						
						| 
							 | 
						    VERT_STRIDE = 1  | 
					
					
						
						| 
							 | 
						    HOMO_HEAD = False | 
					
					
						
						| 
							 | 
						    sparse_type = 'home' if HOMO_HEAD else 'hetero' | 
					
					
						
						| 
							 | 
						    dtype = torch.bfloat16 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    modes = ['fwd', 'bwd'] if support_backward else ['fwd'] | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    configs = [triton.testing.Benchmark( | 
					
					
						
						| 
							 | 
						        x_names=['SEQ_LEN'], | 
					
					
						
						| 
							 | 
						        x_vals=[2**i for i in range(8, 16)], | 
					
					
						
						| 
							 | 
						        line_arg='provider', | 
					
					
						
						| 
							 | 
						        line_vals=(['triton'] if HAS_DENSE_TRITON_FLASH else []) + (['flash'] if HAS_FLASH else []) + ['triton_sparse'], | 
					
					
						
						| 
							 | 
						        line_names=(['Triton-Dense'] if HAS_DENSE_TRITON_FLASH else [])  + (['Flash-Dense'] if HAS_FLASH else []) + ['Triton-Sparse'], | 
					
					
						
						| 
							 | 
						        styles=[('red', '-'), ('blue', '-'), ('green', '-')], | 
					
					
						
						| 
							 | 
						        ylabel='ms', | 
					
					
						
						| 
							 | 
						        plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-sparse-local{LOCAl_BLOCKS}-vert{VERT_STRIDE}-{sparse_type}-{dtype}-{mode}', | 
					
					
						
						| 
							 | 
						        args={'H': N_HEADS, 'BATCH': BATCH, 'D_HEAD': D_HEAD, 'dtype': dtype, 'mode': mode} | 
					
					
						
						| 
							 | 
						    ) for mode in modes] | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    @triton.testing.perf_report(configs) | 
					
					
						
						| 
							 | 
						    def bench_flash_attention(BATCH, H, SEQ_LEN, D_HEAD, mode, provider, dtype=torch.bfloat16, device='cuda', sparse_attention_fn=None): | 
					
					
						
						| 
							 | 
						        assert mode in ['fwd', 'bwd'] | 
					
					
						
						| 
							 | 
						        warmup = 25 | 
					
					
						
						| 
							 | 
						        rep = 100 | 
					
					
						
						| 
							 | 
						        N_CTX = SEQ_LEN | 
					
					
						
						| 
							 | 
						        if provider == 'triton': | 
					
					
						
						| 
							 | 
						            q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=True) | 
					
					
						
						| 
							 | 
						            k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=True) | 
					
					
						
						| 
							 | 
						            v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=True) | 
					
					
						
						| 
							 | 
						            sm_scale = 1.3 | 
					
					
						
						| 
							 | 
						            fn = lambda: triton_attention(q, k, v, sm_scale) | 
					
					
						
						| 
							 | 
						            if mode == 'bwd': | 
					
					
						
						| 
							 | 
						                o = fn() | 
					
					
						
						| 
							 | 
						                do = torch.randn_like(o) | 
					
					
						
						| 
							 | 
						                fn = lambda: o.backward(do, retain_graph=True) | 
					
					
						
						| 
							 | 
						            ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) | 
					
					
						
						| 
							 | 
						            return ms | 
					
					
						
						| 
							 | 
						        if provider == 'triton_sparse': | 
					
					
						
						| 
							 | 
						            q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=True) | 
					
					
						
						| 
							 | 
						            k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=True) | 
					
					
						
						| 
							 | 
						            v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=True) | 
					
					
						
						| 
							 | 
						            sm_scale = 1.3 | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						            if sparse_attention_fn is None: | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						                sparse_attention_fn = get_local_strided_sparse_attention_op(H, SEQ_LEN, | 
					
					
						
						| 
							 | 
						                                                                            local_blocks=LOCAl_BLOCKS, | 
					
					
						
						| 
							 | 
						                                                                            vert_stride=VERT_STRIDE, | 
					
					
						
						| 
							 | 
						                                                                            homo_head=HOMO_HEAD, | 
					
					
						
						| 
							 | 
						                                                                            sparse_block_size=BLOCK_SIZE, | 
					
					
						
						| 
							 | 
						                                                                            kernel_block_size=BLOCK_SIZE, | 
					
					
						
						| 
							 | 
						                                                                            device=q.device) | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            fn = lambda: sparse_attention_fn(q, k, v, sm_scale) | 
					
					
						
						| 
							 | 
						            if mode == 'bwd': | 
					
					
						
						| 
							 | 
						                o = fn() | 
					
					
						
						| 
							 | 
						                do = torch.randn_like(o) | 
					
					
						
						| 
							 | 
						                fn = lambda: o.backward(do, retain_graph=True) | 
					
					
						
						| 
							 | 
						            ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) | 
					
					
						
						| 
							 | 
						            return ms | 
					
					
						
						| 
							 | 
						        if provider == 'flash': | 
					
					
						
						| 
							 | 
						            lengths = torch.full((BATCH,), fill_value=N_CTX, device=device) | 
					
					
						
						| 
							 | 
						            cu_seqlens = torch.zeros((BATCH + 1,), device=device, dtype=torch.int32) | 
					
					
						
						| 
							 | 
						            cu_seqlens[1:] = lengths.cumsum(0) | 
					
					
						
						| 
							 | 
						            qkv = torch.randn((BATCH * N_CTX, 3, H, D_HEAD), dtype=dtype, device=device, requires_grad=True) | 
					
					
						
						| 
							 | 
						            fn = lambda: flash_attn_func(qkv, cu_seqlens, 0., N_CTX, causal=True) | 
					
					
						
						| 
							 | 
						            if mode == 'bwd': | 
					
					
						
						| 
							 | 
						                o = fn() | 
					
					
						
						| 
							 | 
						                do = torch.randn_like(o) | 
					
					
						
						| 
							 | 
						                fn = lambda: o.backward(do, retain_graph=True) | 
					
					
						
						| 
							 | 
						            ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) | 
					
					
						
						| 
							 | 
						            return ms | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    BATCH, N_HEADS, N_CTX, D_HEAD, Q_LEN = 4, 32, 4096, 128, 1   | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    BLOCK_SIZE = 64 | 
					
					
						
						| 
							 | 
						    LOCAl_BLOCKS = 8  | 
					
					
						
						| 
							 | 
						    VERT_STRIDE = 16  | 
					
					
						
						| 
							 | 
						    HOMO_HEAD = False | 
					
					
						
						| 
							 | 
						    sparse_type = 'home' if HOMO_HEAD else 'hetero' | 
					
					
						
						| 
							 | 
						    dtype = torch.bfloat16 | 
					
					
						
						| 
							 | 
						    MAX_N_CTX = 8192 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    configs = [triton.testing.Benchmark( | 
					
					
						
						| 
							 | 
						        x_names=['PAST_LEN'], | 
					
					
						
						| 
							 | 
						        x_vals=[2**i - 1 for i in range(8, 14)], | 
					
					
						
						| 
							 | 
						        line_arg='provider', | 
					
					
						
						| 
							 | 
						        line_vals=['torch'] + (['flash'] if HAS_FLASH else []) + ['triton_sparse', 'triton_dense'], | 
					
					
						
						| 
							 | 
						        line_names=['Torch']  + (['Flash-Dense'] if HAS_FLASH else []) + ['Triton-Sparse', 'Triton-Dense'], | 
					
					
						
						| 
							 | 
						        styles=[('red', '-'), ('blue', '-'), ('green', '-'), ('cyan', '-')], | 
					
					
						
						| 
							 | 
						        ylabel='ms', | 
					
					
						
						| 
							 | 
						        plot_name=f'fused-attention-inference-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-sparse-local{LOCAl_BLOCKS}-vert{VERT_STRIDE}-{sparse_type}', | 
					
					
						
						| 
							 | 
						        args={'H': N_HEADS, 'BATCH': BATCH, 'D_HEAD': D_HEAD, 'Q_LEN': Q_LEN, 'dtype': torch.float16, 'mode': mode} | 
					
					
						
						| 
							 | 
						    ) for mode in ['fwd']] | 
					
					
						
						| 
							 | 
						    @triton.testing.perf_report(configs) | 
					
					
						
						| 
							 | 
						    def bench_flash_attention_inference(BATCH, H, PAST_LEN, D_HEAD, Q_LEN, mode, provider, dtype=torch.bfloat16, device='cuda'): | 
					
					
						
						| 
							 | 
						        assert mode in ['fwd'] | 
					
					
						
						| 
							 | 
						        warmup = 25 | 
					
					
						
						| 
							 | 
						        rep = 100 | 
					
					
						
						| 
							 | 
						        N_CTX = PAST_LEN + Q_LEN | 
					
					
						
						| 
							 | 
						        if provider == 'torch': | 
					
					
						
						| 
							 | 
						            q = torch.randn((BATCH, H, Q_LEN, D_HEAD), dtype=dtype, device='cuda', requires_grad=False) | 
					
					
						
						| 
							 | 
						            k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=False) | 
					
					
						
						| 
							 | 
						            v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=False) | 
					
					
						
						| 
							 | 
						            sm_scale = 1.3 | 
					
					
						
						| 
							 | 
						            mask_csr, _, mask_dense = get_sparse_attn_mask(q, N_CTX, BLOCK=BLOCK_SIZE, | 
					
					
						
						| 
							 | 
						                                    local_blocks=LOCAl_BLOCKS, vert_stride=VERT_STRIDE, homo_head=VERT_STRIDE, return_dense=True) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						            fn = lambda: torch_attention(q, k, v, mask_dense, sm_scale=sm_scale, block_size=2048) | 
					
					
						
						| 
							 | 
						            ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) | 
					
					
						
						| 
							 | 
						            return ms | 
					
					
						
						| 
							 | 
						        if provider == 'triton_sparse': | 
					
					
						
						| 
							 | 
						            q = torch.randn((BATCH, H, Q_LEN, D_HEAD), dtype=dtype, device='cuda', requires_grad=False) | 
					
					
						
						| 
							 | 
						            k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=False) | 
					
					
						
						| 
							 | 
						            v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=False) | 
					
					
						
						| 
							 | 
						            sm_scale = 1.3 | 
					
					
						
						| 
							 | 
						            sparse_attention_fn = get_local_strided_sparse_attention_op(H, MAX_N_CTX, | 
					
					
						
						| 
							 | 
						                                                                        local_blocks=LOCAl_BLOCKS, | 
					
					
						
						| 
							 | 
						                                                                        vert_stride=VERT_STRIDE, | 
					
					
						
						| 
							 | 
						                                                                        homo_head=HOMO_HEAD, | 
					
					
						
						| 
							 | 
						                                                                        sparse_block_size=BLOCK_SIZE, | 
					
					
						
						| 
							 | 
						                                                                        kernel_block_size=BLOCK_SIZE, | 
					
					
						
						| 
							 | 
						                                                                        device=q.device, | 
					
					
						
						| 
							 | 
						                                                                        inference=True) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						            fn = lambda: sparse_attention_fn(q, k, v, sm_scale) | 
					
					
						
						| 
							 | 
						            ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) | 
					
					
						
						| 
							 | 
						            return ms | 
					
					
						
						| 
							 | 
						        if provider == 'triton_dense': | 
					
					
						
						| 
							 | 
						            q = torch.randn((BATCH, H, Q_LEN, D_HEAD), dtype=dtype, device='cuda', requires_grad=False) | 
					
					
						
						| 
							 | 
						            k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=False) | 
					
					
						
						| 
							 | 
						            v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=False) | 
					
					
						
						| 
							 | 
						            sm_scale = 1.3 | 
					
					
						
						| 
							 | 
						            sparse_attention_fn = get_local_strided_sparse_attention_op(H, MAX_N_CTX, | 
					
					
						
						| 
							 | 
						                                                                        local_blocks=1, | 
					
					
						
						| 
							 | 
						                                                                        vert_stride=1, | 
					
					
						
						| 
							 | 
						                                                                        homo_head=True, | 
					
					
						
						| 
							 | 
						                                                                        sparse_block_size=BLOCK_SIZE, | 
					
					
						
						| 
							 | 
						                                                                        kernel_block_size=BLOCK_SIZE, | 
					
					
						
						| 
							 | 
						                                                                        device=q.device, | 
					
					
						
						| 
							 | 
						                                                                        inference=True) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						            fn = lambda: sparse_attention_fn(q, k, v, sm_scale) | 
					
					
						
						| 
							 | 
						            ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) | 
					
					
						
						| 
							 | 
						            return ms | 
					
					
						
						| 
							 | 
						        if provider == 'flash': | 
					
					
						
						| 
							 | 
						            assert Q_LEN == 1 | 
					
					
						
						| 
							 | 
						            lengths = torch.full((BATCH,), fill_value=N_CTX, device=device) | 
					
					
						
						| 
							 | 
						            cu_seqlens = torch.zeros((BATCH + 1,), device=device, dtype=torch.int32) | 
					
					
						
						| 
							 | 
						            cu_seqlens[1:] = lengths.cumsum(0) | 
					
					
						
						| 
							 | 
						            cu_seqlens_q = torch.arange(BATCH + 1, device=device, dtype=torch.int32) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            q = torch.randn((BATCH, H, D_HEAD), dtype=dtype, device='cuda', requires_grad=False) | 
					
					
						
						| 
							 | 
						            k = torch.randn((BATCH*N_CTX, H, D_HEAD), dtype=dtype, device='cuda', requires_grad=False) | 
					
					
						
						| 
							 | 
						            v = torch.randn((BATCH*N_CTX, H, D_HEAD), dtype=dtype, device='cuda', requires_grad=False) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						            fn = lambda: flash_attn_unpadded_func(q, k, v, cu_seqlens_q, cu_seqlens, 1, N_CTX, dropout_p=0, softmax_scale=1.3, causal=False) | 
					
					
						
						| 
							 | 
						            ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) | 
					
					
						
						| 
							 | 
						            return ms | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    test_op(1, 4, 512, 128, dtype=torch.float16, homo_head=False, backward=support_backward) | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    bench_flash_attention_inference.run(save_path='.', print_data=True) | 
					
					
						
						| 
							 | 
						    exit() | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    test_op(1, 2, 1024, 64, kernel_block_size=64, sparse_block_size=64, | 
					
					
						
						| 
							 | 
						            dtype=torch.bfloat16, homo_head=False, backward=support_backward) | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    test_op(1, 16, 224, 128, dtype=torch.bfloat16, homo_head=False, backward=False, sparse_block_size=128, | 
					
					
						
						| 
							 | 
						            kernel_block_size=64, local_blocks=8, vert_stride=8) | 
					
					
						
						| 
							 | 
						    test_op(3, 2, 2047, 128, homo_head=False, backward=False) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    test_op(1, 16, 224, 128, dtype=torch.bfloat16, homo_head=False, backward=False, kernel_block_size=64) | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    test_op(1, 2, 1024, 128, kernel_block_size=128, sparse_block_size=128, dtype=torch.bfloat16, homo_head=False, | 
					
					
						
						| 
							 | 
						            backward=support_backward, local_blocks=1, vert_stride=1) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    test_op(1, 4, 512 + 256, 128, dtype=torch.float16, homo_head=False, backward=support_backward) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    test_op(2, 4, 8192, 64, homo_head=False, backward=support_backward) | 
					
					
						
						| 
							 | 
						    test_op(2, 4, 8192, 128, dtype=torch.bfloat16, homo_head=False, backward=support_backward) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    test_op(3, 2, 2048, 64, homo_head=True, dtype=torch.bfloat16, backward=False) | 
					
					
						
						| 
							 | 
						    test_op(3, 2, 2048, 64, homo_head=True, backward=support_backward) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    bench_flash_attention.run(save_path='.', print_data=True) | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						
 |