"""
    Author: Eric Lin (xihlin)
"""
"""
    ... note(bapatra)::
        This is written as one big file, instead of splitting into logical components because I was running into issues with transformers auto module
        imports when splitting into different files. I've tried keeping the logical partitions demarkated with comment blocks, but it is not ideal.
        In the future, would be really good to revisit this and refactor into a more readable file structure.

"""
from typing import TypeVar
from functools import lru_cache
import math
import pytest
import torch
import numpy as np

import triton
import triton.language as tl

import os

import dataclasses

Phi3SmallConfig = TypeVar('Phi3SmallConfig')

# triton 2.0.0: fail at backward on A100, for the examples, if h_dim=128.

# Done
#  1. strided of qkv
#  2. seq len not power of 2
#  3. bf16 with Triton May, 2023

# TODO:
#  1. wip: support non-contiguous backward, also help reduce memory allocation in training (q, k, v split)
#  2. block sparse with different BLOCK_M, BLOCK_N?
#  3. for Lq not divided by BLOCK_M, BLOCK_N, only apply mask to K/V on last batch, still need to apply mask on Q.
#     Attempt, fail to compile
#  4. For 2nd iter of inference,  BLOCK_M=1, how to make things work?  K/V maynot divided by BLOCK_N.
#  5. The inner loop can also be paralled via bigger num_stage(better) or on different thread-block (via m/L and atomic update, but this no-comm/sync between blocks)


###########################################################
################### Kernel Parameters #####################
###########################################################

@dataclasses.dataclass
class BlockSparseParams(object):
    block_size: int
    kernel_block_size: int
    num_local_blocks: int
    vert_stride: int
    homo_head_pattern: bool = False

    @classmethod
    def from_config(cls, config: Phi3SmallConfig) -> "BlockSparseParams":
        return cls(
            block_size=config.blocksparse_block_size,
            kernel_block_size=config.blocksparse_triton_kernel_block_size,
            num_local_blocks=config.blocksparse_num_local_blocks,
            vert_stride=config.blocksparse_vert_stride,
            homo_head_pattern=config.blocksparse_homo_head_pattern,
        )


###########################################################
###########################################################

###########################################################
################### Utility Functions #####################
###########################################################

# helper functions for 3D sparse pattern
# these function are not optimized and very inefficient. Avoid calling them too frequent.
# currently, it is only called within `get_local_strided_sparse_attention_op`, which is cached.
def dense_to_crow_col(x):
    ''' Turning a 2D/3D torch tensor (x) to CSR rows/cols indexing.
    param:
    TODO:
        1. improve efficiency, is it faster if done in CPU, or customize a cuda kernel for it?
    NOTE: col_indices padded -1
    '''
    pad = -1
    dim = x.dim()
    assert x.dim() in (2, 3)
    if x.dim() == 2:
        x = x[None]
    x = [xi.to_sparse_csr() for xi in x]
    crows = torch.vstack([xi.crow_indices() for xi in x])
    cols = [xi.col_indices() for xi in x]
    max_cols = max(len(xi) for xi in cols)
    cols = [torch.cat([xi, pad + xi.new_zeros(max_cols - xi.shape[0])]) for xi in cols]
    cols = torch.vstack(cols)
    if dim == 2:
        crows = crows[0]
        cols = cols[0]
    return crows, cols


def crow_col_to_dense(crows, cols, dtype=torch.float16):
    dim = crows.dim()
    if dim == 1:
        crows = crows[None]
        cols = cols[None]
    device = crows.device
    crows, cols = crows.cpu(), cols.cpu()  # faster in cpu
    shape = (crows.shape[0], crows.shape[1] - 1, cols.max() + 1)
    x = torch.zeros(shape, dtype=dtype)
    for i in range(shape[0]):
        for j in range(shape[1]):
            x[i, j, cols[i, crows[i, j]:crows[i, j+1]]] = 1
    if dim == 1:
        x = x[0]
    return x.to(device)


def dense_to_ccol_row(x):
    '''Similar, but to CSC format
    '''
    x = x.transpose(-2, -1)
    return dense_to_crow_col(x)


def ccol_row_to_dense(ccol, rows, dtype=torch.float16):
    return crow_col_to_dense(ccol, rows, dtype).permute(0, 2, 1).contiguous()


def _get_sparse_attn_mask_homo_head(q_len, N_CTX, dtype, device, BLOCK=128, local_blocks=4, vert_stride=4, return_dense=False):
    '''
    :return: a tuple of 3:
        - tuple of crow_indices, col_indices representation of CSR format.
        - block dense mask
        - all token dense mask (be aware that it can be OOM if it is too big) if `return_dense==True`, otherwise, None
    '''
    with torch.no_grad():
        N_BLOCK = triton.cdiv(N_CTX, BLOCK)
        q_pos = torch.arange(N_BLOCK)[:, None]
        k_pos = torch.arange(N_BLOCK)[None]
        mask_vert_strided = (torch.arange(N_BLOCK) + 1) % vert_stride == 0
        block_mask_dense = ((q_pos >= k_pos) & ((q_pos - k_pos < local_blocks) | mask_vert_strided)).to(device).to(dtype)
        N_BLOCK_Q = triton.cdiv(q_len, BLOCK)
        block_mask_dense_output = block_mask_dense[-N_BLOCK_Q:].contiguous().to_sparse_csr()
    if return_dense:
        mask_dense = torch.kron(block_mask_dense, block_mask_dense.new_ones((BLOCK, BLOCK)))
        causal_mask = torch.tril(torch.ones(N_CTX, N_CTX)).type_as(mask_dense)[-q_len:]
        mask_dense = mask_dense[-q_len:, :N_CTX] * causal_mask
        return (block_mask_dense_output.crow_indices(), block_mask_dense_output.col_indices()), block_mask_dense, mask_dense
    else:
        return (block_mask_dense_output.crow_indices(), block_mask_dense_output.col_indices()), block_mask_dense, None


def _get_sparse_attn_mask(n_heads, q_len, N_CTX, dtype, device, BLOCK=128, local_blocks=4, vert_stride=4, homo_head=True, return_dense=False):
    '''
    :return: a tuple of 3:
        - tuple of crow_indices, col_indices representation of CSR format.
        - block dense mask
        - all token dense mask (be aware that it can be OOM if it is too big) if `return_dense==True`, otherwise, None
    '''
    if homo_head:
        with torch.no_grad():
            (crow, col), block_mask_dense, mask_dense = _get_sparse_attn_mask_homo_head(q_len, N_CTX, dtype, device, BLOCK, local_blocks, vert_stride, return_dense)
            crow = crow[None].expand(n_heads, crow.shape[0])
            col = col[None].expand(n_heads, col.shape[0])
            if return_dense:
                mask_dense = mask_dense[None].expand(n_heads, *mask_dense.shape)
            return (crow, col), block_mask_dense, mask_dense

    with torch.no_grad():
        N_BLOCK = triton.cdiv(N_CTX, BLOCK)
        q_pos = torch.arange(N_BLOCK)[None, :, None]
        k_pos = torch.arange(N_BLOCK)[None, None]
        head_sliding_step = max(1, int(vert_stride / n_heads))  # if vert_stride <= n_heads, rotating the heads
        mask_vert_strided = [(torch.arange(N_BLOCK) + h * head_sliding_step + 1) % vert_stride == 0 for h in range(n_heads)]
        mask_vert_strided = torch.vstack(mask_vert_strided).unsqueeze(1)
        block_mask_dense = ((q_pos >= k_pos) & ((q_pos - k_pos < local_blocks) | mask_vert_strided)).to(device).to(dtype)
        N_BLOCK_Q = triton.cdiv(q_len, BLOCK)
        block_mask_dense_output = block_mask_dense[:, -N_BLOCK_Q:]
    if return_dense:
        mask_dense = torch.kron(block_mask_dense, block_mask_dense.new_ones((BLOCK, BLOCK)))
        causal_mask = torch.tril(torch.ones(N_CTX, N_CTX)).type_as(mask_dense)[-q_len:]
        mask_dense = mask_dense[..., -q_len:, :N_CTX] * causal_mask[None]
        return dense_to_crow_col(block_mask_dense_output), block_mask_dense, mask_dense
    else:
        return dense_to_crow_col(block_mask_dense_output), block_mask_dense, None


def get_sparse_attn_mask(q, N_CTX, *args, **kwargs):
    return _get_sparse_attn_mask(q.size(1), q.size(2), N_CTX, q.dtype, q.device, *args, **kwargs)

###########################################################
###########################################################

###########################################################
###################### Training Kernels ###################
###########################################################

# TODO: only apply loading/saving mask on the last iteration for EVEN_N_BLOCK, useful for 1st iteration of inference.
#    Experiment failed inside loop.
#    Another idea: only on saving? load even out of boundary(will it causes illegal access error)?
@triton.jit
def _fwd_kernel(
    Q, K, V, sm_scale,
    layout_crow_ptr,
    layout_col_ptr,
    layout_crow_stride_h, layout_crow_stride_m,
    layout_col_stride_h, layout_col_stride_m,
    TMP, L, M,  # NOTE: TMP is a scratchpad buffer to workaround a compiler bug. TMP, L, M are assumed to have contiguous layouts
    Out,
    stride_qz, stride_qh, stride_qm, stride_qd,
    stride_kz, stride_kh, stride_kn, stride_kd,
    stride_vz, stride_vh, stride_vn, stride_vd,
    stride_oz, stride_oh, stride_om, stride_od,
    Z, H, N_CTX,
    PAST_LEN,
    Q_ROUNDED_LEN,
    BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
    BLOCK_N: tl.constexpr,
    EVEN_M_BLOCK: tl.constexpr,
    EVEN_N_BLOCK: tl.constexpr,
    INFERENCE: tl.constexpr,
    NUM_DBLOCKS: tl.constexpr,
):
    Q_LEN = N_CTX - PAST_LEN
    start_m = tl.program_id(0)
    off_hz = tl.program_id(1)
    off_h = off_hz % H
    off_z = off_hz // H
    Q += off_z * stride_qz + off_h * stride_qh
    K += off_z * stride_kz + off_h * stride_kh
    V += off_z * stride_vz + off_h * stride_vh
    # initialize offsets
    offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_n = tl.arange(0, BLOCK_N)
    offs_d = tl.arange(0, BLOCK_DMODEL)
    off_q = offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qd
    # off_k = offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kd
    off_k = offs_n[None, :] * stride_kn + offs_d[:, None] * stride_kd
    off_v = offs_n[:, None] * stride_vn + offs_d[None, :] * stride_vd
    # Initialize pointers to Q, K, V
    q_ptrs = Q + off_q
    k_ptrs = K + off_k
    v_ptrs = V + off_v
    # initialize pointer to m and l
    t_ptrs = TMP + off_hz * Q_ROUNDED_LEN + offs_m
    m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float('inf')
    l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
    acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
    if NUM_DBLOCKS >= 2:
        acc2 = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)

    # load q: it will stay in SRAM throughout
    if EVEN_M_BLOCK:
        q = tl.load(q_ptrs)
        if NUM_DBLOCKS >= 2:
            q2 = tl.load(q_ptrs + BLOCK_DMODEL * stride_qd)
    else:
        q = tl.load(q_ptrs, mask=offs_m[:, None] < Q_LEN)
        if NUM_DBLOCKS >= 2:
            q2 = tl.load(q_ptrs + BLOCK_DMODEL * stride_qd, mask=offs_m[:, None] < Q_LEN)

    layout_ptr = layout_crow_ptr + off_h * layout_crow_stride_h + start_m * layout_crow_stride_m
    start_l = tl.load(layout_ptr).to(tl.int32)
    end_l = tl.load(layout_ptr + layout_crow_stride_m).to(tl.int32)

    # loop over k, v and update accumulator
    for col_idx_idx in range(start_l, end_l):
        col_idx = tl.load(layout_col_ptr +  off_h * layout_col_stride_h + col_idx_idx * layout_col_stride_m).to(tl.int32)
        start_n = col_idx * BLOCK_N
        # -- compute qk ----
        if EVEN_N_BLOCK:
            k = tl.load(k_ptrs + start_n * stride_kn)
        else:
            k = tl.load(k_ptrs + start_n * stride_kn, mask=offs_n[None, :] + start_n < N_CTX)
        qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
        qk += tl.dot(q, k)

        if NUM_DBLOCKS >= 2:
            if EVEN_N_BLOCK:
                k = tl.load(k_ptrs + start_n * stride_kn + BLOCK_DMODEL * stride_kd)
            else:
                k = tl.load(k_ptrs + start_n * stride_kn + BLOCK_DMODEL * stride_kd, mask=offs_n[None, :] + start_n < N_CTX)
            qk += tl.dot(q2, k)

        qk *= sm_scale
        qk += tl.where(offs_m[:, None] + PAST_LEN >= (start_n + offs_n[None, :]), 0, float('-inf'))
        # -- compute m_ij, p, l_ij
        m_ij = tl.max(qk, 1)
        p = tl.exp(qk - m_ij[:, None])
        l_ij = tl.sum(p, 1)
        # -- update m_i and l_i
        m_i_new = tl.maximum(m_i, m_ij)
        alpha = tl.exp(m_i - m_i_new)
        beta = tl.exp(m_ij - m_i_new)
        l_i_new = alpha * l_i + beta * l_ij
        # -- update output accumulator --
        # scale p
        p_scale = beta / l_i_new
        p = p * p_scale[:, None]
        # scale acc
        acc_scale = l_i / l_i_new * alpha
        # tl.store(t_ptrs, acc_scale)
        # acc_scale = tl.load(t_ptrs)  # BUG: have to store and immediately load
        acc = acc * acc_scale[:, None]
        if NUM_DBLOCKS >= 2:
            acc2 = acc2 * acc_scale[:, None]
        p = p.to(Q.dtype.element_ty)
        # update acc
        if EVEN_N_BLOCK:
            v = tl.load(v_ptrs + start_n * stride_vn)
        else:
            v = tl.load(v_ptrs + start_n * stride_vn, mask=offs_n[:, None] + start_n < N_CTX)
        acc += tl.dot(p, v)

        if NUM_DBLOCKS >= 2:
            if EVEN_N_BLOCK:
                v = tl.load(v_ptrs + start_n * stride_vn + BLOCK_DMODEL * stride_vd)
            else:
                v = tl.load(v_ptrs + start_n * stride_vn + BLOCK_DMODEL * stride_vd, mask=offs_n[:, None] + start_n < N_CTX)
            acc2 += tl.dot(p, v)

        # update m_i and l_i
        l_i = l_i_new
        m_i = m_i_new

    # rematerialize offsets to save registers
    # start_m = tl.program_id(0)
    # offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
    # write back l and m
    if not INFERENCE:
        l_ptrs = L + off_hz * N_CTX + offs_m
        m_ptrs = M + off_hz * N_CTX + offs_m
        if EVEN_M_BLOCK:
            tl.store(l_ptrs, l_i)
            tl.store(m_ptrs, m_i)
        else:
            tl.store(l_ptrs, l_i,  mask=offs_m < Q_LEN)
            tl.store(m_ptrs, m_i,  mask=offs_m < Q_LEN)
    # initialize pointers to output
    # offs_n = tl.arange(0, BLOCK_DMODEL)
    off_o = off_z * stride_oz + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[None, :] * stride_od
    out_ptrs = Out + off_o
    tl.store(out_ptrs, acc,  mask=offs_m[:, None] < Q_LEN)
    if NUM_DBLOCKS >= 2:
        tl.store(out_ptrs + BLOCK_DMODEL * stride_od, acc2,  mask=offs_m[:, None] < Q_LEN)


## backward
@triton.heuristics(
    {
        'EVEN_M_BLOCK': lambda kwargs: kwargs['N_CTX'] % kwargs['BLOCK_M'] == 0,
    }
)
@triton.jit
def _bwd_preprocess(
    Out, DO, L, # assume contiguous for Out, DO, L, NewDO, Delta layout.
    NewDO, Delta,
    N_CTX,
    BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr,
    EVEN_M_BLOCK: tl.constexpr,
):
    off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
    off_d = tl.arange(0, D_HEAD)
    # load
    if EVEN_M_BLOCK:
        o = tl.load(Out + off_m[:, None] * D_HEAD + off_d[None, :]).to(tl.float32)
        do = tl.load(DO + off_m[:, None] * D_HEAD + off_d[None, :]).to(tl.float32)
    else:
        o = tl.load(Out + off_m[:, None] * D_HEAD + off_d[None, :], mask=off_m[:, None] < N_CTX).to(tl.float32)
        do = tl.load(DO + off_m[:, None] * D_HEAD + off_d[None, :], mask=off_m[:, None] < N_CTX).to(tl.float32)
    denom = tl.load(L + off_m).to(tl.float32)
    # compute
    do = do / denom[:, None]
    delta = tl.sum(o * do, axis=1)
    # write-back
    if EVEN_M_BLOCK:
        tl.store(NewDO + off_m[:, None] * D_HEAD + off_d[None, :], do)
    else:
        tl.store(NewDO + off_m[:, None] * D_HEAD + off_d[None, :], do,  mask=off_m[:, None] < N_CTX)
    tl.store(Delta + off_m, delta)


# Does not suuport unequal seqlen(q) and seqlen(k)
@triton.heuristics(
    {
        'EVEN_M_BLOCK': lambda kwargs: kwargs['N_CTX'] % kwargs['BLOCK_M'] == 0,
        'EVEN_N_BLOCK': lambda kwargs: kwargs['N_CTX'] % kwargs['BLOCK_N'] == 0,
    }
)
@triton.jit
def _bwd_kernel(
    Q, K, V, sm_scale,
    layout_ccol_ptr,
    layout_row_ptr,
    layout_ccol_stride_h, layout_ccol_stride_m,
    layout_row_stride_h, layout_row_stride_m,
    Out, DO,  # assume contigous: Out, Do, DQ, DK, DV, L, M, D, seq(q) == seq(k), with stride_oz, stride_oh, stride_om, stride_od,
    DQ, DK, DV,
    L, M,
    D,
    stride_qz, stride_qh, stride_qm, stride_qd,
    stride_kz, stride_kh, stride_kn, stride_kd,
    stride_vz, stride_vh, stride_vn, stride_vd,
    stride_oz, stride_oh, stride_om, stride_od,
    # stride_dz, stride_dh, stride_dm, stride_dd,
    Z, H, N_CTX,
    num_block,
    BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
    BLOCK_N: tl.constexpr,
    EVEN_M_BLOCK: tl.constexpr,
    EVEN_N_BLOCK: tl.constexpr,
    NUM_DBLOCKS: tl.constexpr,
):
    start_n = tl.program_id(0)
    off_hz = tl.program_id(1)
    off_z = off_hz // H
    off_h = off_hz % H
    # offset pointers for batch/head
    Q += off_z * stride_qz + off_h * stride_qh
    K += off_z * stride_kz + off_h * stride_kh
    V += off_z * stride_vz + off_h * stride_vh
    DO += off_z * stride_oz + off_h * stride_oh
    DQ += off_z * stride_oz + off_h * stride_oh
    DK += off_z * stride_oz + off_h * stride_oh
    DV += off_z * stride_oz + off_h * stride_oh
    # Look like this loop can be parallelled
    # for start_n in range(0, num_block):

    offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
    offs_m = tl.arange(0, BLOCK_M)
    offs_d = tl.arange(0, BLOCK_DMODEL)
    # initialize pointers to value-like data
    k_ptrs = K + (offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kd)
    v_ptrs = V + (offs_n[:, None] * stride_vn + offs_d[None, :] * stride_vd)

    # pointer to row-wise quantities in value-like data
    D_ptrs = D + off_hz * N_CTX
    m_ptrs = M + off_hz * N_CTX
    # initialize dv amd dk
    dv = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32)
    dk = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32)
    # k and v stay in SRAM throughout
    if EVEN_N_BLOCK:
        k = tl.load(k_ptrs)
        v = tl.load(v_ptrs)
    else:
        k = tl.load(k_ptrs, mask=offs_n[:, None] < N_CTX)
        v = tl.load(v_ptrs, mask=offs_n[:, None] < N_CTX)

    if NUM_DBLOCKS >= 2:
        dv2 = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32)
        dk2 = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32)
        if EVEN_N_BLOCK:
            k2 = tl.load(k_ptrs + BLOCK_DMODEL * stride_kd)
            v2 = tl.load(v_ptrs + BLOCK_DMODEL * stride_vd)
        else:
            k2 = tl.load(k_ptrs + BLOCK_DMODEL * stride_kd, mask=offs_n[:, None] < N_CTX)
            v2 = tl.load(v_ptrs + BLOCK_DMODEL * stride_vd, mask=offs_n[:, None] < N_CTX)

    # loop over rows

    layout_ptr = layout_ccol_ptr + off_h * layout_ccol_stride_h + start_n * layout_ccol_stride_m
    start_l = tl.load(layout_ptr).to(tl.int32)
    end_l = tl.load(layout_ptr + layout_ccol_stride_m).to(tl.int32)

    for row_idx_idx in range(start_l, end_l):
        row_idx = tl.load(layout_row_ptr + off_h * layout_row_stride_h + row_idx_idx * layout_row_stride_m).to(tl.int32)
        start_m = row_idx * BLOCK_M

        # offs_qm = start_m + tl.arange(0, BLOCK_M)
        offs_m_curr = start_m + offs_m
        q_ptrs =   Q + (offs_m_curr[:, None] * stride_qm + offs_d[None, :] * stride_qd)
        do_ptrs = DO + (offs_m_curr[:, None] * stride_om + offs_d[None, :] * stride_od)
        dq_ptrs = DQ + (offs_m_curr[:, None] * stride_om + offs_d[None, :] * stride_od)

        # load q, k, v, do on-chip
        if EVEN_M_BLOCK:
            q = tl.load(q_ptrs)
        else:
            q = tl.load(q_ptrs, mask=offs_m_curr[:, None] < N_CTX)
        # re-compute p = softmax(qk, dim=-1).T
        # NOTE: `do` is pre-divided by `l`; no normalization here
        qk = tl.dot(q, tl.trans(k))

        if NUM_DBLOCKS >= 2:
            if EVEN_M_BLOCK:
                q2 = tl.load(q_ptrs + BLOCK_DMODEL * stride_qd)
            else:
                q2 = tl.load(q_ptrs + BLOCK_DMODEL * stride_qd, mask=offs_m_curr[:, None] < N_CTX)
            qk += tl.dot(q2, tl.trans(k2))

        qk += tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), 0, float('-inf'))

        if EVEN_M_BLOCK:
            m = tl.load(m_ptrs + offs_m_curr)
        else:
            m = tl.load(m_ptrs + offs_m_curr, mask=offs_m_curr < N_CTX)
        p = tl.exp(qk * sm_scale - m[:, None])

        # compute dv
        if EVEN_M_BLOCK:
            do = tl.load(do_ptrs)
        else:
            do = tl.load(do_ptrs, mask=offs_m_curr[:, None] < N_CTX)

        if NUM_DBLOCKS >= 2:
            if EVEN_M_BLOCK:
                do2 = tl.load(do_ptrs + BLOCK_DMODEL * stride_od)
            else:
                do2 = tl.load(do_ptrs + BLOCK_DMODEL * stride_od, mask=offs_m_curr[:, None] < N_CTX)

        dv += tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do)

        if NUM_DBLOCKS >= 2:
            dv2 += tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do2)

        # compute dp = dot(v, do)
        if EVEN_M_BLOCK:
            Di = tl.load(D_ptrs + offs_m_curr)
        else:
            Di = tl.load(D_ptrs + offs_m_curr, mask=offs_m_curr < N_CTX)
        dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None]
        dp += tl.dot(do, tl.trans(v))

        if NUM_DBLOCKS >= 2:
            dp += tl.dot(do2, tl.trans(v2))

        # compute ds = p * (dp - delta[:, None])
        ds = p * dp * sm_scale
        # compute dk = dot(ds.T, q)
        dk += tl.dot(tl.trans(ds.to(Q.dtype.element_ty)), q)
        if NUM_DBLOCKS >= 2:
            dk2 += tl.dot(tl.trans(ds.to(Q.dtype.element_ty)), q2)

        # # compute dq
        dq = tl.dot(ds.to(Q.dtype.element_ty), k)
        if EVEN_M_BLOCK:
            tl.atomic_add(dq_ptrs, dq)
        else:
            tl.atomic_add(dq_ptrs, dq, mask=offs_m_curr[:, None] < N_CTX)

        if NUM_DBLOCKS >= 2:
            dq2 = tl.dot(ds.to(Q.dtype.element_ty), k2)
            dq_ptrs2 = dq_ptrs + BLOCK_DMODEL * stride_od
            if EVEN_M_BLOCK:
                tl.atomic_add(dq_ptrs2, dq2)
            else:
                tl.atomic_add(dq_ptrs2, dq2, mask=offs_m_curr[:, None] < N_CTX)

    # write-back
    dv_ptrs = DV + (offs_n[:, None] * stride_om + offs_d[None, :] * stride_od)
    dk_ptrs = DK + (offs_n[:, None] * stride_om + offs_d[None, :] * stride_od)
    if EVEN_N_BLOCK:
        tl.store(dv_ptrs, dv)
        tl.store(dk_ptrs, dk)
    else:
        tl.store(dv_ptrs, dv, mask=offs_n[:, None] < N_CTX)
        tl.store(dk_ptrs, dk, mask=offs_n[:, None] < N_CTX)

    if NUM_DBLOCKS >= 2:
        dv_ptrs2 = dv_ptrs + BLOCK_DMODEL * stride_od
        dk_ptrs2 = dk_ptrs + BLOCK_DMODEL * stride_od
        if EVEN_N_BLOCK:
            tl.store(dv_ptrs2, dv2)
            tl.store(dk_ptrs2, dk2)
        else:
            tl.store(dv_ptrs2, dv2, mask=offs_n[:, None] < N_CTX)
            tl.store(dk_ptrs2, dk2, mask=offs_n[:, None] < N_CTX)



def _forward(ctx, q, k, v, layout_crow_indices, layout_col_indices, sm_scale, BLOCK_M, BLOCK_N, num_warps=None, num_stages=1, inference=None, out=None):
    '''
    :param q, k, v: [batch, n_heads, seq_len, model_dim]. len of q is allowed to be different than k/v.
    :param layout_crow_indices, layout_col_indices: same as CSR.crow_indices, and CSR.col_indices used to preresent a sparse tensor.
        Each element represent a block, i.e, all elements in a block to be attentdd, or not attended at all..
    '''
    assert q.shape[-1] == k.shape[-1] == v.shape[-1]
    assert k.shape[2] == v.shape[2]
    o = out if out is not None else torch.empty_like(q).contiguous()
    grid = (triton.cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1])

    q_rounded_len = grid[0] * BLOCK_M
    tmp = torch.empty((q.shape[0] * q.shape[1], q_rounded_len), device=q.device, dtype=torch.float32)

    if inference is None:
        inference = (not q.requires_grad) and (not k.requires_grad)  and (not v.requires_grad)

    if inference:
        L, m = tmp, tmp  # no need to use create new tensor
    else:
        L = torch.empty((q.shape[0] * q.shape[1], q_rounded_len), device=q.device, dtype=torch.float32)
        m = torch.empty((q.shape[0] * q.shape[1], q_rounded_len), device=q.device, dtype=torch.float32)

    if layout_col_indices.dim() == 1:
        layout_crow_indices = layout_crow_indices[None].expand(q.shape[1] , -1)
        layout_col_indices = layout_col_indices[None].expand(q.shape[1] , -1)

    assert q.shape[-1] in [64, 128]
    BLOCK_DMODEL = 64

    if num_warps is None:
        MIN_D = min(BLOCK_M, BLOCK_N, BLOCK_DMODEL)
        num_warps = max(1, 2 ** int(math.log2(MIN_D / 16)))
        # print(f'> {BLOCK_M=}, {BLOCK_N=}, {BLOCK_DMODEL=}, {num_warps=}, {num_stages=}')
    else:
        assert math.log2(num_warps) % 1 == 0, f'''"num_warps" should be power of 2, but got {num_warps}.'''

    ## For debugging:
    # print(f'>> {q.shape=}, {k.shape=}, {BLOCK_M=}, {BLOCK_N=}, {num_warps=}, {BLOCK_DMODEL=}, {q.stride()=}, {k.stride()=}')
    # print(f'>> {layout_crow_indices=}\n{layout_col_indices=}\n {layout_crow_indices.stride()=}, {layout_crow_indices.stride()=}')
    # print(f'> {q.shape=}, {k.shape=}, {layout_crow_indices.shape}, {layout_col_indices.shape}, {layout_crow_indices.stride()}, \
    #   {layout_col_indices.stride()}, {layout_crow_indices=}, {layout_col_indices=}')

    with torch.cuda.device(q.device.index):
        _fwd_kernel[grid](
            q, k, v, sm_scale,
            layout_crow_indices,
            layout_col_indices,
            layout_crow_indices.stride(0), layout_crow_indices.stride(1),
            layout_col_indices.stride(0), layout_col_indices.stride(1),
            tmp, L, m,
            o,
            q.stride(0), q.stride(1), q.stride(2), q.stride(3),
            k.stride(0), k.stride(1), k.stride(2), k.stride(3),
            v.stride(0), v.stride(1), v.stride(2), v.stride(3),
            o.stride(0), o.stride(1), o.stride(2), o.stride(3),
            q.shape[0], q.shape[1], k.shape[2],
            k.shape[2] - q.shape[2],
            q_rounded_len,
            BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N,
            BLOCK_DMODEL=BLOCK_DMODEL,
            EVEN_M_BLOCK=q.shape[2] % BLOCK_M == 0,
            EVEN_N_BLOCK=k.shape[2] % BLOCK_N == 0 ,
            INFERENCE=inference,
            NUM_DBLOCKS=q.shape[-1] // BLOCK_DMODEL,
            num_warps=num_warps,
            num_stages=num_stages,
        )
    if inference:
        L, m = None, None

    ctx.save_for_backward(q, k, v, o, L, m, layout_crow_indices,  layout_col_indices)
    ctx.BLOCK_M = BLOCK_M
    ctx.BLOCK_N = BLOCK_N
    ctx.BLOCK_DMODEL = BLOCK_DMODEL
    # ctx.BLOCK = BLOCK
    ctx.grid = grid
    ctx.sm_scale = sm_scale
    ctx.num_warps = num_warps
    ctx.num_stages = num_stages
    return o


def _backward(ctx, do, layout_ccol_indices, layout_row_indices, dq=None, dk=None, dv=None):
    # q, k, v, o, l, m = ctx.saved_tensors
    q, k, v, o, l, m, layout_crow_indices, layout_col_indices = ctx.saved_tensors

    ## this following too slow to do online, so get it from inputs, which is cached.
    # layout_ccol_indices, layout_row_indices = dense_to_ccol_row(crow_col_to_dense(ctx.layout_crow_indices, ctx.layout_col_indices))
    # layout_ccol_indices, layout_row_indices = dense_to_ccol_row(crow_col_to_dense(layout_crow_indices, layout_col_indices))

    if not do.is_contiguous():
        do = do.contiguous()
        ## for debugging
        # print(f'----> do is not contiguous: {do.stride()=}')
        # raise ValueError(f'>>>> output grad is not contiguous: {do.stride()=}')

    if not o.is_contiguous():
        # TODO: currently only work with contiguous q/k/v.
        raise ValueError(f'--> output is not contiguous: {o.stride()=}. This is maybe caused by q/k/v not being contiguous.')


    if layout_ccol_indices.dim() == 1:
        layout_ccol_indices = layout_ccol_indices[None].expand(q.shape[1], -1)
        layout_row_indices = layout_row_indices[None].expand(q.shape[1], -1)

    # do = do.contiguous()
    dq = dq if dq is not None else torch.zeros_like(q, dtype=torch.float32)
    dk = dk if dk is not None else torch.empty_like(k)
    dv =dv if dv is not None else  torch.empty_like(v)
    do_scaled = torch.empty_like(do)
    delta = torch.empty_like(l)

    assert o.stride() == dq.stride() == dk.stride() == dv.stride() == do_scaled.stride()

    _bwd_preprocess[(ctx.grid[0] * ctx.grid[1], )](
        o, do, l,
        do_scaled, delta,
        k.shape[2],
        BLOCK_M=ctx.BLOCK_M, D_HEAD=q.shape[-1],
    )

    grid = (triton.cdiv(q.shape[2], ctx.BLOCK_N), ctx.grid[1])

    _bwd_kernel[grid](
        q, k, v, ctx.sm_scale,
        layout_ccol_indices,
        layout_row_indices,
        layout_ccol_indices.stride(0), layout_ccol_indices.stride(1),
        layout_row_indices.stride(0), layout_row_indices.stride(1),
        o, do_scaled,
        dq, dk, dv,
        l, m,
        delta,
        q.stride(0), q.stride(1), q.stride(2), q.stride(3),
        k.stride(0), k.stride(1), k.stride(2), k.stride(3),
        v.stride(0), v.stride(1), v.stride(2), v.stride(3),
        o.stride(0), o.stride(1), o.stride(2), o.stride(3),
        q.shape[0], q.shape[1], q.shape[2],
        ctx.grid[0],
        BLOCK_M=ctx.BLOCK_M,
        BLOCK_N=ctx.BLOCK_N,
        BLOCK_DMODEL=ctx.BLOCK_DMODEL,
        NUM_DBLOCKS=q.shape[-1] // ctx.BLOCK_DMODEL,
        num_warps=ctx.num_warps,
        num_stages=1,
    )
    return dq, dk, dv, None, None, None


class _sparse_attention(torch.autograd.Function):

    @staticmethod
    def forward(ctx, q, k, v, layout_crow_indices, layout_col_indices, sm_scale):
        BLOCK = 128
        # shape constraints
        return _forward(ctx, q, k, v, layout_crow_indices, layout_col_indices, sm_scale, BLOCK, BLOCK)

    @staticmethod
    def backward(ctx, do):
        # q, k, v, o, l, m = ctx.saved_tensors
        q, k, v, o, l, m, layout_crow_indices, layout_col_indices = ctx.saved_tensors
        # TODO: the following is very inefficient.
        # layout_ccol_indices, layout_row_indices = dense_to_ccol_row(crow_col_to_dense(ctx.layout_crow_indices, ctx.layout_col_indices))
        layout_ccol_indices, layout_row_indices = dense_to_ccol_row(crow_col_to_dense(layout_crow_indices, layout_col_indices))
        return _backward(ctx, do, layout_ccol_indices, layout_row_indices)



# suppressed
class _sparse_attention_inference(_sparse_attention):
    # TODO: does not work now, as BLOCK_M cannot be <1, as shape for tl.dot cannot be smaller than 16.
    @staticmethod
    def forward(ctx, q, k, v, layout_crow_indices, layout_col_indices, sm_scale):
        BLOCK = 128
        return _forward(ctx, q, k, v, layout_crow_indices, layout_col_indices, sm_scale, 1, BLOCK)



def sparse_attention_factory(BLOCK_M=128, BLOCK_N=128, **kwargs):
    class _sparse_attention_config(_sparse_attention):
        @staticmethod
        def forward(ctx, q, k, v, layout_crow_indices, layout_col_indices, sm_scale):
            # shape constraints
            return _forward(ctx, q, k, v, layout_crow_indices, layout_col_indices, sm_scale, BLOCK_M, BLOCK_N,
                            **kwargs
                        )
    return _sparse_attention_config.apply


@lru_cache(maxsize=8)
def get_local_strided_sparse_attention_op(
        n_heads: int,
        max_seq_len:int,
        sparse_block_size: int=128,
        local_blocks: int=4,
        vert_stride: int=4,
        homo_head: bool=False,
        dtype=torch.bfloat16,
        device='cuda',
        active_head_range=None,
        verbose=True,
        **kwargs):
    '''
    :param n_heads: total number of attention heads (regardless of tensor/model parallel)
    :param max_seq_len: max sequence length. Need to be bigger or equal to the length of sequences.
    :param sparse_block_size: sparse block size. Default to 128
    :param local_blocks: number of nearest block to attend to. Default to 4, i.e., attention to previous 4xblock_size tokens.
    :param vert_stride: Default to 4. Meaning
    :param homo_head: if all head shared the same pattern.
    :param active_head_range: tuple of start & end of the heads, e..g, (8, 16). Default to use all heads.
                              Mainly for tensor/model parallelization where heads are splitted to different GPUs.
    '''

    if verbose:
        print((f'> new block_sparse_attn op constructed with config: '
            f'{n_heads=}, {max_seq_len=}, {sparse_block_size=}, {local_blocks=}, '
            f'{vert_stride=}, {homo_head=}, {active_head_range=}, {kwargs=}'))
    # assert math.log2(max_seq_len) % 2 == 0, f"max_seq_len should be power of 2 to be more efficient"
    _, block_sparse_pattern, _ = _get_sparse_attn_mask(n_heads, max_seq_len, max_seq_len, dtype, device,
                                                       BLOCK=sparse_block_size, local_blocks=local_blocks,
                                                       vert_stride=vert_stride, homo_head=homo_head,
                                                       return_dense=False)
    if (not homo_head) and (active_head_range is not None):
        assert isinstance(active_head_range, tuple)
        assert len(active_head_range) == 2, '"active_head_range" should be a tuple of start/end index of the heads.'
        h_start, h_end = active_head_range
        block_sparse_pattern = block_sparse_pattern[h_start:h_end]
    # print(block_sparse_pattern)
    return get_sparse_attn_op(block_sparse_pattern, sparse_block_size, **kwargs)


def get_sparse_attn_op(
        sparse_pattern: torch.tensor,
        sparse_block_size: int=128,
        kernel_block_size=128,
        qkv_format='q,k,v',
          **kwargs):
    '''
    Ccreate a block-sparse op with fixed layout. This is to avoid the need to of create CSR layout and convert it to CSC layout everytime,
        which is very inefficient (use python loops on CPU.  PyTorch 1.13 supports CSR->CSC, may help.)

    :param sparse_pattern: sparse pattern of the blocks. Should be `num_blocks(q) x num_blocks(k)` or `n_heads x num_blocks x num_blocks`.
        This tensor should have lower-triangular matrices on the last 2 dimensions for causal attention
    :param sparse_block_size: sparse block size. Default to 128
    :param kernel_block_size: the tile/block size to launch a triton instance. Default to None, i.e., same as `sparse_block_size`
    :param qkv_format: Choices=['q,k,v', 'q, kv', 'qkv'], i.e., separated q,k,v, or kv packed, or qkv packed. Currently, only 'q,k,v' is supported.

    :param kwargs: keyward arguments passed to `_forward`
    '''
    # assert qkv_format in ('q,k,v', 'q, kv', 'qkv')  # to save from running `concat` at forward/backward

    assert qkv_format == 'q,k,v'

    if kernel_block_size is None:
        kernel_block_size = sparse_block_size
    else:
        assert sparse_block_size % kernel_block_size == 0, f"The sparse block size must be a multiple of {kernel_block_size}."
        assert kernel_block_size >=16 and math.log2(kernel_block_size) % 1 == 0, f"block_size must be power of 2 and at least 16, but {kernel_block_size} is given"


        # print(f'>> {sparse_pattern.shape=}')
        # print(f'{sparse_pattern=}')
        if sparse_block_size // kernel_block_size > 1:
            _mul = sparse_block_size // kernel_block_size
            # need to consider if block_m and block_n are different
            sparse_pattern = torch.kron(sparse_pattern, sparse_pattern.new_ones(_mul, _mul))
            num_sparse_blocks = sparse_pattern.size(-1)
            block_causal_mask = torch.arange(0, num_sparse_blocks)[:, None] >= torch.arange(0, num_sparse_blocks)[None]
            sparse_pattern *= block_causal_mask.type_as(sparse_pattern)
            # print(f'>> after: {sparse_pattern.shape=}')
            # print(f'{sparse_pattern=}')

    BLOCK_N = kernel_block_size
    NUM_BLOCK =  sparse_pattern.size(-1)
    MAX_SEQ_LEN = kernel_block_size * NUM_BLOCK

    grand_layout_crow_indices, grand_layout_col_indices = dense_to_crow_col(sparse_pattern)
    # sparse csc layout for backward
    grand_layout_ccol_indices, grand_layout_row_indices = dense_to_ccol_row(sparse_pattern)


    # cache GPU backward layout. limit the size to avoid OOM as time goes.
    # For inference, one only needs to cache one block as sequence length always increases
    # Therefore, this cache needs to be reconstructed per every `block_size`-steps.
    # For training/finetune, set to 8 to increase cache hit.
    # Given an input, the block_len will be the same for all layers, so cache is very helpful.

    max_cache_size = 1 if kwargs.get('inference', False) else 8

    @lru_cache(maxsize=max_cache_size)
    def get_backward_layout_by_block_len(block_len):
        assert block_len <= NUM_BLOCK
        if block_len == NUM_BLOCK:
            return (grand_layout_ccol_indices, grand_layout_row_indices)
        return dense_to_ccol_row(sparse_pattern[..., :block_len, :block_len])

    # for debugging
    # if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
    #     print(f'> {sparse_pattern.cpu().tolist()=}')
    #     print('----')
    #     print(f'> {grand_layout_crow_indices.cpu().tolist()=}\n{grand_layout_col_indices.cpu().tolist()=}')


     # q, k, v separated
    class _q_k_v_sparse_attention(torch.autograd.Function):
        @staticmethod
        def forward(ctx, q, k, v, sm_scale):
            # assert q.shape[2] == 1 or q.shape[2] == k.shape[2]
            # shape constraints
            MIN_BLOCK_SIZE = 16
            assert BLOCK_N >= MIN_BLOCK_SIZE
            BLOCK_M = 16 if q.shape[2] <= 16 else BLOCK_N  # BLOCK_M has to be power of 2

            # this following code only works for causal attention
            K_BLOCKS = triton.cdiv(k.shape[2],  kernel_block_size)
            # Q_START_BLOCKS = K_BLOCKS - 1 if q.shape[2] == 1 else 0
            Q_START_BLOCKS = K_BLOCKS - triton.cdiv(q.shape[2], BLOCK_N)
            # print(Q_START_BLOCKS, K_BLOCKS)

            layout_crow_indices = grand_layout_crow_indices[..., Q_START_BLOCKS:K_BLOCKS+1]
            layout_col_indices = grand_layout_col_indices
            # print(BLOCK_M, BLOCK_N, Q_START_BLOCKS, K_BLOCKS+1, layout_crow_indices, layout_col_indices)

            return _forward(ctx, q, k, v, layout_crow_indices, layout_col_indices, sm_scale, BLOCK_M, BLOCK_N,
                            **kwargs
                        )
        @staticmethod
        def backward(ctx, do):
            q, k = ctx.saved_tensors[:2]
            assert q.shape[2] == k.shape[2], '> currently backward can only be done if q, k have same length. Contact @EricLin if you need it.'
            # assume q, k have same length
            block_len = triton.cdiv(do.shape[2], kernel_block_size)
            backward_layout = get_backward_layout_by_block_len(block_len)
            return _backward(ctx, do, *backward_layout)[:4]


    def _q_k_v_sparse_attention_fn(*args):
        return _q_k_v_sparse_attention.apply(*args)

    _q_k_v_sparse_attention_fn.sparse_pattern = sparse_pattern
    _q_k_v_sparse_attention_fn.grand_layout_crow_indices = grand_layout_crow_indices
    _q_k_v_sparse_attention_fn.grand_layout_col_indices = grand_layout_col_indices
    _q_k_v_sparse_attention_fn.grand_layout_ccol_indices = grand_layout_ccol_indices
    _q_k_v_sparse_attention_fn.grand_layout_row_indices = grand_layout_row_indices

    return _q_k_v_sparse_attention_fn

###########################################################
###########################################################

###########################################################
################ Inference Kernels ########################
###########################################################

def blocksparse_flash_attn_padded_fwd(
    q, k, v, # (batch, tokens, n_heads, head_size)
    sm_scale,
    sparse_layout,
    *,
    left_paddings = None,
    seqlens = None,
    block_size = 64,
    max_seqlen = None
):
    '''
    q, k, v: (batch, tokens, n_heads/n_kv_heads, head_size)
    left_paddings: (batch, ), number of left paddings for each sample.
    seqlens: can be used to specify right padding. No need to specify if left_paddings is used.
    '''
    batches, q_len, n_heads, head_size = q.shape
    _, k_len, n_kv_heads, _ = k.shape


    assert q.dim() == k.dim() == v.dim() == 4
    assert q.size(2) % k.size(2) == 0
    assert q.size(0) == k.size(0) and q.size(3) == k.size(3)
    assert k.shape == v.shape # TODO: allow diff head_size for k, v
    assert q_len == 1 or q_len == k_len, \
    f'q length can only 1 for decoding for same as k length for prefilling.'

    q_k_ratio = q.size(2) // k.size(2)

    if max_seqlen:
        assert k.size(1) <= max_seqlen, f'k has seqlen {k.size(1)} while max sequence length is set to {max_seqlen}.'

    # paddings always has zero output, a little slower than using empty
    out = q.new_zeros(q.shape)

    layout_crow_indices, layout_col_indices = sparse_layout
    block_d = triton.next_power_of_2(head_size)

    if left_paddings is not None:
        assert left_paddings.shape == (batches,)
        k_batch_starts = left_paddings.to(q.device, dtype=torch.int32).contiguous()
    else:
        k_batch_starts = torch.zeros((batches,), dtype=torch.int32, device=q.device)

    if seqlens is not None:
        k_batch_ends = k_batch_starts + seqlens.type_as(k_batch_starts)
        assert k_batch_ends.max() <= k_len, f'seqlens (+left_paddings if any) exceeds seqlen.'
    else:
        k_batch_ends = torch.zeros_like(k_batch_starts) + k_len

    if q_len == 1:
        q_batch_starts = torch.zeros_like(k_batch_starts)
        q_batch_ends = q_batch_starts + 1
    else:
        q_batch_starts = k_batch_starts
        q_batch_ends = k_batch_ends

    # switch to use cpu to avoid too many kernel lauch when iterate over
    q_lens = (q_batch_ends - q_batch_starts).cpu()
    n_blocks = (q_lens + block_size - 1) // block_size

    q_batch_ids = torch.tensor([i for i, n in enumerate(n_blocks) for _ in range(n)],
                                dtype=q_batch_starts.dtype,
                                device=q_batch_starts.device)
    q_start_sids = torch.tensor([i * block_size for n in n_blocks for i in range(n)],
                               dtype=q_batch_starts.dtype,
                               device=q_batch_starts.device)

    grid = (len(q_start_sids), n_heads)

    with torch.cuda.device(q.device.index):
        _fwd_kernel_batch_inference[grid](
        q, k, v, out,
        sm_scale,
        q_batch_starts,
        q_batch_ends,
        k_batch_starts,
        k_batch_ends,
        q_batch_ids,
        q_start_sids,
    
        *q.stride(),
        *k.stride(),
        *v.stride(),
        *out.stride(),
    
        layout_crow_indices,
        layout_col_indices,
        *layout_crow_indices.stride(),
        *layout_col_indices.stride(),
    
        q_k_ratio,
        HAS_BATCH_DIM = True,
        D_HEAD = head_size,
        BLOCK_M = block_size,
        BLOCK_N = block_size,
        BLOCK_D = block_d,
        BLOCK_M_LOADING = 16 if q_len == 1 else block_size, # smaller for decoding
        EVEN_D = block_d == head_size,
        num_warps = 1 if q_len == 1 else 4,
        num_stages = 1
        )


    return out


def blocksparse_flash_attn_varlen_fwd(
    q, k, v, # (#tokens, n_heads, head_size)
    cu_seqlens_k,
    cu_seqlens_q,
    sm_scale,
    sparse_layout,
    *,
    block_size=64,
    max_seqlen = None
):
    # split q to blocks
    _, n_heads, head_size = q.shape
    batch_size = cu_seqlens_k.size(0) - 1


    # print(f'> {q.shape=}, {k.shape=}')
    assert q.dim() == k.dim() == v.dim() == 3
    assert q.size(1) % k.size(1) == 0
    assert q.size(2) == k.size(2)
    assert k.shape == v.shape # TODO: allow diff head_size for k, v
    assert cu_seqlens_k.dim() == 1

    q_k_ratio = q.size(1) // k.size(1)

    if cu_seqlens_q is None:
        if q.size(0) == batch_size: # decoding only
            cu_seqlens_q = torch.arange(0, batch_size + 1,
                                        dtype=cu_seqlens_k.dtype,
                                        device=cu_seqlens_k.device)
        elif q.size(0) == k.size(0):
            cu_seqlens_q = cu_seqlens_k
        else:
            raise ValueError('cu_seqlens_q must be specified if it is mix of prefilling and decoding.')
    else:
        assert cu_seqlens_k.size(0) == cu_seqlens_q.size(0)

    # switch to use cpu to avoid too many kernel lauch when iterate over
    q_lens = (cu_seqlens_q[1:] - cu_seqlens_q[:-1]).cpu()
    k_lens = (cu_seqlens_k[1:] - cu_seqlens_k[:-1]).cpu()

    assert torch.logical_or(q_lens == 1, k_lens == q_lens).all(), \
        'length of q should either be 1 (decoding) or same as k (prefilling).'

    if max_seqlen:
        assert k_lens.max() <= max_seqlen

    n_blocks = (q_lens + block_size - 1) // block_size

    q_batch_ids = torch.tensor([i for i, n in enumerate(n_blocks) for _ in range(n)],
                                dtype=cu_seqlens_q.dtype,
                                device=cu_seqlens_q.device)
    q_start_sids = torch.tensor([i * block_size for n in n_blocks for i in range(n)],
                               dtype=cu_seqlens_q.dtype,
                               device=cu_seqlens_q.device)


    out = q.new_empty(q.shape)
    cu_seqlens_q = cu_seqlens_q.contiguous()
    cu_seqlens_k = cu_seqlens_k.contiguous()

    layout_crow_indices, layout_col_indices = sparse_layout
    block_d = triton.next_power_of_2(head_size)

    decoding_only =  (q_lens == 1).all()

    grid = (len(q_start_sids), n_heads)

    with torch.cuda.device(q.device.index):
        _fwd_kernel_batch_inference[grid](
        q, k, v, out,
        sm_scale,
        cu_seqlens_q[:-1],
        cu_seqlens_q[1:],
        cu_seqlens_k[:-1],
        cu_seqlens_k[1:],
        q_batch_ids,
        q_start_sids,
    
        0, *q.stride(),
        0, *k.stride(),
        0, *v.stride(),
        0, *out.stride(),
    
        layout_crow_indices,
        layout_col_indices,
        *layout_crow_indices.stride(),
        *layout_col_indices.stride(),
    
        q_k_ratio,
        HAS_BATCH_DIM = False,
        D_HEAD = head_size,
        BLOCK_M = block_size,
        BLOCK_N = block_size,
        BLOCK_D = block_d,
        BLOCK_M_LOADING = 16 if decoding_only else block_size, # smaller for decoding
        EVEN_D = block_d == head_size,
        num_warps = 1 if decoding_only else 4,
        num_stages = 3
        )

    return out


@triton.jit
def _fwd_kernel_inner(
    acc, l_i, m_i,
    q, Q,
    k_block_col_idx,
    layout_col_ptr,
    layout_col_stride_h, layout_col_stride_m,
    k_ptrs,
    v_ptrs,
    off_h, offs_m, offs_n, offs_d,
    stride_kt, stride_vt,
    sm_scale,
    k_seqlen,
    past_len,
    LAST_K_BLOCK: tl.constexpr,
    BLOCK_M_LOADING: tl.constexpr,
    BLOCK_N: tl.constexpr,
    D_HEAD: tl.constexpr,
    EVEN_D: tl.constexpr,
    M_LT_N: tl.constexpr
):
    k_block_id = tl.load(layout_col_ptr +  off_h * layout_col_stride_h + k_block_col_idx * layout_col_stride_m).to(tl.int32)
    start_n = k_block_id * BLOCK_N
    # -- compute qk ----
    if LAST_K_BLOCK:
        if EVEN_D:
            k = tl.load(k_ptrs + start_n * stride_kt,
                        mask=offs_n[None, :] + start_n < k_seqlen)
        else:
            # mask = mask & (offs_d[:, ])
            k = tl.load(k_ptrs + start_n * stride_kt,
                        mask=(offs_n[None, :] + start_n < k_seqlen) & (offs_d[:, None] < D_HEAD))
    else:
        if EVEN_D:
            k = tl.load(k_ptrs + start_n * stride_kt)
        else:
            k = tl.load(k_ptrs + start_n * stride_kt,
                        mask=offs_d[:, None] < D_HEAD)


    qk = tl.zeros([BLOCK_M_LOADING, BLOCK_N], dtype=tl.float32)
    qk += tl.dot(q, k)

    qk *= sm_scale

    # the following is needed only when LAST_K_BLOCK or BLOCK_M < BLOCK_N
    if LAST_K_BLOCK | M_LT_N:
        qk += tl.where(offs_m[:, None] + past_len >= (start_n + offs_n[None, :]), 0, float('-inf'))

    # -- compute m_ij, p, l_ij
    m_ij = tl.max(qk, 1)
    p = tl.exp(qk - m_ij[:, None])

    l_ij = tl.sum(p, 1)
    # -- update m_i and l_i
    m_i_new = tl.maximum(m_i, m_ij)
    alpha = tl.exp(m_i - m_i_new)
    beta = tl.exp(m_ij - m_i_new)
    l_i_new = alpha * l_i + beta * l_ij
    # -- update output accumulator --
    # scale p
    p_scale = beta / l_i_new
    p = p * p_scale[:, None]
    # scale acc
    acc_scale = l_i / l_i_new * alpha
    acc = acc * acc_scale[:, None]

    p = p.to(Q.dtype.element_ty)
    # update acc
    if LAST_K_BLOCK:
        if EVEN_D:
            v = tl.load(v_ptrs + start_n * stride_vt,
                        mask=offs_n[:, None] + start_n < k_seqlen)
        else:
            v = tl.load(v_ptrs + start_n * stride_vt,
                        mask=(offs_n[:, None] + start_n < k_seqlen) & (offs_d[None, :] < D_HEAD))
    else:
        if EVEN_D:
            v = tl.load(v_ptrs + start_n * stride_vt)
        else:
            v = tl.load(v_ptrs + start_n * stride_vt,
                        mask=offs_d[None, :] < D_HEAD)

    acc += tl.dot(p, v)
    # update m_i and l_i
    l_i = l_i_new
    m_i = m_i_new
    return acc, l_i, m_i


@triton.heuristics(
    {
        'M_LT_N': lambda kwargs: kwargs['BLOCK_M'] < kwargs['BLOCK_N'],
    }
)
@triton.jit
def _fwd_kernel_batch_inference(
    Q, K, V, Out,

    sm_scale,
    q_batch_starts,
    q_batch_ends,
    k_batch_starts,
    k_batch_ends,
    q_batch_ids,
    q_start_sids,

    stride_qb, stride_qt, stride_qh, stride_qd,
    stride_kb, stride_kt, stride_kh, stride_kd,
    stride_vb, stride_vt, stride_vh, stride_vd,
    stride_ob, stride_ot, stride_oh, stride_od,

    layout_crow_ptr,
    layout_col_ptr,
    layout_crow_stride_h, layout_crow_stride_m,
    layout_col_stride_h, layout_col_stride_m,

    q_k_ratio,

    HAS_BATCH_DIM: tl.constexpr,
    D_HEAD: tl.constexpr,
    BLOCK_M: tl.constexpr,
    BLOCK_N: tl.constexpr,
    BLOCK_D: tl.constexpr,
    BLOCK_M_LOADING: tl.constexpr,
    EVEN_D: tl.constexpr,
    M_LT_N: tl.constexpr
):
    '''
    NOTATION:
    pid: position id
    sid: storage id
    sbid: storage block id
    pbid: position block id
    offs_m, offs_n: storage offsets of m-dim(q, row) and n-dim(k, col)

    q and blocks in KV needs to be contiguous

    Arguments:
    kv_seq_lens: for compute past_len
    kv_storage_offsets: similar to block_tables in vllm, except it is dynamic.
        TODO: fix this

    TODO:
    Optimize grouped-attn

    CUDA graph support issue
        1. grid is dynamic: vllm set up multiple cuda graph in decoding phase, with diff max token size (16, 32, ...)
            since we mix prompt and decoing phase here, it can be more complex.
            need to set up diff cuda-graph for diff (off_zm, off_z)

            # indeed, q_batch_ids can be padded to maximum number of grid[0], i.e., assume all decoding
            therefore, cu_seqlens_q, kv_seq_lens

    '''
    off_zm = tl.program_id(0)
    off_h = tl.program_id(1)

    off_h_for_kv = off_h // q_k_ratio
    off_z = tl.load(q_batch_ids + off_zm).to(tl.int32)   # [0, 0, 0, 1]
    q_start_sid = tl.load(q_start_sids + off_zm)
    start_m = q_start_sid // BLOCK_M

    if HAS_BATCH_DIM:
        Q += off_z * stride_qb
        K += off_z * stride_kb
        V += off_z * stride_vb
        Out += off_z * stride_ob

    offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M_LOADING)
    offs_n = tl.arange(0, BLOCK_N)
    offs_d = tl.arange(0, BLOCK_D)

    q_cu_start = tl.load(q_batch_starts + off_z).to(tl.int32)
    q_seqlen = tl.load(q_batch_ends + off_z).to(tl.int32) - q_cu_start

    k_cu_start = tl.load(k_batch_starts + off_z).to(tl.int32)
    k_seqlen = tl.load(k_batch_ends + off_z).to(tl.int32) - k_cu_start

    past_len = k_seqlen - q_seqlen

    Q += q_cu_start * stride_qt + off_h * stride_qh
    K += k_cu_start * stride_kt + off_h_for_kv * stride_kh
    V += k_cu_start * stride_vt + off_h_for_kv * stride_vh
    Out += q_cu_start * stride_ot + off_h * stride_oh

    q_pbid = (past_len + q_start_sid) // BLOCK_M

    if EVEN_D:
        q = tl.load(Q + offs_m[:, None] * stride_qt + offs_d[None, :] * stride_qd,
                    mask=offs_m[:, None] < q_seqlen)
    else:
        q = tl.load(Q + offs_m[:, None] * stride_qt + offs_d[None, :] * stride_qd,
                    mask=(offs_m[:, None] < q_seqlen) & (offs_d[None, :] < D_HEAD),
                    other=0)

    sparse_crow_ptr = layout_crow_ptr + off_h * layout_crow_stride_h + q_pbid * layout_crow_stride_m

    # TODO: load at once, supported in new Triton
    k_block_start = tl.load(sparse_crow_ptr).to(tl.int32)
    k_block_end = tl.load(sparse_crow_ptr + 1).to(tl.int32)

    m_i = tl.zeros([BLOCK_M_LOADING], dtype=tl.float32) - float('inf')
    l_i = tl.zeros([BLOCK_M_LOADING], dtype=tl.float32)
    acc = tl.zeros([BLOCK_M_LOADING, BLOCK_D], dtype=tl.float32)

    k_ptrs = K + offs_n[None, :] * stride_kt + offs_d[:, None] * stride_kd
    v_ptrs = V + offs_n[:, None] * stride_vt + offs_d[None, :] * stride_vd

    for k_block_col_idx in range(k_block_start, k_block_end - 1):
        acc, l_i, m_i = _fwd_kernel_inner(
            acc, l_i, m_i,
            q, Q,
            k_block_col_idx,
            layout_col_ptr,
            layout_col_stride_h, layout_col_stride_m,
            k_ptrs,
            v_ptrs,
            off_h, offs_m, offs_n, offs_d,
            stride_kt, stride_vt,
            sm_scale,
            k_seqlen,
            past_len,
            False,
            BLOCK_M_LOADING,
            BLOCK_N,
            D_HEAD,
            EVEN_D,
            M_LT_N
            )

    acc, l_i, m_i = _fwd_kernel_inner(
        acc, l_i, m_i,
        q, Q,
        k_block_end - 1,
        layout_col_ptr,
        layout_col_stride_h, layout_col_stride_m,
        k_ptrs,
        v_ptrs,
        off_h, offs_m, offs_n, offs_d,
        stride_kt, stride_vt,
        sm_scale,
        k_seqlen,
        past_len,
        True,
        BLOCK_M_LOADING,
        BLOCK_N,
        D_HEAD,
        EVEN_D,
        M_LT_N
        )

    # write output
    if EVEN_D:
        tl.store(Out + offs_m[:, None] * stride_ot + offs_d[None, :] * stride_od, acc,
                mask=offs_m[:, None] < q_seqlen)
    else:
        tl.store(Out + offs_m[:, None] * stride_ot + offs_d[None, :] * stride_od, acc,
                mask=(offs_m[:, None] < q_seqlen) & (offs_d[None, :] < D_HEAD))


###########################################################
###########################################################

###########################################################
################## Testing Utilities ######################
###########################################################


def torch_attention(q, k, v, attn_mask=None, sm_scale=None, block_attn_mask=None, block_size=128, do=None):
    '''
    q, k, v: shape=(batch, n_heads, seq, dim)
    '''
    # for verification
    if sm_scale is None:
        sm_scale = math.sqrt(float(q.size(-1)))

    if block_attn_mask is not None:
        assert attn_mask is None
        outs = []
        for s in range(0, q.size(2), block_size):
            e = min(s + block_size, q.size(2))
            q_block = q[:, :, s:e]
            attn = torch.einsum('bhmd,bhnd->bhmn', q_block, k[:, :, :e]).float() * sm_scale
            mask = block_attn_mask[..., s // block_size, : (s // block_size + 1)]
            mask = torch.kron(mask, torch.ones(block_size, block_size, device=mask.device))
            mask[..., :, s:].masked_fill_(torch.arange(0, block_size)[:, None] <= torch.arange(0, block_size)[None, :], 0)
            attn = attn.masked_fill((1 - mask).bool(), float('-inf'))
            attn = attn.softmax(-1)
            out = torch.einsum('bhmn,bhnd->bhmd', attn.type_as(v), v[:, :, :e])
            outs.append(out)
        torch_output = torch.cat(outs, dim=2)
    else:
        attn = torch.einsum('bhmd,bhnd->bhmn', q, k).float() * sm_scale
        # import ipdb; ipdb.set_trace()
        if attn_mask is not None:
            attn = attn.masked_fill((1 - attn_mask).bool(), float('-inf'))
        # print(f'> torch attn: {attn.exp().sum(-1)=}')

        attn = attn.softmax(-1)
        if do is not None:
            dv = torch.einsum('bhqk,bhqd->bhkd', attn.type_as(do), do)
            print(f'> torch_attn computed dv: {dv=}')
        torch_output = torch.einsum('bhmn,bhnd->bhmd', attn.type_as(v), v)
    return torch_output

###########################################################
###########################################################

###########################################################
#################### Unit Tests ###########################
###########################################################


@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(2, 8, 2048, 128), (1, 4, 4096, 64)])
def test_op(Z, H, N_CTX, D_HEAD, Q_LEN=None, dtype=torch.bfloat16, homo_head=True, kernel_block_size=None, sparse_block_size=128, backward=True,
            sparse_attention_fn=None, local_blocks=4, vert_stride=4, sm_scale=None, max_length=None):
    Q_LEN = Q_LEN or N_CTX
    torch.manual_seed(20)
    q = torch.empty((Z, H, Q_LEN, D_HEAD), dtype=dtype, device='cuda').normal_(mean=0, std=.5) # .requires_grad_()
    k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device='cuda').normal_(mean=0, std=.5) # .requires_grad_()
    v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device='cuda').normal_(mean=0, std=.5) # .requires_grad_()

    if sm_scale is None:
        sm_scale = 1. / math.sqrt(D_HEAD)

    # for debugging
    # print(f'>> {q.shape=}, {k.shape=}, {v.shape=}, {homo_head=}, {kernel_block_size=}, {sparse_block_size=}, {local_blocks=}, {vert_stride=}')
    sm_scale = 0.0078125
    if backward:
        q.requires_grad_(), k.requires_grad_(), v.requires_grad_()

    # qkv = torch.empty((Z, N_CTX, 3*H*D_HEAD), dtype=dtype, device='cuda').normal_(mean=0, std=.5)
    # q = qkv[..., :H*D_HEAD]
    # k = qkv[..., H*D_HEAD:2*H*D_HEAD]
    # v = qkv[..., 2*H*D_HEAD:]
    # q = q.view(Z, N_CTX, H, -1).permute(0, 2, 1, 3)
    # k = k.view(Z, N_CTX, H, -1).permute(0, 2, 1, 3)
    # v = v.view(Z, N_CTX, H, -1).permute(0, 2, 1, 3)

    # if Q_LEN and Q_LEN < N_CTX:
    #     q = q[:, :, -Q_LEN:] # .contiguous()

    # q = q.requires_grad_()
    # k = k.requires_grad_()
    # v = v.requires_grad_()

    dout = torch.randn_like(q).contiguous()

    # dout = torch.eye(N_CTX)[:, :D_HEAD][None, None].expand_as(q).type_as(q).contiguous()
    # print(dout)

    mask_csr, _, mask_dense = get_sparse_attn_mask(q, N_CTX, BLOCK=sparse_block_size,
                            local_blocks=local_blocks, vert_stride=vert_stride, homo_head=homo_head, return_dense=True)

    if sparse_attention_fn is None:
        sparse_attention_fn = get_local_strided_sparse_attention_op(H, N_CTX,
                                                                    sparse_block_size=sparse_block_size,
                                                                    local_blocks=local_blocks,
                                                                    vert_stride=vert_stride,
                                                                    homo_head=homo_head,
                                                                    device=q.device,
                                                                    dtype=q.dtype,
                                                                    kernel_block_size=kernel_block_size)
    # reference implementation
    ref_out = torch_attention(q, k, v, mask_dense, sm_scale)

    # lengths = torch.full((Z,), fill_value=N_CTX, device='cuda')
    # cu_seqlens = torch.zeros((Z + 1,), device='cuda', dtype=torch.int32)
    # cu_seqlens[1:] = lengths.cumsum(0)
    # # qkv = torch.randn((Z * N_CTX, 3, H, D_HEAD), dtype=dtype, device='cuda', requires_grad=True)

    # qkv_list = list(map(lambda x: x.permute(0, 2, 1, 3).contiguous().view(Z * N_CTX, 1, H, D_HEAD), [q, k, v]))
    # qkv = torch.cat(qkv_list, dim=1)
    # ref_out0 = flash_attn_func(qkv, cu_seqlens, dropout_p=0, max_s=N_CTX, softmax_scale=sm_scale, causal=True)
    # ref_out = ref_out0.view(Z, N_CTX, H, D_HEAD).permute(0, 2, 1, 3).contiguous()


    if backward:
        ref_out.backward(dout)
        ref_dv, v.grad = v.grad.clone(), None
        ref_dk, k.grad = k.grad.clone(), None
        ref_dq, q.grad = q.grad.clone(), None

    tri_out = sparse_attention_fn(q, k, v, sm_scale)

    decimal = 1 if dtype == torch.bfloat16 else 2
    assert torch.allclose(ref_out.cpu(), tri_out.cpu(), atol=1e-2, rtol=0), f'>> {ref_out[0, 0, :, 0].tolist()=}\n\n{tri_out[0, 0, :, 0].tolist()=}'

    if backward:
        tri_out.backward(dout)
        tri_dv, v.grad = v.grad.clone(), None
        tri_dk, k.grad = k.grad.clone(), None
        tri_dq, q.grad = q.grad.clone(), None

    if backward:
        assert torch.allclose(ref_dv, tri_dv, atol=1e-2, rtol=1e-2)
        assert torch.allclose(ref_dk, tri_dk, atol=1e-2, rtol=0)
        assert torch.allclose(ref_dq, tri_dq, atol=1e-2, rtol=0)

    print(f'> test passed: {Z=}, {H=}, {N_CTX=}, {D_HEAD=}, {Q_LEN=}, {dtype=}, {homo_head=}, {sparse_block_size=}')

###########################################################

if __name__ == '__main__':

    GPU_TYPE = os.popen('nvidia-smi --query-gpu=name --format=csv | tail -n 1').read().strip()
    # print(GPU_TYPE)
    support_backward = True # 'A100' in GPU_TYPE. Wasn't supportted in consumer A1000.

    ###############
    # benchmarking

    HAS_DENSE_TRITON_FLASH = False
    # try:
    #     from triton.ops.flash_attention import attention as triton_attention
    #     HAS_DENSE_TRITON_FLASH = True
    # except:
    #     HAS_DENSE_TRITON_FLASH = False
    #     print('> cannot import Trition flash attn')

    try:
        from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_unpadded_func
        HAS_FLASH = True
    except BaseException:
        HAS_FLASH = False
        print('> cannot import flash_attn')


    # BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64
    BATCH, N_HEADS, N_CTX, D_HEAD = 4, 32, 4096, 128  # 6.7B model, with 4k len
    # BATCH, N_HEADS, N_CTX, D_HEAD = 4, 16, 4096, 128  # 204m model

    BLOCK_SIZE = 64
    LOCAl_BLOCKS = 8 # 4
    VERT_STRIDE = 1 # 16 # 8
    HOMO_HEAD = False
    sparse_type = 'home' if HOMO_HEAD else 'hetero'
    dtype = torch.bfloat16


    modes = ['fwd', 'bwd'] if support_backward else ['fwd']

    configs = [triton.testing.Benchmark(
        x_names=['SEQ_LEN'],
        x_vals=[2**i for i in range(8, 16)],
        line_arg='provider',
        line_vals=(['triton'] if HAS_DENSE_TRITON_FLASH else []) + (['flash'] if HAS_FLASH else []) + ['triton_sparse'],
        line_names=(['Triton-Dense'] if HAS_DENSE_TRITON_FLASH else [])  + (['Flash-Dense'] if HAS_FLASH else []) + ['Triton-Sparse'],
        styles=[('red', '-'), ('blue', '-'), ('green', '-')],
        ylabel='ms',
        plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-sparse-local{LOCAl_BLOCKS}-vert{VERT_STRIDE}-{sparse_type}-{dtype}-{mode}',
        args={'H': N_HEADS, 'BATCH': BATCH, 'D_HEAD': D_HEAD, 'dtype': dtype, 'mode': mode}
    ) for mode in modes]


    @triton.testing.perf_report(configs)
    def bench_flash_attention(BATCH, H, SEQ_LEN, D_HEAD, mode, provider, dtype=torch.bfloat16, device='cuda', sparse_attention_fn=None):
        assert mode in ['fwd', 'bwd']
        warmup = 25
        rep = 100
        N_CTX = SEQ_LEN
        if provider == 'triton':
            q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=True)
            k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=True)
            v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=True)
            sm_scale = 1.3
            fn = lambda: triton_attention(q, k, v, sm_scale)
            if mode == 'bwd':
                o = fn()
                do = torch.randn_like(o)
                fn = lambda: o.backward(do, retain_graph=True)
            ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
            return ms
        if provider == 'triton_sparse':
            q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=True)
            k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=True)
            v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=True)
            sm_scale = 1.3
            # q_pos = torch.arange(N_CTX // BLOCK, device='cuda')[:, None]
            # k_pos = torch.arange(N_CTX // BLOCK, device='cuda')[None]
            # local_blocks = 4 # num_block per attn, block_size is tied to BLOCK
            # vert_stride =N_CTX + 1 # 4
            # mask_vert_strided = torch.arange(N_CTX // BLOCK, device='cuda') % vert_stride == vert_stride - 1
            # mask_dense = ((q_pos >= k_pos) & ((q_pos - k_pos < local_blocks) | mask_vert_strided)).type_as(q)
            # mask = mask_dense.to_sparse_csr()
            # mask_csr, _ = get_sparse_attn_mask(q, N_CTX, BLOCK=BLOCK, local_blocks=LOCAl_BLOCKS, vert_stride=VERT_STRIDE, homo_head=HOMO_HEAD)

            if sparse_attention_fn is None:
                # sparse_attention_fn = sparse_attention
                sparse_attention_fn = get_local_strided_sparse_attention_op(H, SEQ_LEN,
                                                                            local_blocks=LOCAl_BLOCKS,
                                                                            vert_stride=VERT_STRIDE,
                                                                            homo_head=HOMO_HEAD,
                                                                            sparse_block_size=BLOCK_SIZE,
                                                                            kernel_block_size=BLOCK_SIZE,
                                                                            device=q.device)
            # sparse_attention_fn = sparse_attention_factory(128, 128, num_warps=8)

            # fn = lambda: sparse_attention_fn(q, k, v, mask_csr[0], mask_csr[1], sm_scale)
            fn = lambda: sparse_attention_fn(q, k, v, sm_scale)
            if mode == 'bwd':
                o = fn()
                do = torch.randn_like(o)
                fn = lambda: o.backward(do, retain_graph=True)
            ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
            return ms
        if provider == 'flash':
            lengths = torch.full((BATCH,), fill_value=N_CTX, device=device)
            cu_seqlens = torch.zeros((BATCH + 1,), device=device, dtype=torch.int32)
            cu_seqlens[1:] = lengths.cumsum(0)
            qkv = torch.randn((BATCH * N_CTX, 3, H, D_HEAD), dtype=dtype, device=device, requires_grad=True)
            fn = lambda: flash_attn_func(qkv, cu_seqlens, 0., N_CTX, causal=True)
            if mode == 'bwd':
                o = fn()
                do = torch.randn_like(o)
                fn = lambda: o.backward(do, retain_graph=True)
            ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
            return ms

        # if provider == 'torch':
        #     q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=True)
        #     k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=True)
        #     v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=True)
        #     sm_scale = 1.3
        #     causal_mask = torch.tril(torch.ones(N_CTX, N_CTX)).type_as(q)
        #     fn = lambda:  torch_attention(q, k, v, causal_mask, sm_scale)
        #     ms = triton.testing.do_bench(fn, percentiles=None, warmup=warmup, rep=rep)
        #     return ms


    BATCH, N_HEADS, N_CTX, D_HEAD, Q_LEN = 4, 32, 4096, 128, 1  # 6.7B model, with 4k len

    BLOCK_SIZE = 64
    LOCAl_BLOCKS = 8 # 4
    VERT_STRIDE = 16 # 8
    HOMO_HEAD = False
    sparse_type = 'home' if HOMO_HEAD else 'hetero'
    dtype = torch.bfloat16
    MAX_N_CTX = 8192

    configs = [triton.testing.Benchmark(
        x_names=['PAST_LEN'],
        x_vals=[2**i - 1 for i in range(8, 14)],
        line_arg='provider',
        line_vals=['torch'] + (['flash'] if HAS_FLASH else []) + ['triton_sparse', 'triton_dense'],
        line_names=['Torch']  + (['Flash-Dense'] if HAS_FLASH else []) + ['Triton-Sparse', 'Triton-Dense'],
        styles=[('red', '-'), ('blue', '-'), ('green', '-'), ('cyan', '-')],
        ylabel='ms',
        plot_name=f'fused-attention-inference-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-sparse-local{LOCAl_BLOCKS}-vert{VERT_STRIDE}-{sparse_type}',
        args={'H': N_HEADS, 'BATCH': BATCH, 'D_HEAD': D_HEAD, 'Q_LEN': Q_LEN, 'dtype': torch.float16, 'mode': mode}
    ) for mode in ['fwd']]
    @triton.testing.perf_report(configs)
    def bench_flash_attention_inference(BATCH, H, PAST_LEN, D_HEAD, Q_LEN, mode, provider, dtype=torch.bfloat16, device='cuda'):
        assert mode in ['fwd']
        warmup = 25
        rep = 100
        N_CTX = PAST_LEN + Q_LEN
        if provider == 'torch':
            q = torch.randn((BATCH, H, Q_LEN, D_HEAD), dtype=dtype, device='cuda', requires_grad=False)
            k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=False)
            v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=False)
            sm_scale = 1.3
            mask_csr, _, mask_dense = get_sparse_attn_mask(q, N_CTX, BLOCK=BLOCK_SIZE,
                                    local_blocks=LOCAl_BLOCKS, vert_stride=VERT_STRIDE, homo_head=VERT_STRIDE, return_dense=True)

            fn = lambda: torch_attention(q, k, v, mask_dense, sm_scale=sm_scale, block_size=2048)
            ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
            return ms
        if provider == 'triton_sparse':
            q = torch.randn((BATCH, H, Q_LEN, D_HEAD), dtype=dtype, device='cuda', requires_grad=False)
            k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=False)
            v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=False)
            sm_scale = 1.3
            sparse_attention_fn = get_local_strided_sparse_attention_op(H, MAX_N_CTX,
                                                                        local_blocks=LOCAl_BLOCKS,
                                                                        vert_stride=VERT_STRIDE,
                                                                        homo_head=HOMO_HEAD,
                                                                        sparse_block_size=BLOCK_SIZE,
                                                                        kernel_block_size=BLOCK_SIZE,
                                                                        device=q.device,
                                                                        inference=True)

            fn = lambda: sparse_attention_fn(q, k, v, sm_scale)
            ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
            return ms
        if provider == 'triton_dense':
            q = torch.randn((BATCH, H, Q_LEN, D_HEAD), dtype=dtype, device='cuda', requires_grad=False)
            k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=False)
            v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=False)
            sm_scale = 1.3
            sparse_attention_fn = get_local_strided_sparse_attention_op(H, MAX_N_CTX,
                                                                        local_blocks=1,
                                                                        vert_stride=1,
                                                                        homo_head=True,
                                                                        sparse_block_size=BLOCK_SIZE,
                                                                        kernel_block_size=BLOCK_SIZE,
                                                                        device=q.device,
                                                                        inference=True)

            fn = lambda: sparse_attention_fn(q, k, v, sm_scale)
            ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
            return ms
        if provider == 'flash':
            assert Q_LEN == 1
            lengths = torch.full((BATCH,), fill_value=N_CTX, device=device)
            cu_seqlens = torch.zeros((BATCH + 1,), device=device, dtype=torch.int32)
            cu_seqlens[1:] = lengths.cumsum(0)
            cu_seqlens_q = torch.arange(BATCH + 1, device=device, dtype=torch.int32)

            # (total_q, nheads, headdim),
            q = torch.randn((BATCH, H, D_HEAD), dtype=dtype, device='cuda', requires_grad=False)
            k = torch.randn((BATCH*N_CTX, H, D_HEAD), dtype=dtype, device='cuda', requires_grad=False)
            v = torch.randn((BATCH*N_CTX, H, D_HEAD), dtype=dtype, device='cuda', requires_grad=False)

            fn = lambda: flash_attn_unpadded_func(q, k, v, cu_seqlens_q, cu_seqlens, 1, N_CTX, dropout_p=0, softmax_scale=1.3, causal=False)
            ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
            return ms


    test_op(1, 4, 512, 128, dtype=torch.float16, homo_head=False, backward=support_backward)
    # bench_flash_attention.run(save_path='.', print_data=True)

    bench_flash_attention_inference.run(save_path='.', print_data=True)
    exit()
    # head_dim=64
    test_op(1, 2, 1024, 64, kernel_block_size=64, sparse_block_size=64,
            dtype=torch.bfloat16, homo_head=False, backward=support_backward)
    # uneven length, bf16
    test_op(1, 16, 224, 128, dtype=torch.bfloat16, homo_head=False, backward=False, sparse_block_size=128,
            kernel_block_size=64, local_blocks=8, vert_stride=8)
    test_op(3, 2, 2047, 128, homo_head=False, backward=False)

    # diff kernel/sparse block size
    test_op(1, 16, 224, 128, dtype=torch.bfloat16, homo_head=False, backward=False, kernel_block_size=64)
    # inference
    # test_op(1, 4, 512 + 256, 128, Q_LEN=1,  dtype=torch.bfloat16, homo_head=False, backward=support_backward)

    # dense flash attn
    test_op(1, 2, 1024, 128, kernel_block_size=128, sparse_block_size=128, dtype=torch.bfloat16, homo_head=False,
            backward=support_backward, local_blocks=1, vert_stride=1)

    # fp16
    test_op(1, 4, 512 + 256, 128, dtype=torch.float16, homo_head=False, backward=support_backward)

    # longer sequence
    test_op(2, 4, 8192, 64, homo_head=False, backward=support_backward)
    test_op(2, 4, 8192, 128, dtype=torch.bfloat16, homo_head=False, backward=support_backward)

    # homo head
    test_op(3, 2, 2048, 64, homo_head=True, dtype=torch.bfloat16, backward=False)
    test_op(3, 2, 2048, 64, homo_head=True, backward=support_backward)

    # sparse_attention_fn = sparse_attention_factory(16, 128, num_warps=1, INFERENCE=True)
    # test_op(8, 1, 2047, 128, 1, backward=False, sparse_attention_fn=None)
    # test_op_inference(3, 2, 2048, 128, 2048)
    # test_op_inference(3, 2, 2047, 64, 2047)
    # test_op_inference(3, 2, 256, 64, 128)
    # test_op_inference(3, 2, 2048, 64, 1)

    bench_flash_attention.run(save_path='.', print_data=True)
    # bench_flash_attention_inference.run(save_path='.', print_data=True)

# ========================
# Some Benchmark Results #
# ========================

# fused-attention-batch4-head48-d64-sparse-local4-vert4-hetero-fwd
#    SEQ_LEN  Triton-Dense  Flash-Dense  Triton-Sparse
# 0    256.0      0.057184     0.069646       0.052567
# 1    512.0      0.131688     0.187658       0.110212
# 2   1024.0      0.391844     0.524990       0.247875
# 3   2048.0      1.305190     1.456685       0.596506
# 4   4096.0      4.623019     4.968653       1.600277
# 5   8192.0     17.513062    18.332262       4.802458
# 6  16384.0     68.453377    70.337540      16.052908
# 7  32768.0    270.655487   276.020233      57.938946
# fused-attention-batch4-head48-d64-sparse-local4-vert4-hetero-bwd (num_warp=8):
# SEQ_LEN  Triton-Dense  Flash-Dense  Triton-Sparse
# 0    256.0      0.190120     0.150313       0.181451
# 1    512.0      0.406348     0.391767       0.391177
# 2   1024.0      1.029704     1.182967       0.885741
# 3   2048.0      2.985456     3.843399       2.040469
# 4   4096.0      9.808897    13.073701       5.069609
# 5   8192.0     34.995201    47.863808      13.948782
# 6  16384.0    132.740097   182.579193      42.816513
# 7  32768.0    542.223389   714.820618     147.053574
# fused-attention-inference-batch4-head32-d128-sparse-local4-vert4-hetero:
# PAST_LEN  Torch-Dense  Flash-Dense  Triton-Sparse
# 0     256.0     0.050949     0.032357       0.107513
# 1     512.0     0.073624     0.050651       0.199086
# 2    1024.0     0.107472     0.080379       0.245445
# 3    2048.0     0.178423     0.129448       0.338259
# 4    4096.0     0.327647     0.223106       0.517048
# 5    8192.0     0.588423     0.411263       0.884606
# 6   16384.0     1.098898     0.798941       1.611809
# 7   32768.0     2.094537     1.594726       3.044160


# 6.7B
# fused-attention-batch4-head32-d128-sparse-local4-vert4-hetero-fwd:
#    SEQ_LEN  Triton-Dense  Flash-Dense  Triton-Sparse
# 0    256.0      0.069208     0.082156       0.065097
# 1    512.0      0.138271     0.201393       0.144467
# 2   1024.0      0.391521     0.624614       0.322382
# 3   2048.0      1.268443     2.406325       0.784367
# 4   4096.0      4.455703     9.139097       2.100856
# 5   8192.0     16.764315    35.289600       6.328320
# 6  16384.0     65.221634   138.401794      21.069057
# 7  32768.0    257.251343   548.085754      76.111870
# fused-attention-batch4-head32-d128-sparse-local4-vert4-hetero-bwd:
#    SEQ_LEN  Triton-Dense  Flash-Dense  Triton-Sparse
# 0    256.0      0.297118     0.266469       0.255255
# 1    512.0      0.672826     0.613685       0.552954
# 2   1024.0      1.718434     1.705066       1.251953
# 3   2048.0      4.936755     5.403875       2.927895
# 4   4096.0     15.911594    18.959362       7.436288
# 5   8192.0     55.357441    70.808578      21.140224
# 6  16384.0    208.188416   273.617920      68.018173
# 7  32768.0    806.037476  1081.453613     218.720261
# fused-attention-inference-batch4-head32-d128-sparse-local4-vert4-hetero:
#    PAST_LEN  Torch-Dense  Flash-Dense  Triton-Sparse
# 0     256.0     0.050151     0.032337       0.107593
# 1     512.0     0.073409     0.051737       0.200200
# 2    1024.0     0.107533     0.082099       0.247067
# 3    2048.0     0.177259     0.128891       0.338510
# 4    4096.0     0.325866     0.223621       0.524842
# 5    8192.0     0.586926     0.408913       0.885490
# 6   16384.0     1.100834     0.793277       1.612271
# 7   32768.0     2.098851     1.595831       3.064544

# fused-attention-batch4-head32-d128-sparse-local4-vert8-hetero-fwd:
#    SEQ_LEN  Triton-Dense  Flash-Dense  Triton-Sparse
# 0    256.0      0.066673     0.082037       0.065085
# 1    512.0      0.137379     0.201880       0.143473
# 2   1024.0      0.390675     0.624234       0.312046
# 3   2048.0      1.267739     2.406950       0.696045
# 4   4096.0      4.445138     9.136333       1.665788
# 5   8192.0     16.768614    35.265533       4.380486
# 6  16384.0     65.235970   138.393600      12.997633
# 7  32768.0    257.317902   550.442993      42.821121
# fused-attention-batch4-head32-d128-sparse-local4-vert8-hetero-bwd:
#    SEQ_LEN  Triton-Dense  Flash-Dense  Triton-Sparse
# 0    256.0      0.296461     0.266581       0.254022
# 1    512.0      0.671427     0.613643       0.551283
# 2   1024.0      1.719918     1.704295       1.229982
# 3   2048.0      4.945305     5.403364       2.721906
# 4   4096.0     15.934293    18.960999       6.259371
# 5   8192.0     55.406593    70.832130      15.676929
# 6  16384.0    208.750595   275.004425      44.837891
# 7  32768.0    808.057861  1080.647705     141.856766
# fused-attention-inference-batch4-head32-d128-sparse-local4-vert8-hetero:
#    PAST_LEN  Torch-Dense  Flash-Dense  Triton-Sparse
# 0     256.0     0.050739     0.032886       0.107837
# 1     512.0     0.073507     0.051996       0.200293
# 2    1024.0     0.106394     0.080679       0.240610
# 3    2048.0     0.177659     0.127660       0.287625
# 4    4096.0     0.326326     0.226971       0.377500
# 5    8192.0     0.586339     0.407367       0.559266
# 6   16384.0     1.102279     0.786221       0.920976
# 7   32768.0     2.097370     1.545090       1.644288


################
##### fp16 #####
################

# fused-attention-batch4-head16-d64-sparse-local4-vert8-hetero-fwd:
#    SEQ_LEN  Triton-Dense  Flash-Dense  Triton-Sparse
# 0    256.0      0.032518     0.035472       0.029939
# 1    512.0      0.054266     0.087841       0.054320
# 2   1024.0      0.133447     0.263090       0.102045
# 3   2048.0      0.384615     1.023293       0.201763
# 4   4096.0      1.300890     4.023936       0.449555
# 5   8192.0      4.774144    15.816704       1.150854
# 6  16384.0     18.220032    62.771198       3.356001
# 7  32768.0     71.405571   250.273788      10.976142
# fused-attention-batch4-head16-d64-sparse-local4-vert8-hetero-bwd:
#    SEQ_LEN  Triton-Dense  Flash-Dense  Triton-Sparse
# 0    256.0      0.083342     0.069742       0.079496
# 1    512.0      0.159894     0.170995       0.151705
# 2   1024.0      0.386071     0.522407       0.331443
# 3   2048.0      1.067715     1.737333       0.715248
# 4   4096.0      3.382731     6.219520       1.597457
# 5   8192.0     11.857793    23.560448       3.879035
# 6  16384.0     44.422142    91.251709      10.626843
# 7  32768.0    175.011841   359.473145      32.340992


################
##### bf16 #####
################

# fused-attention-batch4-head16-d64-sparse-local4-vert8-hetero-fwd:
#    SEQ_LEN  Triton-Dense  Flash-Dense  Triton-Sparse
# 0    256.0      0.037636     0.035902       0.031512
# 1    512.0      0.058591     0.087229       0.058125
# 2   1024.0      0.143337     0.263919       0.108443
# 3   2048.0      0.414458     1.025985       0.214114
# 4   4096.0      1.390841     4.020010       0.480550
# 5   8192.0      5.067938    15.808171       1.230874
# 6  16384.0     19.442280    62.765057       3.597274
# 7  32768.0     75.501572   250.443771      11.768959
# fused-attention-batch4-head16-d64-sparse-local4-vert8-hetero-bwd:
#    SEQ_LEN  Triton-Dense  Flash-Dense  Triton-Sparse
# 0    256.0      0.084404     0.070663       0.082613
# 1    512.0      0.161510     0.172882       0.157661
# 2   1024.0      0.388954     0.526047       0.339855
# 3   2048.0      1.075814     1.736057       0.732420
# 4   4096.0      3.401622     6.221376       1.636039
# 5   8192.0     11.915136    23.483391       3.968725
# 6  16384.0     44.660225    91.302910      10.857130
# 7  32768.0    175.038467   359.048187      32.778240