import math

import torch
import torch.nn.functional as F


try:
    import flash_attn
    from flash_attn.flash_attn_interface import (
        _flash_attn_forward,
        flash_attn_func,
        flash_attn_varlen_func,
    )
except ImportError:
    flash_attn = None
    flash_attn_varlen_func = None
    _flash_attn_forward = None
    flash_attn_func = None

MEMORY_LAYOUT = {
    # flash模式:
    # 预处理: 输入 [batch_size, seq_len, num_heads, head_dim]
    # 后处理: 保持形状不变
    "flash": (
        lambda x: x,  # 保持形状
        lambda x: x,  # 保持形状
    ),
    # torch/vanilla模式:
    # 预处理: 交换序列和注意力头的维度 [B,S,A,D] -> [B,A,S,D]
    # 后处理: 交换回原始维度 [B,A,S,D] -> [B,S,A,D]
    "torch": (
        lambda x: x.transpose(1, 2),  # (B,S,A,D) -> (B,A,S,D)
        lambda x: x.transpose(1, 2),  # (B,A,S,D) -> (B,S,A,D)
    ),
    "vanilla": (
        lambda x: x.transpose(1, 2),
        lambda x: x.transpose(1, 2),
    ),
}


def attention(
    q,
    k,
    v,
    mode="torch",
    drop_rate=0,
    attn_mask=None,
    causal=False,
):
    """
    执行QKV自注意力计算

    Args:
        q (torch.Tensor): 查询张量,形状 [batch_size, seq_len, num_heads, head_dim]
        k (torch.Tensor): 键张量,形状 [batch_size, seq_len_kv, num_heads, head_dim]
        v (torch.Tensor): 值张量,形状 [batch_size, seq_len_kv, num_heads, head_dim]
        mode (str): 注意力模式,可选 'flash', 'torch', 'vanilla'
        drop_rate (float): 注意力矩阵的dropout概率
        attn_mask (torch.Tensor): 注意力掩码,形状根据模式不同而变化
        causal (bool): 是否使用因果注意力(仅关注前面位置)

    Returns:
        torch.Tensor: 注意力输出,形状 [batch_size, seq_len, num_heads * head_dim]
    """
    # 获取预处理和后处理函数
    pre_attn_layout, post_attn_layout = MEMORY_LAYOUT[mode]

    # 应用预处理变换
    q = pre_attn_layout(q)  # 形状根据模式变化
    k = pre_attn_layout(k)
    v = pre_attn_layout(v)

    if mode == "torch":
        # 使用PyTorch原生的scaled_dot_product_attention
        if attn_mask is not None and attn_mask.dtype != torch.bool:
            attn_mask = attn_mask.to(q.dtype)
        x = F.scaled_dot_product_attention(
            q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal
        )
    elif mode == "flash":
        assert flash_attn_func is not None, "flash_attn_func未定义"
        assert attn_mask is None, "不支持的注意力掩码"
        x: torch.Tensor = flash_attn_func(
            q, k, v, dropout_p=drop_rate, causal=causal, softmax_scale=None
        )  # type: ignore
    elif mode == "vanilla":
        # 手动实现注意力机制
        scale_factor = 1 / math.sqrt(q.size(-1))  # 缩放因子 1/sqrt(d_k)

        b, a, s, _ = q.shape  # 获取形状参数
        s1 = k.size(2)  # 键值序列长度

        # 初始化注意力偏置
        attn_bias = torch.zeros(b, a, s, s1, dtype=q.dtype, device=q.device)

        # 处理因果掩码
        if causal:
            assert attn_mask is None, "因果掩码和注意力掩码不能同时使用"
            # 生成下三角因果掩码
            temp_mask = torch.ones(b, a, s, s, dtype=torch.bool, device=q.device).tril(
                diagonal=0
            )
            attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
            attn_bias = attn_bias.to(q.dtype)

        # 处理自定义注意力掩码
        if attn_mask is not None:
            if attn_mask.dtype == torch.bool:
                attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
            else:
                attn_bias += attn_mask  # 允许类似ALiBi的位置偏置

        # 计算注意力矩阵
        attn = (q @ k.transpose(-2, -1)) * scale_factor  # [B,A,S,S1]
        attn += attn_bias

        # softmax和dropout
        attn = attn.softmax(dim=-1)
        attn = torch.dropout(attn, p=drop_rate, train=True)

        # 计算输出
        x = attn @ v  # [B,A,S,D]
    else:
        raise NotImplementedError(f"不支持的注意力模式: {mode}")

    # 应用后处理变换
    x = post_attn_layout(x)  # 恢复原始维度顺序

    # 合并注意力头维度
    b, s, a, d = x.shape
    out = x.reshape(b, s, -1)  # [B,S,A*D]
    return out