Add more build variants
Browse files- build/torch26-cxx11-cu124-x86_64-linux/flash_attn3/__init__.py +17 -0
- build/torch26-cxx11-cu124-x86_64-linux/flash_attn3/_flash_attn3_2e75662.abi3.so +3 -0
- build/torch26-cxx11-cu124-x86_64-linux/flash_attn3/_flash_attn3_557701f.abi3.so +3 -0
- build/torch26-cxx11-cu124-x86_64-linux/flash_attn3/_ops.py +9 -0
- build/torch26-cxx11-cu124-x86_64-linux/flash_attn3/flash_attn_interface.py +828 -0
- build/torch26-cxx11-cu126-x86_64-linux/flash_attn3/__init__.py +17 -0
- build/torch26-cxx11-cu126-x86_64-linux/flash_attn3/_flash_attn3_2e75662.abi3.so +3 -0
- build/torch26-cxx11-cu126-x86_64-linux/flash_attn3/_flash_attn3_557701f.abi3.so +3 -0
- build/torch26-cxx11-cu126-x86_64-linux/flash_attn3/_ops.py +9 -0
- build/torch26-cxx11-cu126-x86_64-linux/flash_attn3/flash_attn_interface.py +828 -0
- build/torch26-cxx98-cu124-x86_64-linux/flash_attn3/__init__.py +17 -0
- build/torch26-cxx98-cu124-x86_64-linux/flash_attn3/_flash_attn3_2e75662.abi3.so +3 -0
- build/torch26-cxx98-cu124-x86_64-linux/flash_attn3/_flash_attn3_557701f.abi3.so +3 -0
- build/torch26-cxx98-cu124-x86_64-linux/flash_attn3/_ops.py +9 -0
- build/torch26-cxx98-cu124-x86_64-linux/flash_attn3/flash_attn_interface.py +828 -0
- build/torch26-cxx98-cu126-x86_64-linux/flash_attn3/__init__.py +17 -0
- build/torch26-cxx98-cu126-x86_64-linux/flash_attn3/_flash_attn3_2e75662.abi3.so +3 -0
- build/torch26-cxx98-cu126-x86_64-linux/flash_attn3/_flash_attn3_557701f.abi3.so +3 -0
- build/torch26-cxx98-cu126-x86_64-linux/flash_attn3/_ops.py +9 -0
- build/torch26-cxx98-cu126-x86_64-linux/flash_attn3/flash_attn_interface.py +828 -0
- build/torch27-cxx11-cu128-x86_64-linux/flash_attn3/__init__.py +17 -0
- build/torch27-cxx11-cu128-x86_64-linux/flash_attn3/_flash_attn3_2e75662.abi3.so +3 -0
- build/torch27-cxx11-cu128-x86_64-linux/flash_attn3/_ops.py +9 -0
- build/torch27-cxx11-cu128-x86_64-linux/flash_attn3/flash_attn_interface.py +828 -0
build/torch26-cxx11-cu124-x86_64-linux/flash_attn3/__init__.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .flash_attn_interface import (
|
2 |
+
flash_attn_combine,
|
3 |
+
flash_attn_func,
|
4 |
+
flash_attn_qkvpacked_func,
|
5 |
+
flash_attn_varlen_func,
|
6 |
+
flash_attn_with_kvcache,
|
7 |
+
get_scheduler_metadata,
|
8 |
+
)
|
9 |
+
|
10 |
+
__all__ = [
|
11 |
+
"flash_attn_combine",
|
12 |
+
"flash_attn_func",
|
13 |
+
"flash_attn_qkvpacked_func",
|
14 |
+
"flash_attn_varlen_func",
|
15 |
+
"flash_attn_with_kvcache",
|
16 |
+
"get_scheduler_metadata",
|
17 |
+
]
|
build/torch26-cxx11-cu124-x86_64-linux/flash_attn3/_flash_attn3_2e75662.abi3.so
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:21b44e8e5e447a8b8ee051d347f0e32a3446a750f79d0bd1755e553f2119aa3b
|
3 |
+
size 838459656
|
build/torch26-cxx11-cu124-x86_64-linux/flash_attn3/_flash_attn3_557701f.abi3.so
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:12d4ff964085fd02252777a2008f5ca47c90ea6a93da590e2fc5065dd5330207
|
3 |
+
size 838459656
|
build/torch26-cxx11-cu124-x86_64-linux/flash_attn3/_ops.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from . import _flash_attn3_557701f
|
3 |
+
ops = torch.ops._flash_attn3_557701f
|
4 |
+
|
5 |
+
def add_op_namespace_prefix(op_name: str):
|
6 |
+
"""
|
7 |
+
Prefix op by namespace.
|
8 |
+
"""
|
9 |
+
return f"_flash_attn3_557701f::{op_name}"
|
build/torch26-cxx11-cu124-x86_64-linux/flash_attn3/flash_attn_interface.py
ADDED
@@ -0,0 +1,828 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023, Tri Dao.
|
2 |
+
|
3 |
+
from typing import Optional, Union
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
|
8 |
+
from ._ops import ops as flash_attn_3_cuda
|
9 |
+
|
10 |
+
def maybe_contiguous(x):
|
11 |
+
return x.contiguous() if x is not None and x.stride(-1) != 1 else x
|
12 |
+
|
13 |
+
|
14 |
+
def _flash_attn_forward(
|
15 |
+
q,
|
16 |
+
k,
|
17 |
+
v,
|
18 |
+
k_new,
|
19 |
+
v_new,
|
20 |
+
qv,
|
21 |
+
out,
|
22 |
+
cu_seqlens_q,
|
23 |
+
cu_seqlens_k,
|
24 |
+
cu_seqlens_k_new,
|
25 |
+
seqused_q,
|
26 |
+
seqused_k,
|
27 |
+
max_seqlen_q,
|
28 |
+
max_seqlen_k,
|
29 |
+
page_table,
|
30 |
+
kv_batch_idx,
|
31 |
+
leftpad_k,
|
32 |
+
rotary_cos,
|
33 |
+
rotary_sin,
|
34 |
+
seqlens_rotary,
|
35 |
+
q_descale,
|
36 |
+
k_descale,
|
37 |
+
v_descale,
|
38 |
+
softmax_scale,
|
39 |
+
causal,
|
40 |
+
window_size=(-1, -1),
|
41 |
+
attention_chunk=0,
|
42 |
+
softcap=0.0,
|
43 |
+
rotary_interleaved=True,
|
44 |
+
scheduler_metadata=None,
|
45 |
+
num_splits=1,
|
46 |
+
pack_gqa=None,
|
47 |
+
sm_margin=0):
|
48 |
+
q, k, k_new, v_new = [maybe_contiguous(x) for x in (q, k, k_new, v_new)]
|
49 |
+
v = v.contiguous() if v.stride(-1) != 1 and v.stride(-3) != 1 else v
|
50 |
+
cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new = [
|
51 |
+
maybe_contiguous(x) for x in (cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new)
|
52 |
+
]
|
53 |
+
seqused_q, seqused_k = [maybe_contiguous(x) for x in (seqused_q, seqused_k)]
|
54 |
+
page_table, kv_batch_idx, leftpad_k = [
|
55 |
+
maybe_contiguous(x) for x in (page_table, kv_batch_idx, leftpad_k)
|
56 |
+
]
|
57 |
+
rotary_cos, rotary_sin = [maybe_contiguous(x) for x in (rotary_cos, rotary_sin)]
|
58 |
+
seqlens_rotary = maybe_contiguous(seqlens_rotary)
|
59 |
+
out, softmax_lse, *rest = flash_attn_3_cuda.fwd(
|
60 |
+
q,
|
61 |
+
k,
|
62 |
+
v,
|
63 |
+
k_new,
|
64 |
+
v_new,
|
65 |
+
qv,
|
66 |
+
out,
|
67 |
+
cu_seqlens_q,
|
68 |
+
cu_seqlens_k,
|
69 |
+
cu_seqlens_k_new,
|
70 |
+
seqused_q,
|
71 |
+
seqused_k,
|
72 |
+
max_seqlen_q,
|
73 |
+
max_seqlen_k,
|
74 |
+
page_table,
|
75 |
+
kv_batch_idx,
|
76 |
+
leftpad_k,
|
77 |
+
rotary_cos,
|
78 |
+
rotary_sin,
|
79 |
+
seqlens_rotary,
|
80 |
+
q_descale,
|
81 |
+
k_descale,
|
82 |
+
v_descale,
|
83 |
+
softmax_scale,
|
84 |
+
causal,
|
85 |
+
window_size[0],
|
86 |
+
window_size[1],
|
87 |
+
attention_chunk,
|
88 |
+
softcap,
|
89 |
+
rotary_interleaved,
|
90 |
+
scheduler_metadata,
|
91 |
+
num_splits,
|
92 |
+
pack_gqa,
|
93 |
+
sm_margin,
|
94 |
+
)
|
95 |
+
return out, softmax_lse, *rest
|
96 |
+
|
97 |
+
|
98 |
+
def _flash_attn_backward(
|
99 |
+
dout,
|
100 |
+
q,
|
101 |
+
k,
|
102 |
+
v,
|
103 |
+
out,
|
104 |
+
softmax_lse,
|
105 |
+
cu_seqlens_q,
|
106 |
+
cu_seqlens_k,
|
107 |
+
sequed_q,
|
108 |
+
sequed_k,
|
109 |
+
max_seqlen_q,
|
110 |
+
max_seqlen_k,
|
111 |
+
dq,
|
112 |
+
dk,
|
113 |
+
dv,
|
114 |
+
softmax_scale,
|
115 |
+
causal,
|
116 |
+
window_size=(-1, -1),
|
117 |
+
softcap=0.0,
|
118 |
+
deterministic=False,
|
119 |
+
sm_margin=0,
|
120 |
+
):
|
121 |
+
# dq, dk, dv are allocated by us so they should already be contiguous
|
122 |
+
dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
|
123 |
+
dq, dk, dv, softmax_d, *rest = flash_attn_3_cuda.bwd(
|
124 |
+
dout,
|
125 |
+
q,
|
126 |
+
k,
|
127 |
+
v,
|
128 |
+
out,
|
129 |
+
softmax_lse,
|
130 |
+
dq,
|
131 |
+
dk,
|
132 |
+
dv,
|
133 |
+
cu_seqlens_q,
|
134 |
+
cu_seqlens_k,
|
135 |
+
sequed_q,
|
136 |
+
sequed_k,
|
137 |
+
max_seqlen_q,
|
138 |
+
max_seqlen_k,
|
139 |
+
softmax_scale,
|
140 |
+
causal,
|
141 |
+
window_size[0],
|
142 |
+
window_size[1],
|
143 |
+
softcap,
|
144 |
+
deterministic,
|
145 |
+
sm_margin,
|
146 |
+
)
|
147 |
+
return dq, dk, dv, softmax_d
|
148 |
+
|
149 |
+
|
150 |
+
class FlashAttnQKVPackedFunc(torch.autograd.Function):
|
151 |
+
@staticmethod
|
152 |
+
def forward(
|
153 |
+
ctx,
|
154 |
+
qkv,
|
155 |
+
softmax_scale,
|
156 |
+
causal,
|
157 |
+
q_descale=None, k_descale=None, v_descale=None,
|
158 |
+
window_size=(-1, -1),
|
159 |
+
attention_chunk=0,
|
160 |
+
softcap=0.0,
|
161 |
+
deterministic=False,
|
162 |
+
num_heads_q=None,
|
163 |
+
sm_margin=0,
|
164 |
+
):
|
165 |
+
if softmax_scale is None:
|
166 |
+
softmax_scale = qkv.shape[-1] ** (-0.5)
|
167 |
+
if qkv.dim() == 5:
|
168 |
+
assert qkv.shape[-3] == 3
|
169 |
+
q, k, v = qkv.unbind(dim=-3)
|
170 |
+
else:
|
171 |
+
assert qkv.dim() == 4
|
172 |
+
assert num_heads_q is not None
|
173 |
+
num_heads_k = (qkv.shape[2] - num_heads_q) // 2
|
174 |
+
assert num_heads_k * 2 + num_heads_q == qkv.shape[2]
|
175 |
+
q, k, v = qkv.split([num_heads_q, num_heads_k, num_heads_k], dim=-2)
|
176 |
+
out, softmax_lse, *rest = _flash_attn_forward(
|
177 |
+
q,
|
178 |
+
k,
|
179 |
+
v,
|
180 |
+
None, None, # k_new, v_new
|
181 |
+
None, # qv
|
182 |
+
None, # out
|
183 |
+
None, None, None, # cu_seqlens_q/k/k_new
|
184 |
+
None, None, # seqused_q/k
|
185 |
+
None, None, # max_seqlen_q/k
|
186 |
+
None, None, None, # page_table, kv_batch_idx, leftpad_k,
|
187 |
+
None, None, None, # rotary_cos/sin, seqlens_rotary
|
188 |
+
q_descale, k_descale, v_descale,
|
189 |
+
softmax_scale,
|
190 |
+
causal=causal,
|
191 |
+
window_size=window_size,
|
192 |
+
attention_chunk=attention_chunk,
|
193 |
+
softcap=softcap,
|
194 |
+
sm_margin=sm_margin,
|
195 |
+
)
|
196 |
+
# ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
|
197 |
+
ctx.save_for_backward(q, k, v, out, softmax_lse)
|
198 |
+
ctx.softmax_scale = softmax_scale
|
199 |
+
ctx.causal = causal
|
200 |
+
ctx.window_size = window_size
|
201 |
+
ctx.attention_chunk = attention_chunk
|
202 |
+
ctx.softcap = softcap
|
203 |
+
ctx.deterministic = deterministic
|
204 |
+
ctx.ndim = qkv.dim()
|
205 |
+
ctx.sm_margin = sm_margin
|
206 |
+
# return out, softmax_lse
|
207 |
+
return out
|
208 |
+
|
209 |
+
@staticmethod
|
210 |
+
def backward(ctx, dout, *args):
|
211 |
+
q, k, v, out, softmax_lse = ctx.saved_tensors
|
212 |
+
assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
|
213 |
+
if ctx.ndim == 5:
|
214 |
+
qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
|
215 |
+
dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
|
216 |
+
dq, dk, dv = dqkv.unbind(dim=-3)
|
217 |
+
else:
|
218 |
+
num_heads_q = q.shape[2]
|
219 |
+
num_heads_k = k.shape[2]
|
220 |
+
qkv_shape = q.shape[:-2] + (num_heads_q + num_heads_k * 2, *q.shape[-1:])
|
221 |
+
dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
|
222 |
+
dq, dk, dv = dqkv.split([num_heads_q, num_heads_k, num_heads_k], dim=-2)
|
223 |
+
_flash_attn_backward(
|
224 |
+
dout,
|
225 |
+
q,
|
226 |
+
k,
|
227 |
+
v,
|
228 |
+
out,
|
229 |
+
softmax_lse,
|
230 |
+
None, None, # cu_seqlens_q, cu_seqlens_k,
|
231 |
+
None, None, # sequed_q, sequed_k,
|
232 |
+
None, None, # max_seqlen_q, max_seqlen_k,
|
233 |
+
dq,
|
234 |
+
dk,
|
235 |
+
dv,
|
236 |
+
ctx.softmax_scale,
|
237 |
+
ctx.causal,
|
238 |
+
ctx.window_size,
|
239 |
+
ctx.softcap,
|
240 |
+
ctx.deterministic,
|
241 |
+
ctx.sm_margin,
|
242 |
+
)
|
243 |
+
dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension
|
244 |
+
return dqkv, None, None, None, None, None, None, None, None, None, None, None
|
245 |
+
|
246 |
+
|
247 |
+
class FlashAttnFunc(torch.autograd.Function):
|
248 |
+
|
249 |
+
@staticmethod
|
250 |
+
def forward(
|
251 |
+
ctx,
|
252 |
+
q,
|
253 |
+
k,
|
254 |
+
v,
|
255 |
+
softmax_scale,
|
256 |
+
causal,
|
257 |
+
qv=None,
|
258 |
+
q_descale=None, k_descale=None, v_descale=None,
|
259 |
+
window_size=(-1, -1),
|
260 |
+
attention_chunk=0,
|
261 |
+
softcap=0.0,
|
262 |
+
num_splits=1,
|
263 |
+
pack_gqa=None,
|
264 |
+
deterministic=False,
|
265 |
+
sm_margin=0,
|
266 |
+
):
|
267 |
+
if softmax_scale is None:
|
268 |
+
softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
|
269 |
+
# out, q, k, v, out_padded, softmax_lse = _flash_attn_forward(
|
270 |
+
out, softmax_lse, *rest = _flash_attn_forward(
|
271 |
+
q,
|
272 |
+
k,
|
273 |
+
v,
|
274 |
+
None, None, # k_new, v_new
|
275 |
+
qv, # qv
|
276 |
+
None, # out
|
277 |
+
None, None, None, # cu_seqlens_q/k/k_new
|
278 |
+
None, None, # seqused_q/k
|
279 |
+
None, None, # max_seqlen_q/k
|
280 |
+
None, None, None, # page_table, kv_batch_idx, leftpad_k,
|
281 |
+
None, None, None, # rotary_cos/sin, seqlens_rotary
|
282 |
+
q_descale, k_descale, v_descale,
|
283 |
+
softmax_scale,
|
284 |
+
causal=causal,
|
285 |
+
window_size=window_size,
|
286 |
+
attention_chunk=attention_chunk,
|
287 |
+
softcap=softcap,
|
288 |
+
num_splits=num_splits,
|
289 |
+
pack_gqa=pack_gqa,
|
290 |
+
sm_margin=sm_margin,
|
291 |
+
)
|
292 |
+
# ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
|
293 |
+
ctx.save_for_backward(q, k, v, out, softmax_lse)
|
294 |
+
ctx.softmax_scale = softmax_scale
|
295 |
+
ctx.causal = causal
|
296 |
+
ctx.window_size = window_size
|
297 |
+
ctx.attention_chunk = attention_chunk
|
298 |
+
ctx.softcap = softcap
|
299 |
+
ctx.deterministic = deterministic
|
300 |
+
ctx.sm_margin = sm_margin
|
301 |
+
return out, softmax_lse
|
302 |
+
|
303 |
+
@staticmethod
|
304 |
+
def backward(ctx, dout, *args):
|
305 |
+
q, k, v, out, softmax_lse = ctx.saved_tensors
|
306 |
+
assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
|
307 |
+
dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
|
308 |
+
_flash_attn_backward(
|
309 |
+
dout,
|
310 |
+
q,
|
311 |
+
k,
|
312 |
+
v,
|
313 |
+
out,
|
314 |
+
softmax_lse,
|
315 |
+
None, None, # cu_seqlens_q, cu_seqlens_k,
|
316 |
+
None, None, # sequed_q, sequed_k,
|
317 |
+
None, None, # max_seqlen_q, max_seqlen_k,
|
318 |
+
dq,
|
319 |
+
dk,
|
320 |
+
dv,
|
321 |
+
ctx.softmax_scale,
|
322 |
+
ctx.causal,
|
323 |
+
ctx.window_size,
|
324 |
+
ctx.softcap,
|
325 |
+
ctx.deterministic,
|
326 |
+
ctx.sm_margin,
|
327 |
+
)
|
328 |
+
dq = dq[..., : q.shape[-1]] # We could have padded the head dimension
|
329 |
+
dk = dk[..., : k.shape[-1]]
|
330 |
+
dv = dv[..., : v.shape[-1]]
|
331 |
+
return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None
|
332 |
+
|
333 |
+
|
334 |
+
class FlashAttnVarlenFunc(torch.autograd.Function):
|
335 |
+
|
336 |
+
@staticmethod
|
337 |
+
def forward(
|
338 |
+
ctx,
|
339 |
+
q,
|
340 |
+
k,
|
341 |
+
v,
|
342 |
+
cu_seqlens_q,
|
343 |
+
cu_seqlens_k,
|
344 |
+
seqused_q,
|
345 |
+
seqused_k,
|
346 |
+
max_seqlen_q,
|
347 |
+
max_seqlen_k,
|
348 |
+
softmax_scale,
|
349 |
+
causal,
|
350 |
+
qv=None,
|
351 |
+
q_descale=None, k_descale=None, v_descale=None,
|
352 |
+
window_size=(-1, -1),
|
353 |
+
attention_chunk=0,
|
354 |
+
softcap=0.0,
|
355 |
+
num_splits=1,
|
356 |
+
pack_gqa=None,
|
357 |
+
deterministic=False,
|
358 |
+
sm_margin=0,
|
359 |
+
):
|
360 |
+
if softmax_scale is None:
|
361 |
+
softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
|
362 |
+
# out, q, k, v, out_padded, softmax_lse = _flash_attn_varlen_forward(
|
363 |
+
out, softmax_lse, *rest = _flash_attn_forward(
|
364 |
+
q,
|
365 |
+
k,
|
366 |
+
v,
|
367 |
+
None, None, # k_new, v_new
|
368 |
+
qv, # qv
|
369 |
+
None, # out
|
370 |
+
cu_seqlens_q,
|
371 |
+
cu_seqlens_k,
|
372 |
+
None, # cu_seqlens_k_new
|
373 |
+
seqused_q,
|
374 |
+
seqused_k,
|
375 |
+
max_seqlen_q,
|
376 |
+
max_seqlen_k,
|
377 |
+
None, None, None, # page_table, kv_batch_idx, leftpad_k,
|
378 |
+
None, None, None, # rotary_cos/sin, seqlens_rotary
|
379 |
+
q_descale, k_descale, v_descale,
|
380 |
+
softmax_scale,
|
381 |
+
causal=causal,
|
382 |
+
window_size=window_size,
|
383 |
+
attention_chunk=attention_chunk,
|
384 |
+
softcap=softcap,
|
385 |
+
num_splits=num_splits,
|
386 |
+
pack_gqa=pack_gqa,
|
387 |
+
sm_margin=sm_margin,
|
388 |
+
)
|
389 |
+
# ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
|
390 |
+
ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
|
391 |
+
ctx.max_seqlen_q = max_seqlen_q
|
392 |
+
ctx.max_seqlen_k = max_seqlen_k
|
393 |
+
ctx.softmax_scale = softmax_scale
|
394 |
+
ctx.causal = causal
|
395 |
+
ctx.window_size = window_size
|
396 |
+
ctx.attention_chunk = attention_chunk
|
397 |
+
ctx.softcap = softcap
|
398 |
+
ctx.deterministic = deterministic
|
399 |
+
ctx.sm_margin = sm_margin
|
400 |
+
return out, softmax_lse
|
401 |
+
|
402 |
+
@staticmethod
|
403 |
+
def backward(ctx, dout, *args):
|
404 |
+
q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = ctx.saved_tensors
|
405 |
+
assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
|
406 |
+
dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
|
407 |
+
_flash_attn_backward(
|
408 |
+
dout,
|
409 |
+
q,
|
410 |
+
k,
|
411 |
+
v,
|
412 |
+
out,
|
413 |
+
softmax_lse,
|
414 |
+
cu_seqlens_q,
|
415 |
+
cu_seqlens_k,
|
416 |
+
seqused_q,
|
417 |
+
seqused_k,
|
418 |
+
ctx.max_seqlen_q,
|
419 |
+
ctx.max_seqlen_k,
|
420 |
+
dq,
|
421 |
+
dk,
|
422 |
+
dv,
|
423 |
+
ctx.softmax_scale,
|
424 |
+
ctx.causal,
|
425 |
+
ctx.window_size,
|
426 |
+
ctx.softcap,
|
427 |
+
ctx.deterministic,
|
428 |
+
ctx.sm_margin,
|
429 |
+
)
|
430 |
+
dq = dq[..., : q.shape[-1]] # We could have padded the head dimension
|
431 |
+
dk = dk[..., : k.shape[-1]]
|
432 |
+
dv = dv[..., : v.shape[-1]]
|
433 |
+
return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None
|
434 |
+
|
435 |
+
|
436 |
+
def flash_attn_qkvpacked_func(
|
437 |
+
qkv,
|
438 |
+
softmax_scale=None,
|
439 |
+
causal=False,
|
440 |
+
q_descale=None, k_descale=None, v_descale=None,
|
441 |
+
window_size=(-1, -1),
|
442 |
+
attention_chunk=0,
|
443 |
+
softcap=0.0,
|
444 |
+
deterministic=False,
|
445 |
+
num_heads_q=None,
|
446 |
+
sm_margin=0,
|
447 |
+
):
|
448 |
+
"""dropout_p should be set to 0.0 during evaluation
|
449 |
+
If Q, K, V are already stacked into 1 tensor, this function will be faster than
|
450 |
+
calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
|
451 |
+
of the gradients of Q, K, V.
|
452 |
+
For multi-query and grouped-query attention (MQA/GQA), please see
|
453 |
+
flash_attn_kvpacked_func and flash_attn_func.
|
454 |
+
|
455 |
+
If window_size != (-1, -1), implements sliding window local attention. Query at position i
|
456 |
+
will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.
|
457 |
+
|
458 |
+
Arguments:
|
459 |
+
qkv: (batch_size, seqlen, 3, nheads, headdim)
|
460 |
+
dropout_p: float. Dropout probability.
|
461 |
+
softmax_scale: float. The scaling of QK^T before applying softmax.
|
462 |
+
Default to 1 / sqrt(headdim).
|
463 |
+
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
|
464 |
+
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
|
465 |
+
softcap: float. Anything > 0 activates softcapping attention.
|
466 |
+
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to
|
467 |
+
the attention score of query i and key j.
|
468 |
+
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
|
469 |
+
which is slightly slower and uses more memory. The forward pass is always deterministic.
|
470 |
+
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
|
471 |
+
testing only. The returned probabilities are not guaranteed to be correct
|
472 |
+
(they might not have the right scaling).
|
473 |
+
Return:
|
474 |
+
out: (batch_size, seqlen, nheads, headdim).
|
475 |
+
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
|
476 |
+
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
|
477 |
+
normalization factor).
|
478 |
+
S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
|
479 |
+
The output of softmax (possibly with different scaling). It also encodes the dropout
|
480 |
+
pattern (negative means that location was dropped, nonnegative means it was kept).
|
481 |
+
"""
|
482 |
+
return FlashAttnQKVPackedFunc.apply(
|
483 |
+
qkv,
|
484 |
+
softmax_scale,
|
485 |
+
causal,
|
486 |
+
q_descale, k_descale, v_descale,
|
487 |
+
window_size,
|
488 |
+
attention_chunk,
|
489 |
+
softcap,
|
490 |
+
deterministic,
|
491 |
+
num_heads_q,
|
492 |
+
sm_margin,
|
493 |
+
)
|
494 |
+
|
495 |
+
|
496 |
+
def flash_attn_func(
|
497 |
+
q,
|
498 |
+
k,
|
499 |
+
v,
|
500 |
+
softmax_scale=None,
|
501 |
+
causal=False,
|
502 |
+
qv=None,
|
503 |
+
q_descale=None, k_descale=None, v_descale=None,
|
504 |
+
window_size=(-1, -1),
|
505 |
+
attention_chunk=0,
|
506 |
+
softcap=0.0,
|
507 |
+
num_splits=1,
|
508 |
+
pack_gqa=None,
|
509 |
+
deterministic=False,
|
510 |
+
sm_margin=0,
|
511 |
+
):
|
512 |
+
"""dropout_p should be set to 0.0 during evaluation
|
513 |
+
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
|
514 |
+
than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
|
515 |
+
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
|
516 |
+
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
|
517 |
+
|
518 |
+
If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
|
519 |
+
For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
|
520 |
+
1 1 1 1 0
|
521 |
+
1 1 1 1 1
|
522 |
+
If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
|
523 |
+
0 0
|
524 |
+
0 0
|
525 |
+
0 0
|
526 |
+
1 0
|
527 |
+
1 1
|
528 |
+
If the row of the mask is all zero, the output will be zero.
|
529 |
+
|
530 |
+
If window_size != (-1, -1), implements sliding window local attention. Query at position i
|
531 |
+
will only attend to keys between
|
532 |
+
[i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
|
533 |
+
|
534 |
+
Arguments:
|
535 |
+
q: (batch_size, seqlen, nheads, headdim)
|
536 |
+
k: (batch_size, seqlen, nheads_k, headdim)
|
537 |
+
v: (batch_size, seqlen, nheads_k, headdim)
|
538 |
+
dropout_p: float. Dropout probability.
|
539 |
+
softmax_scale: float. The scaling of QK^T before applying softmax.
|
540 |
+
Default to 1 / sqrt(headdim).
|
541 |
+
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
|
542 |
+
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
|
543 |
+
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
|
544 |
+
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
|
545 |
+
is added to the attention score of query i and key j.
|
546 |
+
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
|
547 |
+
which is slightly slower and uses more memory. The forward pass is always deterministic.
|
548 |
+
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
|
549 |
+
testing only. The returned probabilities are not guaranteed to be correct
|
550 |
+
(they might not have the right scaling).
|
551 |
+
Return:
|
552 |
+
out: (batch_size, seqlen, nheads, headdim).
|
553 |
+
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
|
554 |
+
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
|
555 |
+
normalization factor).
|
556 |
+
"""
|
557 |
+
return FlashAttnFunc.apply(
|
558 |
+
q,
|
559 |
+
k,
|
560 |
+
v,
|
561 |
+
softmax_scale,
|
562 |
+
causal,
|
563 |
+
qv,
|
564 |
+
q_descale, k_descale, v_descale,
|
565 |
+
window_size,
|
566 |
+
attention_chunk,
|
567 |
+
softcap,
|
568 |
+
num_splits,
|
569 |
+
pack_gqa,
|
570 |
+
deterministic,
|
571 |
+
sm_margin,
|
572 |
+
)
|
573 |
+
|
574 |
+
|
575 |
+
def flash_attn_varlen_func(
|
576 |
+
q,
|
577 |
+
k,
|
578 |
+
v,
|
579 |
+
cu_seqlens_q,
|
580 |
+
cu_seqlens_k,
|
581 |
+
max_seqlen_q,
|
582 |
+
max_seqlen_k,
|
583 |
+
seqused_q=None,
|
584 |
+
seqused_k=None,
|
585 |
+
softmax_scale=None,
|
586 |
+
causal=False,
|
587 |
+
qv=None,
|
588 |
+
q_descale=None, k_descale=None, v_descale=None,
|
589 |
+
window_size=(-1, -1),
|
590 |
+
attention_chunk=0,
|
591 |
+
softcap=0.0,
|
592 |
+
num_splits=1,
|
593 |
+
pack_gqa=None,
|
594 |
+
deterministic=False,
|
595 |
+
sm_margin=0,
|
596 |
+
):
|
597 |
+
return FlashAttnVarlenFunc.apply(
|
598 |
+
q,
|
599 |
+
k,
|
600 |
+
v,
|
601 |
+
cu_seqlens_q,
|
602 |
+
cu_seqlens_k,
|
603 |
+
seqused_q,
|
604 |
+
seqused_k,
|
605 |
+
max_seqlen_q,
|
606 |
+
max_seqlen_k,
|
607 |
+
softmax_scale,
|
608 |
+
causal,
|
609 |
+
qv,
|
610 |
+
q_descale, k_descale, v_descale,
|
611 |
+
window_size,
|
612 |
+
attention_chunk,
|
613 |
+
softcap,
|
614 |
+
num_splits,
|
615 |
+
pack_gqa,
|
616 |
+
deterministic,
|
617 |
+
sm_margin,
|
618 |
+
)
|
619 |
+
|
620 |
+
|
621 |
+
def flash_attn_combine(out_partial, lse_partial, out=None, out_dtype=None):
|
622 |
+
return flash_attn_3_cuda.fwd_combine(out_partial, lse_partial, out, out_dtype)
|
623 |
+
|
624 |
+
|
625 |
+
def flash_attn_with_kvcache(
|
626 |
+
q,
|
627 |
+
k_cache,
|
628 |
+
v_cache,
|
629 |
+
k=None,
|
630 |
+
v=None,
|
631 |
+
qv=None,
|
632 |
+
rotary_cos=None,
|
633 |
+
rotary_sin=None,
|
634 |
+
cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
|
635 |
+
cache_batch_idx: Optional[torch.Tensor] = None,
|
636 |
+
cache_leftpad: Optional[torch.Tensor] = None,
|
637 |
+
page_table: Optional[torch.Tensor] = None,
|
638 |
+
cu_seqlens_q: Optional[torch.Tensor] = None,
|
639 |
+
cu_seqlens_k_new: Optional[torch.Tensor] = None,
|
640 |
+
max_seqlen_q: Optional[int] = None,
|
641 |
+
rotary_seqlens: Optional[torch.Tensor] = None,
|
642 |
+
q_descale: Optional[torch.Tensor] = None,
|
643 |
+
k_descale: Optional[torch.Tensor] = None,
|
644 |
+
v_descale: Optional[torch.Tensor] = None,
|
645 |
+
softmax_scale=None,
|
646 |
+
causal=False,
|
647 |
+
window_size=(-1, -1), # -1 means infinite context window
|
648 |
+
attention_chunk=0,
|
649 |
+
softcap=0.0, # 0.0 means deactivated
|
650 |
+
rotary_interleaved=True,
|
651 |
+
scheduler_metadata=None,
|
652 |
+
num_splits=0, # Can be tuned for speed
|
653 |
+
pack_gqa=None, # Can be tuned for speed
|
654 |
+
sm_margin=0, # Can be tuned if some SMs are used for communication
|
655 |
+
return_softmax_lse=False,
|
656 |
+
):
|
657 |
+
"""
|
658 |
+
If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
|
659 |
+
k and v. This is useful for incremental decoding: you can pass in the cached keys/values from
|
660 |
+
the previous step, and update them with the new keys/values from the current step, and do
|
661 |
+
attention with the updated cache, all in 1 kernel.
|
662 |
+
|
663 |
+
If you pass in k / v, you must make sure that the cache is large enough to hold the new values.
|
664 |
+
For example, the KV cache could be pre-allocated with the max sequence length, and you can use
|
665 |
+
cache_seqlens to keep track of the current sequence lengths of each sequence in the batch.
|
666 |
+
|
667 |
+
Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be
|
668 |
+
rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
|
669 |
+
If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos
|
670 |
+
and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
|
671 |
+
If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at
|
672 |
+
indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens).
|
673 |
+
|
674 |
+
See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function.
|
675 |
+
|
676 |
+
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
|
677 |
+
than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
|
678 |
+
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
|
679 |
+
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
|
680 |
+
|
681 |
+
If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
|
682 |
+
For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
|
683 |
+
1 1 1 1 0
|
684 |
+
1 1 1 1 1
|
685 |
+
If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
|
686 |
+
0 0
|
687 |
+
0 0
|
688 |
+
0 0
|
689 |
+
1 0
|
690 |
+
1 1
|
691 |
+
If the row of the mask is all zero, the output will be zero.
|
692 |
+
|
693 |
+
If window_size != (-1, -1), implements sliding window local attention. Query at position i
|
694 |
+
will only attend to keys between
|
695 |
+
[i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
|
696 |
+
|
697 |
+
Note: Does not support backward pass.
|
698 |
+
|
699 |
+
Arguments:
|
700 |
+
q: (batch_size, seqlen, nheads, headdim)
|
701 |
+
k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no page_table,
|
702 |
+
or (num_blocks, page_block_size, nheads_k, headdim) if there's a page_table (i.e. paged KV cache)
|
703 |
+
page_block_size must be a multiple of 256.
|
704 |
+
v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim_v) if there's no page_table,
|
705 |
+
or (num_blocks, page_block_size, nheads_k, headdim_v) if there's a page_table (i.e. paged KV cache)
|
706 |
+
k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
|
707 |
+
k with k_cache, starting at the indices specified by cache_seqlens.
|
708 |
+
v [optional]: (batch_size, seqlen_new, nheads_k, headdim_v). Similar to k.
|
709 |
+
qv [optional]: (batch_size, seqlen, nheads, headdim_v)
|
710 |
+
rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding
|
711 |
+
to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16.
|
712 |
+
rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
|
713 |
+
cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
|
714 |
+
KV cache.
|
715 |
+
cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.
|
716 |
+
If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].
|
717 |
+
If the indices are not distinct, and k and v are provided, the values updated in the cache
|
718 |
+
might come from any of the duplicate indices.
|
719 |
+
cache_leftpad: (batch_size,), dtype torch.int32. The index that the KV cache starts. If None, assume 0.
|
720 |
+
page_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32.
|
721 |
+
softmax_scale: float. The scaling of QK^T before applying softmax.
|
722 |
+
Default to 1 / sqrt(headdim).
|
723 |
+
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
|
724 |
+
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
|
725 |
+
softcap: float. Anything > 0 activates softcapping attention.
|
726 |
+
rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
|
727 |
+
If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
|
728 |
+
rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
|
729 |
+
(i.e. GPT-NeoX style).
|
730 |
+
num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
|
731 |
+
If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
|
732 |
+
to automatically determine the number of splits.
|
733 |
+
Don't change this unless you know what you are doing.
|
734 |
+
return_softmax_lse: bool. Whether to return the logsumexp of the attention scores.
|
735 |
+
|
736 |
+
Return:
|
737 |
+
out: (batch_size, seqlen, nheads, headdim).
|
738 |
+
softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The
|
739 |
+
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
|
740 |
+
normalization factor).
|
741 |
+
"""
|
742 |
+
assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
|
743 |
+
assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
|
744 |
+
if softmax_scale is None:
|
745 |
+
softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
|
746 |
+
if cache_seqlens is not None and isinstance(cache_seqlens, int):
|
747 |
+
cache_seqlens = torch.full(
|
748 |
+
(k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device
|
749 |
+
)
|
750 |
+
cache_seqlens = maybe_contiguous(cache_seqlens)
|
751 |
+
out, softmax_lse, *rest = _flash_attn_forward(
|
752 |
+
q,
|
753 |
+
k_cache,
|
754 |
+
v_cache,
|
755 |
+
k,
|
756 |
+
v,
|
757 |
+
qv,
|
758 |
+
None, # out
|
759 |
+
cu_seqlens_q,
|
760 |
+
None, # cu_seqlens_k
|
761 |
+
cu_seqlens_k_new,
|
762 |
+
None, # seqused_q
|
763 |
+
cache_seqlens,
|
764 |
+
max_seqlen_q,
|
765 |
+
None, # max_seqlen_k
|
766 |
+
page_table,
|
767 |
+
cache_batch_idx,
|
768 |
+
cache_leftpad,
|
769 |
+
rotary_cos,
|
770 |
+
rotary_sin,
|
771 |
+
rotary_seqlens,
|
772 |
+
q_descale, k_descale, v_descale,
|
773 |
+
softmax_scale,
|
774 |
+
causal=causal,
|
775 |
+
window_size=window_size,
|
776 |
+
attention_chunk=attention_chunk,
|
777 |
+
softcap=softcap,
|
778 |
+
rotary_interleaved=rotary_interleaved,
|
779 |
+
scheduler_metadata=scheduler_metadata,
|
780 |
+
num_splits=num_splits,
|
781 |
+
pack_gqa=pack_gqa,
|
782 |
+
sm_margin=sm_margin,
|
783 |
+
)
|
784 |
+
# return (out, softmax_lse) if return_softmax_lse else out
|
785 |
+
return (out, softmax_lse, *rest) if return_softmax_lse else out
|
786 |
+
|
787 |
+
|
788 |
+
def get_scheduler_metadata(
|
789 |
+
batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim,
|
790 |
+
cache_seqlens: torch.Tensor,
|
791 |
+
qkv_dtype=torch.bfloat16,
|
792 |
+
headdim_v=None,
|
793 |
+
cu_seqlens_q: Optional[torch.Tensor] = None,
|
794 |
+
cu_seqlens_k_new: Optional[torch.Tensor] = None,
|
795 |
+
cache_leftpad: Optional[torch.Tensor] = None,
|
796 |
+
page_size: Optional[int] = None,
|
797 |
+
max_seqlen_k_new=0,
|
798 |
+
causal=False,
|
799 |
+
window_size=(-1, -1), # -1 means infinite context window
|
800 |
+
attention_chunk=0,
|
801 |
+
has_softcap=False,
|
802 |
+
num_splits=0, # Can be tuned for speed
|
803 |
+
pack_gqa=None, # Can be tuned for speed
|
804 |
+
sm_margin=0, # Can be tuned if some SMs are used for communication
|
805 |
+
):
|
806 |
+
cache_seqlens = maybe_contiguous(cache_seqlens)
|
807 |
+
if headdim_v is None:
|
808 |
+
headdim_v = headdim
|
809 |
+
scheduler_metadata = flash_attn_3_cuda.get_scheduler_metadata(
|
810 |
+
batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim, headdim_v,
|
811 |
+
qkv_dtype,
|
812 |
+
cache_seqlens,
|
813 |
+
cu_seqlens_q,
|
814 |
+
None, # cu_seqlens_k
|
815 |
+
cu_seqlens_k_new,
|
816 |
+
None, # seqused_q
|
817 |
+
cache_leftpad,
|
818 |
+
page_size,
|
819 |
+
max_seqlen_k_new,
|
820 |
+
causal,
|
821 |
+
window_size[0], window_size[1],
|
822 |
+
attention_chunk,
|
823 |
+
has_softcap,
|
824 |
+
num_splits,
|
825 |
+
pack_gqa,
|
826 |
+
sm_margin,
|
827 |
+
)
|
828 |
+
return scheduler_metadata
|
build/torch26-cxx11-cu126-x86_64-linux/flash_attn3/__init__.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .flash_attn_interface import (
|
2 |
+
flash_attn_combine,
|
3 |
+
flash_attn_func,
|
4 |
+
flash_attn_qkvpacked_func,
|
5 |
+
flash_attn_varlen_func,
|
6 |
+
flash_attn_with_kvcache,
|
7 |
+
get_scheduler_metadata,
|
8 |
+
)
|
9 |
+
|
10 |
+
__all__ = [
|
11 |
+
"flash_attn_combine",
|
12 |
+
"flash_attn_func",
|
13 |
+
"flash_attn_qkvpacked_func",
|
14 |
+
"flash_attn_varlen_func",
|
15 |
+
"flash_attn_with_kvcache",
|
16 |
+
"get_scheduler_metadata",
|
17 |
+
]
|
build/torch26-cxx11-cu126-x86_64-linux/flash_attn3/_flash_attn3_2e75662.abi3.so
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:21b44e8e5e447a8b8ee051d347f0e32a3446a750f79d0bd1755e553f2119aa3b
|
3 |
+
size 838459656
|
build/torch26-cxx11-cu126-x86_64-linux/flash_attn3/_flash_attn3_557701f.abi3.so
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:12d4ff964085fd02252777a2008f5ca47c90ea6a93da590e2fc5065dd5330207
|
3 |
+
size 838459656
|
build/torch26-cxx11-cu126-x86_64-linux/flash_attn3/_ops.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from . import _flash_attn3_557701f
|
3 |
+
ops = torch.ops._flash_attn3_557701f
|
4 |
+
|
5 |
+
def add_op_namespace_prefix(op_name: str):
|
6 |
+
"""
|
7 |
+
Prefix op by namespace.
|
8 |
+
"""
|
9 |
+
return f"_flash_attn3_557701f::{op_name}"
|
build/torch26-cxx11-cu126-x86_64-linux/flash_attn3/flash_attn_interface.py
ADDED
@@ -0,0 +1,828 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023, Tri Dao.
|
2 |
+
|
3 |
+
from typing import Optional, Union
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
|
8 |
+
from ._ops import ops as flash_attn_3_cuda
|
9 |
+
|
10 |
+
def maybe_contiguous(x):
|
11 |
+
return x.contiguous() if x is not None and x.stride(-1) != 1 else x
|
12 |
+
|
13 |
+
|
14 |
+
def _flash_attn_forward(
|
15 |
+
q,
|
16 |
+
k,
|
17 |
+
v,
|
18 |
+
k_new,
|
19 |
+
v_new,
|
20 |
+
qv,
|
21 |
+
out,
|
22 |
+
cu_seqlens_q,
|
23 |
+
cu_seqlens_k,
|
24 |
+
cu_seqlens_k_new,
|
25 |
+
seqused_q,
|
26 |
+
seqused_k,
|
27 |
+
max_seqlen_q,
|
28 |
+
max_seqlen_k,
|
29 |
+
page_table,
|
30 |
+
kv_batch_idx,
|
31 |
+
leftpad_k,
|
32 |
+
rotary_cos,
|
33 |
+
rotary_sin,
|
34 |
+
seqlens_rotary,
|
35 |
+
q_descale,
|
36 |
+
k_descale,
|
37 |
+
v_descale,
|
38 |
+
softmax_scale,
|
39 |
+
causal,
|
40 |
+
window_size=(-1, -1),
|
41 |
+
attention_chunk=0,
|
42 |
+
softcap=0.0,
|
43 |
+
rotary_interleaved=True,
|
44 |
+
scheduler_metadata=None,
|
45 |
+
num_splits=1,
|
46 |
+
pack_gqa=None,
|
47 |
+
sm_margin=0):
|
48 |
+
q, k, k_new, v_new = [maybe_contiguous(x) for x in (q, k, k_new, v_new)]
|
49 |
+
v = v.contiguous() if v.stride(-1) != 1 and v.stride(-3) != 1 else v
|
50 |
+
cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new = [
|
51 |
+
maybe_contiguous(x) for x in (cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new)
|
52 |
+
]
|
53 |
+
seqused_q, seqused_k = [maybe_contiguous(x) for x in (seqused_q, seqused_k)]
|
54 |
+
page_table, kv_batch_idx, leftpad_k = [
|
55 |
+
maybe_contiguous(x) for x in (page_table, kv_batch_idx, leftpad_k)
|
56 |
+
]
|
57 |
+
rotary_cos, rotary_sin = [maybe_contiguous(x) for x in (rotary_cos, rotary_sin)]
|
58 |
+
seqlens_rotary = maybe_contiguous(seqlens_rotary)
|
59 |
+
out, softmax_lse, *rest = flash_attn_3_cuda.fwd(
|
60 |
+
q,
|
61 |
+
k,
|
62 |
+
v,
|
63 |
+
k_new,
|
64 |
+
v_new,
|
65 |
+
qv,
|
66 |
+
out,
|
67 |
+
cu_seqlens_q,
|
68 |
+
cu_seqlens_k,
|
69 |
+
cu_seqlens_k_new,
|
70 |
+
seqused_q,
|
71 |
+
seqused_k,
|
72 |
+
max_seqlen_q,
|
73 |
+
max_seqlen_k,
|
74 |
+
page_table,
|
75 |
+
kv_batch_idx,
|
76 |
+
leftpad_k,
|
77 |
+
rotary_cos,
|
78 |
+
rotary_sin,
|
79 |
+
seqlens_rotary,
|
80 |
+
q_descale,
|
81 |
+
k_descale,
|
82 |
+
v_descale,
|
83 |
+
softmax_scale,
|
84 |
+
causal,
|
85 |
+
window_size[0],
|
86 |
+
window_size[1],
|
87 |
+
attention_chunk,
|
88 |
+
softcap,
|
89 |
+
rotary_interleaved,
|
90 |
+
scheduler_metadata,
|
91 |
+
num_splits,
|
92 |
+
pack_gqa,
|
93 |
+
sm_margin,
|
94 |
+
)
|
95 |
+
return out, softmax_lse, *rest
|
96 |
+
|
97 |
+
|
98 |
+
def _flash_attn_backward(
|
99 |
+
dout,
|
100 |
+
q,
|
101 |
+
k,
|
102 |
+
v,
|
103 |
+
out,
|
104 |
+
softmax_lse,
|
105 |
+
cu_seqlens_q,
|
106 |
+
cu_seqlens_k,
|
107 |
+
sequed_q,
|
108 |
+
sequed_k,
|
109 |
+
max_seqlen_q,
|
110 |
+
max_seqlen_k,
|
111 |
+
dq,
|
112 |
+
dk,
|
113 |
+
dv,
|
114 |
+
softmax_scale,
|
115 |
+
causal,
|
116 |
+
window_size=(-1, -1),
|
117 |
+
softcap=0.0,
|
118 |
+
deterministic=False,
|
119 |
+
sm_margin=0,
|
120 |
+
):
|
121 |
+
# dq, dk, dv are allocated by us so they should already be contiguous
|
122 |
+
dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
|
123 |
+
dq, dk, dv, softmax_d, *rest = flash_attn_3_cuda.bwd(
|
124 |
+
dout,
|
125 |
+
q,
|
126 |
+
k,
|
127 |
+
v,
|
128 |
+
out,
|
129 |
+
softmax_lse,
|
130 |
+
dq,
|
131 |
+
dk,
|
132 |
+
dv,
|
133 |
+
cu_seqlens_q,
|
134 |
+
cu_seqlens_k,
|
135 |
+
sequed_q,
|
136 |
+
sequed_k,
|
137 |
+
max_seqlen_q,
|
138 |
+
max_seqlen_k,
|
139 |
+
softmax_scale,
|
140 |
+
causal,
|
141 |
+
window_size[0],
|
142 |
+
window_size[1],
|
143 |
+
softcap,
|
144 |
+
deterministic,
|
145 |
+
sm_margin,
|
146 |
+
)
|
147 |
+
return dq, dk, dv, softmax_d
|
148 |
+
|
149 |
+
|
150 |
+
class FlashAttnQKVPackedFunc(torch.autograd.Function):
|
151 |
+
@staticmethod
|
152 |
+
def forward(
|
153 |
+
ctx,
|
154 |
+
qkv,
|
155 |
+
softmax_scale,
|
156 |
+
causal,
|
157 |
+
q_descale=None, k_descale=None, v_descale=None,
|
158 |
+
window_size=(-1, -1),
|
159 |
+
attention_chunk=0,
|
160 |
+
softcap=0.0,
|
161 |
+
deterministic=False,
|
162 |
+
num_heads_q=None,
|
163 |
+
sm_margin=0,
|
164 |
+
):
|
165 |
+
if softmax_scale is None:
|
166 |
+
softmax_scale = qkv.shape[-1] ** (-0.5)
|
167 |
+
if qkv.dim() == 5:
|
168 |
+
assert qkv.shape[-3] == 3
|
169 |
+
q, k, v = qkv.unbind(dim=-3)
|
170 |
+
else:
|
171 |
+
assert qkv.dim() == 4
|
172 |
+
assert num_heads_q is not None
|
173 |
+
num_heads_k = (qkv.shape[2] - num_heads_q) // 2
|
174 |
+
assert num_heads_k * 2 + num_heads_q == qkv.shape[2]
|
175 |
+
q, k, v = qkv.split([num_heads_q, num_heads_k, num_heads_k], dim=-2)
|
176 |
+
out, softmax_lse, *rest = _flash_attn_forward(
|
177 |
+
q,
|
178 |
+
k,
|
179 |
+
v,
|
180 |
+
None, None, # k_new, v_new
|
181 |
+
None, # qv
|
182 |
+
None, # out
|
183 |
+
None, None, None, # cu_seqlens_q/k/k_new
|
184 |
+
None, None, # seqused_q/k
|
185 |
+
None, None, # max_seqlen_q/k
|
186 |
+
None, None, None, # page_table, kv_batch_idx, leftpad_k,
|
187 |
+
None, None, None, # rotary_cos/sin, seqlens_rotary
|
188 |
+
q_descale, k_descale, v_descale,
|
189 |
+
softmax_scale,
|
190 |
+
causal=causal,
|
191 |
+
window_size=window_size,
|
192 |
+
attention_chunk=attention_chunk,
|
193 |
+
softcap=softcap,
|
194 |
+
sm_margin=sm_margin,
|
195 |
+
)
|
196 |
+
# ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
|
197 |
+
ctx.save_for_backward(q, k, v, out, softmax_lse)
|
198 |
+
ctx.softmax_scale = softmax_scale
|
199 |
+
ctx.causal = causal
|
200 |
+
ctx.window_size = window_size
|
201 |
+
ctx.attention_chunk = attention_chunk
|
202 |
+
ctx.softcap = softcap
|
203 |
+
ctx.deterministic = deterministic
|
204 |
+
ctx.ndim = qkv.dim()
|
205 |
+
ctx.sm_margin = sm_margin
|
206 |
+
# return out, softmax_lse
|
207 |
+
return out
|
208 |
+
|
209 |
+
@staticmethod
|
210 |
+
def backward(ctx, dout, *args):
|
211 |
+
q, k, v, out, softmax_lse = ctx.saved_tensors
|
212 |
+
assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
|
213 |
+
if ctx.ndim == 5:
|
214 |
+
qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
|
215 |
+
dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
|
216 |
+
dq, dk, dv = dqkv.unbind(dim=-3)
|
217 |
+
else:
|
218 |
+
num_heads_q = q.shape[2]
|
219 |
+
num_heads_k = k.shape[2]
|
220 |
+
qkv_shape = q.shape[:-2] + (num_heads_q + num_heads_k * 2, *q.shape[-1:])
|
221 |
+
dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
|
222 |
+
dq, dk, dv = dqkv.split([num_heads_q, num_heads_k, num_heads_k], dim=-2)
|
223 |
+
_flash_attn_backward(
|
224 |
+
dout,
|
225 |
+
q,
|
226 |
+
k,
|
227 |
+
v,
|
228 |
+
out,
|
229 |
+
softmax_lse,
|
230 |
+
None, None, # cu_seqlens_q, cu_seqlens_k,
|
231 |
+
None, None, # sequed_q, sequed_k,
|
232 |
+
None, None, # max_seqlen_q, max_seqlen_k,
|
233 |
+
dq,
|
234 |
+
dk,
|
235 |
+
dv,
|
236 |
+
ctx.softmax_scale,
|
237 |
+
ctx.causal,
|
238 |
+
ctx.window_size,
|
239 |
+
ctx.softcap,
|
240 |
+
ctx.deterministic,
|
241 |
+
ctx.sm_margin,
|
242 |
+
)
|
243 |
+
dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension
|
244 |
+
return dqkv, None, None, None, None, None, None, None, None, None, None, None
|
245 |
+
|
246 |
+
|
247 |
+
class FlashAttnFunc(torch.autograd.Function):
|
248 |
+
|
249 |
+
@staticmethod
|
250 |
+
def forward(
|
251 |
+
ctx,
|
252 |
+
q,
|
253 |
+
k,
|
254 |
+
v,
|
255 |
+
softmax_scale,
|
256 |
+
causal,
|
257 |
+
qv=None,
|
258 |
+
q_descale=None, k_descale=None, v_descale=None,
|
259 |
+
window_size=(-1, -1),
|
260 |
+
attention_chunk=0,
|
261 |
+
softcap=0.0,
|
262 |
+
num_splits=1,
|
263 |
+
pack_gqa=None,
|
264 |
+
deterministic=False,
|
265 |
+
sm_margin=0,
|
266 |
+
):
|
267 |
+
if softmax_scale is None:
|
268 |
+
softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
|
269 |
+
# out, q, k, v, out_padded, softmax_lse = _flash_attn_forward(
|
270 |
+
out, softmax_lse, *rest = _flash_attn_forward(
|
271 |
+
q,
|
272 |
+
k,
|
273 |
+
v,
|
274 |
+
None, None, # k_new, v_new
|
275 |
+
qv, # qv
|
276 |
+
None, # out
|
277 |
+
None, None, None, # cu_seqlens_q/k/k_new
|
278 |
+
None, None, # seqused_q/k
|
279 |
+
None, None, # max_seqlen_q/k
|
280 |
+
None, None, None, # page_table, kv_batch_idx, leftpad_k,
|
281 |
+
None, None, None, # rotary_cos/sin, seqlens_rotary
|
282 |
+
q_descale, k_descale, v_descale,
|
283 |
+
softmax_scale,
|
284 |
+
causal=causal,
|
285 |
+
window_size=window_size,
|
286 |
+
attention_chunk=attention_chunk,
|
287 |
+
softcap=softcap,
|
288 |
+
num_splits=num_splits,
|
289 |
+
pack_gqa=pack_gqa,
|
290 |
+
sm_margin=sm_margin,
|
291 |
+
)
|
292 |
+
# ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
|
293 |
+
ctx.save_for_backward(q, k, v, out, softmax_lse)
|
294 |
+
ctx.softmax_scale = softmax_scale
|
295 |
+
ctx.causal = causal
|
296 |
+
ctx.window_size = window_size
|
297 |
+
ctx.attention_chunk = attention_chunk
|
298 |
+
ctx.softcap = softcap
|
299 |
+
ctx.deterministic = deterministic
|
300 |
+
ctx.sm_margin = sm_margin
|
301 |
+
return out, softmax_lse
|
302 |
+
|
303 |
+
@staticmethod
|
304 |
+
def backward(ctx, dout, *args):
|
305 |
+
q, k, v, out, softmax_lse = ctx.saved_tensors
|
306 |
+
assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
|
307 |
+
dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
|
308 |
+
_flash_attn_backward(
|
309 |
+
dout,
|
310 |
+
q,
|
311 |
+
k,
|
312 |
+
v,
|
313 |
+
out,
|
314 |
+
softmax_lse,
|
315 |
+
None, None, # cu_seqlens_q, cu_seqlens_k,
|
316 |
+
None, None, # sequed_q, sequed_k,
|
317 |
+
None, None, # max_seqlen_q, max_seqlen_k,
|
318 |
+
dq,
|
319 |
+
dk,
|
320 |
+
dv,
|
321 |
+
ctx.softmax_scale,
|
322 |
+
ctx.causal,
|
323 |
+
ctx.window_size,
|
324 |
+
ctx.softcap,
|
325 |
+
ctx.deterministic,
|
326 |
+
ctx.sm_margin,
|
327 |
+
)
|
328 |
+
dq = dq[..., : q.shape[-1]] # We could have padded the head dimension
|
329 |
+
dk = dk[..., : k.shape[-1]]
|
330 |
+
dv = dv[..., : v.shape[-1]]
|
331 |
+
return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None
|
332 |
+
|
333 |
+
|
334 |
+
class FlashAttnVarlenFunc(torch.autograd.Function):
|
335 |
+
|
336 |
+
@staticmethod
|
337 |
+
def forward(
|
338 |
+
ctx,
|
339 |
+
q,
|
340 |
+
k,
|
341 |
+
v,
|
342 |
+
cu_seqlens_q,
|
343 |
+
cu_seqlens_k,
|
344 |
+
seqused_q,
|
345 |
+
seqused_k,
|
346 |
+
max_seqlen_q,
|
347 |
+
max_seqlen_k,
|
348 |
+
softmax_scale,
|
349 |
+
causal,
|
350 |
+
qv=None,
|
351 |
+
q_descale=None, k_descale=None, v_descale=None,
|
352 |
+
window_size=(-1, -1),
|
353 |
+
attention_chunk=0,
|
354 |
+
softcap=0.0,
|
355 |
+
num_splits=1,
|
356 |
+
pack_gqa=None,
|
357 |
+
deterministic=False,
|
358 |
+
sm_margin=0,
|
359 |
+
):
|
360 |
+
if softmax_scale is None:
|
361 |
+
softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
|
362 |
+
# out, q, k, v, out_padded, softmax_lse = _flash_attn_varlen_forward(
|
363 |
+
out, softmax_lse, *rest = _flash_attn_forward(
|
364 |
+
q,
|
365 |
+
k,
|
366 |
+
v,
|
367 |
+
None, None, # k_new, v_new
|
368 |
+
qv, # qv
|
369 |
+
None, # out
|
370 |
+
cu_seqlens_q,
|
371 |
+
cu_seqlens_k,
|
372 |
+
None, # cu_seqlens_k_new
|
373 |
+
seqused_q,
|
374 |
+
seqused_k,
|
375 |
+
max_seqlen_q,
|
376 |
+
max_seqlen_k,
|
377 |
+
None, None, None, # page_table, kv_batch_idx, leftpad_k,
|
378 |
+
None, None, None, # rotary_cos/sin, seqlens_rotary
|
379 |
+
q_descale, k_descale, v_descale,
|
380 |
+
softmax_scale,
|
381 |
+
causal=causal,
|
382 |
+
window_size=window_size,
|
383 |
+
attention_chunk=attention_chunk,
|
384 |
+
softcap=softcap,
|
385 |
+
num_splits=num_splits,
|
386 |
+
pack_gqa=pack_gqa,
|
387 |
+
sm_margin=sm_margin,
|
388 |
+
)
|
389 |
+
# ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
|
390 |
+
ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
|
391 |
+
ctx.max_seqlen_q = max_seqlen_q
|
392 |
+
ctx.max_seqlen_k = max_seqlen_k
|
393 |
+
ctx.softmax_scale = softmax_scale
|
394 |
+
ctx.causal = causal
|
395 |
+
ctx.window_size = window_size
|
396 |
+
ctx.attention_chunk = attention_chunk
|
397 |
+
ctx.softcap = softcap
|
398 |
+
ctx.deterministic = deterministic
|
399 |
+
ctx.sm_margin = sm_margin
|
400 |
+
return out, softmax_lse
|
401 |
+
|
402 |
+
@staticmethod
|
403 |
+
def backward(ctx, dout, *args):
|
404 |
+
q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = ctx.saved_tensors
|
405 |
+
assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
|
406 |
+
dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
|
407 |
+
_flash_attn_backward(
|
408 |
+
dout,
|
409 |
+
q,
|
410 |
+
k,
|
411 |
+
v,
|
412 |
+
out,
|
413 |
+
softmax_lse,
|
414 |
+
cu_seqlens_q,
|
415 |
+
cu_seqlens_k,
|
416 |
+
seqused_q,
|
417 |
+
seqused_k,
|
418 |
+
ctx.max_seqlen_q,
|
419 |
+
ctx.max_seqlen_k,
|
420 |
+
dq,
|
421 |
+
dk,
|
422 |
+
dv,
|
423 |
+
ctx.softmax_scale,
|
424 |
+
ctx.causal,
|
425 |
+
ctx.window_size,
|
426 |
+
ctx.softcap,
|
427 |
+
ctx.deterministic,
|
428 |
+
ctx.sm_margin,
|
429 |
+
)
|
430 |
+
dq = dq[..., : q.shape[-1]] # We could have padded the head dimension
|
431 |
+
dk = dk[..., : k.shape[-1]]
|
432 |
+
dv = dv[..., : v.shape[-1]]
|
433 |
+
return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None
|
434 |
+
|
435 |
+
|
436 |
+
def flash_attn_qkvpacked_func(
|
437 |
+
qkv,
|
438 |
+
softmax_scale=None,
|
439 |
+
causal=False,
|
440 |
+
q_descale=None, k_descale=None, v_descale=None,
|
441 |
+
window_size=(-1, -1),
|
442 |
+
attention_chunk=0,
|
443 |
+
softcap=0.0,
|
444 |
+
deterministic=False,
|
445 |
+
num_heads_q=None,
|
446 |
+
sm_margin=0,
|
447 |
+
):
|
448 |
+
"""dropout_p should be set to 0.0 during evaluation
|
449 |
+
If Q, K, V are already stacked into 1 tensor, this function will be faster than
|
450 |
+
calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
|
451 |
+
of the gradients of Q, K, V.
|
452 |
+
For multi-query and grouped-query attention (MQA/GQA), please see
|
453 |
+
flash_attn_kvpacked_func and flash_attn_func.
|
454 |
+
|
455 |
+
If window_size != (-1, -1), implements sliding window local attention. Query at position i
|
456 |
+
will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.
|
457 |
+
|
458 |
+
Arguments:
|
459 |
+
qkv: (batch_size, seqlen, 3, nheads, headdim)
|
460 |
+
dropout_p: float. Dropout probability.
|
461 |
+
softmax_scale: float. The scaling of QK^T before applying softmax.
|
462 |
+
Default to 1 / sqrt(headdim).
|
463 |
+
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
|
464 |
+
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
|
465 |
+
softcap: float. Anything > 0 activates softcapping attention.
|
466 |
+
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to
|
467 |
+
the attention score of query i and key j.
|
468 |
+
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
|
469 |
+
which is slightly slower and uses more memory. The forward pass is always deterministic.
|
470 |
+
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
|
471 |
+
testing only. The returned probabilities are not guaranteed to be correct
|
472 |
+
(they might not have the right scaling).
|
473 |
+
Return:
|
474 |
+
out: (batch_size, seqlen, nheads, headdim).
|
475 |
+
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
|
476 |
+
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
|
477 |
+
normalization factor).
|
478 |
+
S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
|
479 |
+
The output of softmax (possibly with different scaling). It also encodes the dropout
|
480 |
+
pattern (negative means that location was dropped, nonnegative means it was kept).
|
481 |
+
"""
|
482 |
+
return FlashAttnQKVPackedFunc.apply(
|
483 |
+
qkv,
|
484 |
+
softmax_scale,
|
485 |
+
causal,
|
486 |
+
q_descale, k_descale, v_descale,
|
487 |
+
window_size,
|
488 |
+
attention_chunk,
|
489 |
+
softcap,
|
490 |
+
deterministic,
|
491 |
+
num_heads_q,
|
492 |
+
sm_margin,
|
493 |
+
)
|
494 |
+
|
495 |
+
|
496 |
+
def flash_attn_func(
|
497 |
+
q,
|
498 |
+
k,
|
499 |
+
v,
|
500 |
+
softmax_scale=None,
|
501 |
+
causal=False,
|
502 |
+
qv=None,
|
503 |
+
q_descale=None, k_descale=None, v_descale=None,
|
504 |
+
window_size=(-1, -1),
|
505 |
+
attention_chunk=0,
|
506 |
+
softcap=0.0,
|
507 |
+
num_splits=1,
|
508 |
+
pack_gqa=None,
|
509 |
+
deterministic=False,
|
510 |
+
sm_margin=0,
|
511 |
+
):
|
512 |
+
"""dropout_p should be set to 0.0 during evaluation
|
513 |
+
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
|
514 |
+
than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
|
515 |
+
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
|
516 |
+
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
|
517 |
+
|
518 |
+
If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
|
519 |
+
For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
|
520 |
+
1 1 1 1 0
|
521 |
+
1 1 1 1 1
|
522 |
+
If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
|
523 |
+
0 0
|
524 |
+
0 0
|
525 |
+
0 0
|
526 |
+
1 0
|
527 |
+
1 1
|
528 |
+
If the row of the mask is all zero, the output will be zero.
|
529 |
+
|
530 |
+
If window_size != (-1, -1), implements sliding window local attention. Query at position i
|
531 |
+
will only attend to keys between
|
532 |
+
[i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
|
533 |
+
|
534 |
+
Arguments:
|
535 |
+
q: (batch_size, seqlen, nheads, headdim)
|
536 |
+
k: (batch_size, seqlen, nheads_k, headdim)
|
537 |
+
v: (batch_size, seqlen, nheads_k, headdim)
|
538 |
+
dropout_p: float. Dropout probability.
|
539 |
+
softmax_scale: float. The scaling of QK^T before applying softmax.
|
540 |
+
Default to 1 / sqrt(headdim).
|
541 |
+
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
|
542 |
+
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
|
543 |
+
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
|
544 |
+
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
|
545 |
+
is added to the attention score of query i and key j.
|
546 |
+
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
|
547 |
+
which is slightly slower and uses more memory. The forward pass is always deterministic.
|
548 |
+
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
|
549 |
+
testing only. The returned probabilities are not guaranteed to be correct
|
550 |
+
(they might not have the right scaling).
|
551 |
+
Return:
|
552 |
+
out: (batch_size, seqlen, nheads, headdim).
|
553 |
+
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
|
554 |
+
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
|
555 |
+
normalization factor).
|
556 |
+
"""
|
557 |
+
return FlashAttnFunc.apply(
|
558 |
+
q,
|
559 |
+
k,
|
560 |
+
v,
|
561 |
+
softmax_scale,
|
562 |
+
causal,
|
563 |
+
qv,
|
564 |
+
q_descale, k_descale, v_descale,
|
565 |
+
window_size,
|
566 |
+
attention_chunk,
|
567 |
+
softcap,
|
568 |
+
num_splits,
|
569 |
+
pack_gqa,
|
570 |
+
deterministic,
|
571 |
+
sm_margin,
|
572 |
+
)
|
573 |
+
|
574 |
+
|
575 |
+
def flash_attn_varlen_func(
|
576 |
+
q,
|
577 |
+
k,
|
578 |
+
v,
|
579 |
+
cu_seqlens_q,
|
580 |
+
cu_seqlens_k,
|
581 |
+
max_seqlen_q,
|
582 |
+
max_seqlen_k,
|
583 |
+
seqused_q=None,
|
584 |
+
seqused_k=None,
|
585 |
+
softmax_scale=None,
|
586 |
+
causal=False,
|
587 |
+
qv=None,
|
588 |
+
q_descale=None, k_descale=None, v_descale=None,
|
589 |
+
window_size=(-1, -1),
|
590 |
+
attention_chunk=0,
|
591 |
+
softcap=0.0,
|
592 |
+
num_splits=1,
|
593 |
+
pack_gqa=None,
|
594 |
+
deterministic=False,
|
595 |
+
sm_margin=0,
|
596 |
+
):
|
597 |
+
return FlashAttnVarlenFunc.apply(
|
598 |
+
q,
|
599 |
+
k,
|
600 |
+
v,
|
601 |
+
cu_seqlens_q,
|
602 |
+
cu_seqlens_k,
|
603 |
+
seqused_q,
|
604 |
+
seqused_k,
|
605 |
+
max_seqlen_q,
|
606 |
+
max_seqlen_k,
|
607 |
+
softmax_scale,
|
608 |
+
causal,
|
609 |
+
qv,
|
610 |
+
q_descale, k_descale, v_descale,
|
611 |
+
window_size,
|
612 |
+
attention_chunk,
|
613 |
+
softcap,
|
614 |
+
num_splits,
|
615 |
+
pack_gqa,
|
616 |
+
deterministic,
|
617 |
+
sm_margin,
|
618 |
+
)
|
619 |
+
|
620 |
+
|
621 |
+
def flash_attn_combine(out_partial, lse_partial, out=None, out_dtype=None):
|
622 |
+
return flash_attn_3_cuda.fwd_combine(out_partial, lse_partial, out, out_dtype)
|
623 |
+
|
624 |
+
|
625 |
+
def flash_attn_with_kvcache(
|
626 |
+
q,
|
627 |
+
k_cache,
|
628 |
+
v_cache,
|
629 |
+
k=None,
|
630 |
+
v=None,
|
631 |
+
qv=None,
|
632 |
+
rotary_cos=None,
|
633 |
+
rotary_sin=None,
|
634 |
+
cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
|
635 |
+
cache_batch_idx: Optional[torch.Tensor] = None,
|
636 |
+
cache_leftpad: Optional[torch.Tensor] = None,
|
637 |
+
page_table: Optional[torch.Tensor] = None,
|
638 |
+
cu_seqlens_q: Optional[torch.Tensor] = None,
|
639 |
+
cu_seqlens_k_new: Optional[torch.Tensor] = None,
|
640 |
+
max_seqlen_q: Optional[int] = None,
|
641 |
+
rotary_seqlens: Optional[torch.Tensor] = None,
|
642 |
+
q_descale: Optional[torch.Tensor] = None,
|
643 |
+
k_descale: Optional[torch.Tensor] = None,
|
644 |
+
v_descale: Optional[torch.Tensor] = None,
|
645 |
+
softmax_scale=None,
|
646 |
+
causal=False,
|
647 |
+
window_size=(-1, -1), # -1 means infinite context window
|
648 |
+
attention_chunk=0,
|
649 |
+
softcap=0.0, # 0.0 means deactivated
|
650 |
+
rotary_interleaved=True,
|
651 |
+
scheduler_metadata=None,
|
652 |
+
num_splits=0, # Can be tuned for speed
|
653 |
+
pack_gqa=None, # Can be tuned for speed
|
654 |
+
sm_margin=0, # Can be tuned if some SMs are used for communication
|
655 |
+
return_softmax_lse=False,
|
656 |
+
):
|
657 |
+
"""
|
658 |
+
If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
|
659 |
+
k and v. This is useful for incremental decoding: you can pass in the cached keys/values from
|
660 |
+
the previous step, and update them with the new keys/values from the current step, and do
|
661 |
+
attention with the updated cache, all in 1 kernel.
|
662 |
+
|
663 |
+
If you pass in k / v, you must make sure that the cache is large enough to hold the new values.
|
664 |
+
For example, the KV cache could be pre-allocated with the max sequence length, and you can use
|
665 |
+
cache_seqlens to keep track of the current sequence lengths of each sequence in the batch.
|
666 |
+
|
667 |
+
Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be
|
668 |
+
rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
|
669 |
+
If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos
|
670 |
+
and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
|
671 |
+
If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at
|
672 |
+
indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens).
|
673 |
+
|
674 |
+
See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function.
|
675 |
+
|
676 |
+
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
|
677 |
+
than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
|
678 |
+
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
|
679 |
+
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
|
680 |
+
|
681 |
+
If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
|
682 |
+
For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
|
683 |
+
1 1 1 1 0
|
684 |
+
1 1 1 1 1
|
685 |
+
If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
|
686 |
+
0 0
|
687 |
+
0 0
|
688 |
+
0 0
|
689 |
+
1 0
|
690 |
+
1 1
|
691 |
+
If the row of the mask is all zero, the output will be zero.
|
692 |
+
|
693 |
+
If window_size != (-1, -1), implements sliding window local attention. Query at position i
|
694 |
+
will only attend to keys between
|
695 |
+
[i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
|
696 |
+
|
697 |
+
Note: Does not support backward pass.
|
698 |
+
|
699 |
+
Arguments:
|
700 |
+
q: (batch_size, seqlen, nheads, headdim)
|
701 |
+
k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no page_table,
|
702 |
+
or (num_blocks, page_block_size, nheads_k, headdim) if there's a page_table (i.e. paged KV cache)
|
703 |
+
page_block_size must be a multiple of 256.
|
704 |
+
v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim_v) if there's no page_table,
|
705 |
+
or (num_blocks, page_block_size, nheads_k, headdim_v) if there's a page_table (i.e. paged KV cache)
|
706 |
+
k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
|
707 |
+
k with k_cache, starting at the indices specified by cache_seqlens.
|
708 |
+
v [optional]: (batch_size, seqlen_new, nheads_k, headdim_v). Similar to k.
|
709 |
+
qv [optional]: (batch_size, seqlen, nheads, headdim_v)
|
710 |
+
rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding
|
711 |
+
to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16.
|
712 |
+
rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
|
713 |
+
cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
|
714 |
+
KV cache.
|
715 |
+
cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.
|
716 |
+
If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].
|
717 |
+
If the indices are not distinct, and k and v are provided, the values updated in the cache
|
718 |
+
might come from any of the duplicate indices.
|
719 |
+
cache_leftpad: (batch_size,), dtype torch.int32. The index that the KV cache starts. If None, assume 0.
|
720 |
+
page_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32.
|
721 |
+
softmax_scale: float. The scaling of QK^T before applying softmax.
|
722 |
+
Default to 1 / sqrt(headdim).
|
723 |
+
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
|
724 |
+
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
|
725 |
+
softcap: float. Anything > 0 activates softcapping attention.
|
726 |
+
rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
|
727 |
+
If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
|
728 |
+
rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
|
729 |
+
(i.e. GPT-NeoX style).
|
730 |
+
num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
|
731 |
+
If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
|
732 |
+
to automatically determine the number of splits.
|
733 |
+
Don't change this unless you know what you are doing.
|
734 |
+
return_softmax_lse: bool. Whether to return the logsumexp of the attention scores.
|
735 |
+
|
736 |
+
Return:
|
737 |
+
out: (batch_size, seqlen, nheads, headdim).
|
738 |
+
softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The
|
739 |
+
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
|
740 |
+
normalization factor).
|
741 |
+
"""
|
742 |
+
assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
|
743 |
+
assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
|
744 |
+
if softmax_scale is None:
|
745 |
+
softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
|
746 |
+
if cache_seqlens is not None and isinstance(cache_seqlens, int):
|
747 |
+
cache_seqlens = torch.full(
|
748 |
+
(k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device
|
749 |
+
)
|
750 |
+
cache_seqlens = maybe_contiguous(cache_seqlens)
|
751 |
+
out, softmax_lse, *rest = _flash_attn_forward(
|
752 |
+
q,
|
753 |
+
k_cache,
|
754 |
+
v_cache,
|
755 |
+
k,
|
756 |
+
v,
|
757 |
+
qv,
|
758 |
+
None, # out
|
759 |
+
cu_seqlens_q,
|
760 |
+
None, # cu_seqlens_k
|
761 |
+
cu_seqlens_k_new,
|
762 |
+
None, # seqused_q
|
763 |
+
cache_seqlens,
|
764 |
+
max_seqlen_q,
|
765 |
+
None, # max_seqlen_k
|
766 |
+
page_table,
|
767 |
+
cache_batch_idx,
|
768 |
+
cache_leftpad,
|
769 |
+
rotary_cos,
|
770 |
+
rotary_sin,
|
771 |
+
rotary_seqlens,
|
772 |
+
q_descale, k_descale, v_descale,
|
773 |
+
softmax_scale,
|
774 |
+
causal=causal,
|
775 |
+
window_size=window_size,
|
776 |
+
attention_chunk=attention_chunk,
|
777 |
+
softcap=softcap,
|
778 |
+
rotary_interleaved=rotary_interleaved,
|
779 |
+
scheduler_metadata=scheduler_metadata,
|
780 |
+
num_splits=num_splits,
|
781 |
+
pack_gqa=pack_gqa,
|
782 |
+
sm_margin=sm_margin,
|
783 |
+
)
|
784 |
+
# return (out, softmax_lse) if return_softmax_lse else out
|
785 |
+
return (out, softmax_lse, *rest) if return_softmax_lse else out
|
786 |
+
|
787 |
+
|
788 |
+
def get_scheduler_metadata(
|
789 |
+
batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim,
|
790 |
+
cache_seqlens: torch.Tensor,
|
791 |
+
qkv_dtype=torch.bfloat16,
|
792 |
+
headdim_v=None,
|
793 |
+
cu_seqlens_q: Optional[torch.Tensor] = None,
|
794 |
+
cu_seqlens_k_new: Optional[torch.Tensor] = None,
|
795 |
+
cache_leftpad: Optional[torch.Tensor] = None,
|
796 |
+
page_size: Optional[int] = None,
|
797 |
+
max_seqlen_k_new=0,
|
798 |
+
causal=False,
|
799 |
+
window_size=(-1, -1), # -1 means infinite context window
|
800 |
+
attention_chunk=0,
|
801 |
+
has_softcap=False,
|
802 |
+
num_splits=0, # Can be tuned for speed
|
803 |
+
pack_gqa=None, # Can be tuned for speed
|
804 |
+
sm_margin=0, # Can be tuned if some SMs are used for communication
|
805 |
+
):
|
806 |
+
cache_seqlens = maybe_contiguous(cache_seqlens)
|
807 |
+
if headdim_v is None:
|
808 |
+
headdim_v = headdim
|
809 |
+
scheduler_metadata = flash_attn_3_cuda.get_scheduler_metadata(
|
810 |
+
batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim, headdim_v,
|
811 |
+
qkv_dtype,
|
812 |
+
cache_seqlens,
|
813 |
+
cu_seqlens_q,
|
814 |
+
None, # cu_seqlens_k
|
815 |
+
cu_seqlens_k_new,
|
816 |
+
None, # seqused_q
|
817 |
+
cache_leftpad,
|
818 |
+
page_size,
|
819 |
+
max_seqlen_k_new,
|
820 |
+
causal,
|
821 |
+
window_size[0], window_size[1],
|
822 |
+
attention_chunk,
|
823 |
+
has_softcap,
|
824 |
+
num_splits,
|
825 |
+
pack_gqa,
|
826 |
+
sm_margin,
|
827 |
+
)
|
828 |
+
return scheduler_metadata
|
build/torch26-cxx98-cu124-x86_64-linux/flash_attn3/__init__.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .flash_attn_interface import (
|
2 |
+
flash_attn_combine,
|
3 |
+
flash_attn_func,
|
4 |
+
flash_attn_qkvpacked_func,
|
5 |
+
flash_attn_varlen_func,
|
6 |
+
flash_attn_with_kvcache,
|
7 |
+
get_scheduler_metadata,
|
8 |
+
)
|
9 |
+
|
10 |
+
__all__ = [
|
11 |
+
"flash_attn_combine",
|
12 |
+
"flash_attn_func",
|
13 |
+
"flash_attn_qkvpacked_func",
|
14 |
+
"flash_attn_varlen_func",
|
15 |
+
"flash_attn_with_kvcache",
|
16 |
+
"get_scheduler_metadata",
|
17 |
+
]
|
build/torch26-cxx98-cu124-x86_64-linux/flash_attn3/_flash_attn3_2e75662.abi3.so
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9627e08ec8778d2a409a2a0477572edb3e03eaca2b45e7b4810ee0a9126d6547
|
3 |
+
size 838456048
|
build/torch26-cxx98-cu124-x86_64-linux/flash_attn3/_flash_attn3_557701f.abi3.so
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:07fe025ba95671f6ff957991f74c66063bfb10ab6737641c88f88116c9f83718
|
3 |
+
size 838456048
|
build/torch26-cxx98-cu124-x86_64-linux/flash_attn3/_ops.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from . import _flash_attn3_557701f
|
3 |
+
ops = torch.ops._flash_attn3_557701f
|
4 |
+
|
5 |
+
def add_op_namespace_prefix(op_name: str):
|
6 |
+
"""
|
7 |
+
Prefix op by namespace.
|
8 |
+
"""
|
9 |
+
return f"_flash_attn3_557701f::{op_name}"
|
build/torch26-cxx98-cu124-x86_64-linux/flash_attn3/flash_attn_interface.py
ADDED
@@ -0,0 +1,828 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023, Tri Dao.
|
2 |
+
|
3 |
+
from typing import Optional, Union
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
|
8 |
+
from ._ops import ops as flash_attn_3_cuda
|
9 |
+
|
10 |
+
def maybe_contiguous(x):
|
11 |
+
return x.contiguous() if x is not None and x.stride(-1) != 1 else x
|
12 |
+
|
13 |
+
|
14 |
+
def _flash_attn_forward(
|
15 |
+
q,
|
16 |
+
k,
|
17 |
+
v,
|
18 |
+
k_new,
|
19 |
+
v_new,
|
20 |
+
qv,
|
21 |
+
out,
|
22 |
+
cu_seqlens_q,
|
23 |
+
cu_seqlens_k,
|
24 |
+
cu_seqlens_k_new,
|
25 |
+
seqused_q,
|
26 |
+
seqused_k,
|
27 |
+
max_seqlen_q,
|
28 |
+
max_seqlen_k,
|
29 |
+
page_table,
|
30 |
+
kv_batch_idx,
|
31 |
+
leftpad_k,
|
32 |
+
rotary_cos,
|
33 |
+
rotary_sin,
|
34 |
+
seqlens_rotary,
|
35 |
+
q_descale,
|
36 |
+
k_descale,
|
37 |
+
v_descale,
|
38 |
+
softmax_scale,
|
39 |
+
causal,
|
40 |
+
window_size=(-1, -1),
|
41 |
+
attention_chunk=0,
|
42 |
+
softcap=0.0,
|
43 |
+
rotary_interleaved=True,
|
44 |
+
scheduler_metadata=None,
|
45 |
+
num_splits=1,
|
46 |
+
pack_gqa=None,
|
47 |
+
sm_margin=0):
|
48 |
+
q, k, k_new, v_new = [maybe_contiguous(x) for x in (q, k, k_new, v_new)]
|
49 |
+
v = v.contiguous() if v.stride(-1) != 1 and v.stride(-3) != 1 else v
|
50 |
+
cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new = [
|
51 |
+
maybe_contiguous(x) for x in (cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new)
|
52 |
+
]
|
53 |
+
seqused_q, seqused_k = [maybe_contiguous(x) for x in (seqused_q, seqused_k)]
|
54 |
+
page_table, kv_batch_idx, leftpad_k = [
|
55 |
+
maybe_contiguous(x) for x in (page_table, kv_batch_idx, leftpad_k)
|
56 |
+
]
|
57 |
+
rotary_cos, rotary_sin = [maybe_contiguous(x) for x in (rotary_cos, rotary_sin)]
|
58 |
+
seqlens_rotary = maybe_contiguous(seqlens_rotary)
|
59 |
+
out, softmax_lse, *rest = flash_attn_3_cuda.fwd(
|
60 |
+
q,
|
61 |
+
k,
|
62 |
+
v,
|
63 |
+
k_new,
|
64 |
+
v_new,
|
65 |
+
qv,
|
66 |
+
out,
|
67 |
+
cu_seqlens_q,
|
68 |
+
cu_seqlens_k,
|
69 |
+
cu_seqlens_k_new,
|
70 |
+
seqused_q,
|
71 |
+
seqused_k,
|
72 |
+
max_seqlen_q,
|
73 |
+
max_seqlen_k,
|
74 |
+
page_table,
|
75 |
+
kv_batch_idx,
|
76 |
+
leftpad_k,
|
77 |
+
rotary_cos,
|
78 |
+
rotary_sin,
|
79 |
+
seqlens_rotary,
|
80 |
+
q_descale,
|
81 |
+
k_descale,
|
82 |
+
v_descale,
|
83 |
+
softmax_scale,
|
84 |
+
causal,
|
85 |
+
window_size[0],
|
86 |
+
window_size[1],
|
87 |
+
attention_chunk,
|
88 |
+
softcap,
|
89 |
+
rotary_interleaved,
|
90 |
+
scheduler_metadata,
|
91 |
+
num_splits,
|
92 |
+
pack_gqa,
|
93 |
+
sm_margin,
|
94 |
+
)
|
95 |
+
return out, softmax_lse, *rest
|
96 |
+
|
97 |
+
|
98 |
+
def _flash_attn_backward(
|
99 |
+
dout,
|
100 |
+
q,
|
101 |
+
k,
|
102 |
+
v,
|
103 |
+
out,
|
104 |
+
softmax_lse,
|
105 |
+
cu_seqlens_q,
|
106 |
+
cu_seqlens_k,
|
107 |
+
sequed_q,
|
108 |
+
sequed_k,
|
109 |
+
max_seqlen_q,
|
110 |
+
max_seqlen_k,
|
111 |
+
dq,
|
112 |
+
dk,
|
113 |
+
dv,
|
114 |
+
softmax_scale,
|
115 |
+
causal,
|
116 |
+
window_size=(-1, -1),
|
117 |
+
softcap=0.0,
|
118 |
+
deterministic=False,
|
119 |
+
sm_margin=0,
|
120 |
+
):
|
121 |
+
# dq, dk, dv are allocated by us so they should already be contiguous
|
122 |
+
dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
|
123 |
+
dq, dk, dv, softmax_d, *rest = flash_attn_3_cuda.bwd(
|
124 |
+
dout,
|
125 |
+
q,
|
126 |
+
k,
|
127 |
+
v,
|
128 |
+
out,
|
129 |
+
softmax_lse,
|
130 |
+
dq,
|
131 |
+
dk,
|
132 |
+
dv,
|
133 |
+
cu_seqlens_q,
|
134 |
+
cu_seqlens_k,
|
135 |
+
sequed_q,
|
136 |
+
sequed_k,
|
137 |
+
max_seqlen_q,
|
138 |
+
max_seqlen_k,
|
139 |
+
softmax_scale,
|
140 |
+
causal,
|
141 |
+
window_size[0],
|
142 |
+
window_size[1],
|
143 |
+
softcap,
|
144 |
+
deterministic,
|
145 |
+
sm_margin,
|
146 |
+
)
|
147 |
+
return dq, dk, dv, softmax_d
|
148 |
+
|
149 |
+
|
150 |
+
class FlashAttnQKVPackedFunc(torch.autograd.Function):
|
151 |
+
@staticmethod
|
152 |
+
def forward(
|
153 |
+
ctx,
|
154 |
+
qkv,
|
155 |
+
softmax_scale,
|
156 |
+
causal,
|
157 |
+
q_descale=None, k_descale=None, v_descale=None,
|
158 |
+
window_size=(-1, -1),
|
159 |
+
attention_chunk=0,
|
160 |
+
softcap=0.0,
|
161 |
+
deterministic=False,
|
162 |
+
num_heads_q=None,
|
163 |
+
sm_margin=0,
|
164 |
+
):
|
165 |
+
if softmax_scale is None:
|
166 |
+
softmax_scale = qkv.shape[-1] ** (-0.5)
|
167 |
+
if qkv.dim() == 5:
|
168 |
+
assert qkv.shape[-3] == 3
|
169 |
+
q, k, v = qkv.unbind(dim=-3)
|
170 |
+
else:
|
171 |
+
assert qkv.dim() == 4
|
172 |
+
assert num_heads_q is not None
|
173 |
+
num_heads_k = (qkv.shape[2] - num_heads_q) // 2
|
174 |
+
assert num_heads_k * 2 + num_heads_q == qkv.shape[2]
|
175 |
+
q, k, v = qkv.split([num_heads_q, num_heads_k, num_heads_k], dim=-2)
|
176 |
+
out, softmax_lse, *rest = _flash_attn_forward(
|
177 |
+
q,
|
178 |
+
k,
|
179 |
+
v,
|
180 |
+
None, None, # k_new, v_new
|
181 |
+
None, # qv
|
182 |
+
None, # out
|
183 |
+
None, None, None, # cu_seqlens_q/k/k_new
|
184 |
+
None, None, # seqused_q/k
|
185 |
+
None, None, # max_seqlen_q/k
|
186 |
+
None, None, None, # page_table, kv_batch_idx, leftpad_k,
|
187 |
+
None, None, None, # rotary_cos/sin, seqlens_rotary
|
188 |
+
q_descale, k_descale, v_descale,
|
189 |
+
softmax_scale,
|
190 |
+
causal=causal,
|
191 |
+
window_size=window_size,
|
192 |
+
attention_chunk=attention_chunk,
|
193 |
+
softcap=softcap,
|
194 |
+
sm_margin=sm_margin,
|
195 |
+
)
|
196 |
+
# ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
|
197 |
+
ctx.save_for_backward(q, k, v, out, softmax_lse)
|
198 |
+
ctx.softmax_scale = softmax_scale
|
199 |
+
ctx.causal = causal
|
200 |
+
ctx.window_size = window_size
|
201 |
+
ctx.attention_chunk = attention_chunk
|
202 |
+
ctx.softcap = softcap
|
203 |
+
ctx.deterministic = deterministic
|
204 |
+
ctx.ndim = qkv.dim()
|
205 |
+
ctx.sm_margin = sm_margin
|
206 |
+
# return out, softmax_lse
|
207 |
+
return out
|
208 |
+
|
209 |
+
@staticmethod
|
210 |
+
def backward(ctx, dout, *args):
|
211 |
+
q, k, v, out, softmax_lse = ctx.saved_tensors
|
212 |
+
assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
|
213 |
+
if ctx.ndim == 5:
|
214 |
+
qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
|
215 |
+
dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
|
216 |
+
dq, dk, dv = dqkv.unbind(dim=-3)
|
217 |
+
else:
|
218 |
+
num_heads_q = q.shape[2]
|
219 |
+
num_heads_k = k.shape[2]
|
220 |
+
qkv_shape = q.shape[:-2] + (num_heads_q + num_heads_k * 2, *q.shape[-1:])
|
221 |
+
dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
|
222 |
+
dq, dk, dv = dqkv.split([num_heads_q, num_heads_k, num_heads_k], dim=-2)
|
223 |
+
_flash_attn_backward(
|
224 |
+
dout,
|
225 |
+
q,
|
226 |
+
k,
|
227 |
+
v,
|
228 |
+
out,
|
229 |
+
softmax_lse,
|
230 |
+
None, None, # cu_seqlens_q, cu_seqlens_k,
|
231 |
+
None, None, # sequed_q, sequed_k,
|
232 |
+
None, None, # max_seqlen_q, max_seqlen_k,
|
233 |
+
dq,
|
234 |
+
dk,
|
235 |
+
dv,
|
236 |
+
ctx.softmax_scale,
|
237 |
+
ctx.causal,
|
238 |
+
ctx.window_size,
|
239 |
+
ctx.softcap,
|
240 |
+
ctx.deterministic,
|
241 |
+
ctx.sm_margin,
|
242 |
+
)
|
243 |
+
dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension
|
244 |
+
return dqkv, None, None, None, None, None, None, None, None, None, None, None
|
245 |
+
|
246 |
+
|
247 |
+
class FlashAttnFunc(torch.autograd.Function):
|
248 |
+
|
249 |
+
@staticmethod
|
250 |
+
def forward(
|
251 |
+
ctx,
|
252 |
+
q,
|
253 |
+
k,
|
254 |
+
v,
|
255 |
+
softmax_scale,
|
256 |
+
causal,
|
257 |
+
qv=None,
|
258 |
+
q_descale=None, k_descale=None, v_descale=None,
|
259 |
+
window_size=(-1, -1),
|
260 |
+
attention_chunk=0,
|
261 |
+
softcap=0.0,
|
262 |
+
num_splits=1,
|
263 |
+
pack_gqa=None,
|
264 |
+
deterministic=False,
|
265 |
+
sm_margin=0,
|
266 |
+
):
|
267 |
+
if softmax_scale is None:
|
268 |
+
softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
|
269 |
+
# out, q, k, v, out_padded, softmax_lse = _flash_attn_forward(
|
270 |
+
out, softmax_lse, *rest = _flash_attn_forward(
|
271 |
+
q,
|
272 |
+
k,
|
273 |
+
v,
|
274 |
+
None, None, # k_new, v_new
|
275 |
+
qv, # qv
|
276 |
+
None, # out
|
277 |
+
None, None, None, # cu_seqlens_q/k/k_new
|
278 |
+
None, None, # seqused_q/k
|
279 |
+
None, None, # max_seqlen_q/k
|
280 |
+
None, None, None, # page_table, kv_batch_idx, leftpad_k,
|
281 |
+
None, None, None, # rotary_cos/sin, seqlens_rotary
|
282 |
+
q_descale, k_descale, v_descale,
|
283 |
+
softmax_scale,
|
284 |
+
causal=causal,
|
285 |
+
window_size=window_size,
|
286 |
+
attention_chunk=attention_chunk,
|
287 |
+
softcap=softcap,
|
288 |
+
num_splits=num_splits,
|
289 |
+
pack_gqa=pack_gqa,
|
290 |
+
sm_margin=sm_margin,
|
291 |
+
)
|
292 |
+
# ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
|
293 |
+
ctx.save_for_backward(q, k, v, out, softmax_lse)
|
294 |
+
ctx.softmax_scale = softmax_scale
|
295 |
+
ctx.causal = causal
|
296 |
+
ctx.window_size = window_size
|
297 |
+
ctx.attention_chunk = attention_chunk
|
298 |
+
ctx.softcap = softcap
|
299 |
+
ctx.deterministic = deterministic
|
300 |
+
ctx.sm_margin = sm_margin
|
301 |
+
return out, softmax_lse
|
302 |
+
|
303 |
+
@staticmethod
|
304 |
+
def backward(ctx, dout, *args):
|
305 |
+
q, k, v, out, softmax_lse = ctx.saved_tensors
|
306 |
+
assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
|
307 |
+
dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
|
308 |
+
_flash_attn_backward(
|
309 |
+
dout,
|
310 |
+
q,
|
311 |
+
k,
|
312 |
+
v,
|
313 |
+
out,
|
314 |
+
softmax_lse,
|
315 |
+
None, None, # cu_seqlens_q, cu_seqlens_k,
|
316 |
+
None, None, # sequed_q, sequed_k,
|
317 |
+
None, None, # max_seqlen_q, max_seqlen_k,
|
318 |
+
dq,
|
319 |
+
dk,
|
320 |
+
dv,
|
321 |
+
ctx.softmax_scale,
|
322 |
+
ctx.causal,
|
323 |
+
ctx.window_size,
|
324 |
+
ctx.softcap,
|
325 |
+
ctx.deterministic,
|
326 |
+
ctx.sm_margin,
|
327 |
+
)
|
328 |
+
dq = dq[..., : q.shape[-1]] # We could have padded the head dimension
|
329 |
+
dk = dk[..., : k.shape[-1]]
|
330 |
+
dv = dv[..., : v.shape[-1]]
|
331 |
+
return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None
|
332 |
+
|
333 |
+
|
334 |
+
class FlashAttnVarlenFunc(torch.autograd.Function):
|
335 |
+
|
336 |
+
@staticmethod
|
337 |
+
def forward(
|
338 |
+
ctx,
|
339 |
+
q,
|
340 |
+
k,
|
341 |
+
v,
|
342 |
+
cu_seqlens_q,
|
343 |
+
cu_seqlens_k,
|
344 |
+
seqused_q,
|
345 |
+
seqused_k,
|
346 |
+
max_seqlen_q,
|
347 |
+
max_seqlen_k,
|
348 |
+
softmax_scale,
|
349 |
+
causal,
|
350 |
+
qv=None,
|
351 |
+
q_descale=None, k_descale=None, v_descale=None,
|
352 |
+
window_size=(-1, -1),
|
353 |
+
attention_chunk=0,
|
354 |
+
softcap=0.0,
|
355 |
+
num_splits=1,
|
356 |
+
pack_gqa=None,
|
357 |
+
deterministic=False,
|
358 |
+
sm_margin=0,
|
359 |
+
):
|
360 |
+
if softmax_scale is None:
|
361 |
+
softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
|
362 |
+
# out, q, k, v, out_padded, softmax_lse = _flash_attn_varlen_forward(
|
363 |
+
out, softmax_lse, *rest = _flash_attn_forward(
|
364 |
+
q,
|
365 |
+
k,
|
366 |
+
v,
|
367 |
+
None, None, # k_new, v_new
|
368 |
+
qv, # qv
|
369 |
+
None, # out
|
370 |
+
cu_seqlens_q,
|
371 |
+
cu_seqlens_k,
|
372 |
+
None, # cu_seqlens_k_new
|
373 |
+
seqused_q,
|
374 |
+
seqused_k,
|
375 |
+
max_seqlen_q,
|
376 |
+
max_seqlen_k,
|
377 |
+
None, None, None, # page_table, kv_batch_idx, leftpad_k,
|
378 |
+
None, None, None, # rotary_cos/sin, seqlens_rotary
|
379 |
+
q_descale, k_descale, v_descale,
|
380 |
+
softmax_scale,
|
381 |
+
causal=causal,
|
382 |
+
window_size=window_size,
|
383 |
+
attention_chunk=attention_chunk,
|
384 |
+
softcap=softcap,
|
385 |
+
num_splits=num_splits,
|
386 |
+
pack_gqa=pack_gqa,
|
387 |
+
sm_margin=sm_margin,
|
388 |
+
)
|
389 |
+
# ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
|
390 |
+
ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
|
391 |
+
ctx.max_seqlen_q = max_seqlen_q
|
392 |
+
ctx.max_seqlen_k = max_seqlen_k
|
393 |
+
ctx.softmax_scale = softmax_scale
|
394 |
+
ctx.causal = causal
|
395 |
+
ctx.window_size = window_size
|
396 |
+
ctx.attention_chunk = attention_chunk
|
397 |
+
ctx.softcap = softcap
|
398 |
+
ctx.deterministic = deterministic
|
399 |
+
ctx.sm_margin = sm_margin
|
400 |
+
return out, softmax_lse
|
401 |
+
|
402 |
+
@staticmethod
|
403 |
+
def backward(ctx, dout, *args):
|
404 |
+
q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = ctx.saved_tensors
|
405 |
+
assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
|
406 |
+
dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
|
407 |
+
_flash_attn_backward(
|
408 |
+
dout,
|
409 |
+
q,
|
410 |
+
k,
|
411 |
+
v,
|
412 |
+
out,
|
413 |
+
softmax_lse,
|
414 |
+
cu_seqlens_q,
|
415 |
+
cu_seqlens_k,
|
416 |
+
seqused_q,
|
417 |
+
seqused_k,
|
418 |
+
ctx.max_seqlen_q,
|
419 |
+
ctx.max_seqlen_k,
|
420 |
+
dq,
|
421 |
+
dk,
|
422 |
+
dv,
|
423 |
+
ctx.softmax_scale,
|
424 |
+
ctx.causal,
|
425 |
+
ctx.window_size,
|
426 |
+
ctx.softcap,
|
427 |
+
ctx.deterministic,
|
428 |
+
ctx.sm_margin,
|
429 |
+
)
|
430 |
+
dq = dq[..., : q.shape[-1]] # We could have padded the head dimension
|
431 |
+
dk = dk[..., : k.shape[-1]]
|
432 |
+
dv = dv[..., : v.shape[-1]]
|
433 |
+
return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None
|
434 |
+
|
435 |
+
|
436 |
+
def flash_attn_qkvpacked_func(
|
437 |
+
qkv,
|
438 |
+
softmax_scale=None,
|
439 |
+
causal=False,
|
440 |
+
q_descale=None, k_descale=None, v_descale=None,
|
441 |
+
window_size=(-1, -1),
|
442 |
+
attention_chunk=0,
|
443 |
+
softcap=0.0,
|
444 |
+
deterministic=False,
|
445 |
+
num_heads_q=None,
|
446 |
+
sm_margin=0,
|
447 |
+
):
|
448 |
+
"""dropout_p should be set to 0.0 during evaluation
|
449 |
+
If Q, K, V are already stacked into 1 tensor, this function will be faster than
|
450 |
+
calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
|
451 |
+
of the gradients of Q, K, V.
|
452 |
+
For multi-query and grouped-query attention (MQA/GQA), please see
|
453 |
+
flash_attn_kvpacked_func and flash_attn_func.
|
454 |
+
|
455 |
+
If window_size != (-1, -1), implements sliding window local attention. Query at position i
|
456 |
+
will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.
|
457 |
+
|
458 |
+
Arguments:
|
459 |
+
qkv: (batch_size, seqlen, 3, nheads, headdim)
|
460 |
+
dropout_p: float. Dropout probability.
|
461 |
+
softmax_scale: float. The scaling of QK^T before applying softmax.
|
462 |
+
Default to 1 / sqrt(headdim).
|
463 |
+
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
|
464 |
+
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
|
465 |
+
softcap: float. Anything > 0 activates softcapping attention.
|
466 |
+
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to
|
467 |
+
the attention score of query i and key j.
|
468 |
+
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
|
469 |
+
which is slightly slower and uses more memory. The forward pass is always deterministic.
|
470 |
+
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
|
471 |
+
testing only. The returned probabilities are not guaranteed to be correct
|
472 |
+
(they might not have the right scaling).
|
473 |
+
Return:
|
474 |
+
out: (batch_size, seqlen, nheads, headdim).
|
475 |
+
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
|
476 |
+
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
|
477 |
+
normalization factor).
|
478 |
+
S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
|
479 |
+
The output of softmax (possibly with different scaling). It also encodes the dropout
|
480 |
+
pattern (negative means that location was dropped, nonnegative means it was kept).
|
481 |
+
"""
|
482 |
+
return FlashAttnQKVPackedFunc.apply(
|
483 |
+
qkv,
|
484 |
+
softmax_scale,
|
485 |
+
causal,
|
486 |
+
q_descale, k_descale, v_descale,
|
487 |
+
window_size,
|
488 |
+
attention_chunk,
|
489 |
+
softcap,
|
490 |
+
deterministic,
|
491 |
+
num_heads_q,
|
492 |
+
sm_margin,
|
493 |
+
)
|
494 |
+
|
495 |
+
|
496 |
+
def flash_attn_func(
|
497 |
+
q,
|
498 |
+
k,
|
499 |
+
v,
|
500 |
+
softmax_scale=None,
|
501 |
+
causal=False,
|
502 |
+
qv=None,
|
503 |
+
q_descale=None, k_descale=None, v_descale=None,
|
504 |
+
window_size=(-1, -1),
|
505 |
+
attention_chunk=0,
|
506 |
+
softcap=0.0,
|
507 |
+
num_splits=1,
|
508 |
+
pack_gqa=None,
|
509 |
+
deterministic=False,
|
510 |
+
sm_margin=0,
|
511 |
+
):
|
512 |
+
"""dropout_p should be set to 0.0 during evaluation
|
513 |
+
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
|
514 |
+
than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
|
515 |
+
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
|
516 |
+
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
|
517 |
+
|
518 |
+
If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
|
519 |
+
For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
|
520 |
+
1 1 1 1 0
|
521 |
+
1 1 1 1 1
|
522 |
+
If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
|
523 |
+
0 0
|
524 |
+
0 0
|
525 |
+
0 0
|
526 |
+
1 0
|
527 |
+
1 1
|
528 |
+
If the row of the mask is all zero, the output will be zero.
|
529 |
+
|
530 |
+
If window_size != (-1, -1), implements sliding window local attention. Query at position i
|
531 |
+
will only attend to keys between
|
532 |
+
[i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
|
533 |
+
|
534 |
+
Arguments:
|
535 |
+
q: (batch_size, seqlen, nheads, headdim)
|
536 |
+
k: (batch_size, seqlen, nheads_k, headdim)
|
537 |
+
v: (batch_size, seqlen, nheads_k, headdim)
|
538 |
+
dropout_p: float. Dropout probability.
|
539 |
+
softmax_scale: float. The scaling of QK^T before applying softmax.
|
540 |
+
Default to 1 / sqrt(headdim).
|
541 |
+
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
|
542 |
+
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
|
543 |
+
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
|
544 |
+
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
|
545 |
+
is added to the attention score of query i and key j.
|
546 |
+
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
|
547 |
+
which is slightly slower and uses more memory. The forward pass is always deterministic.
|
548 |
+
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
|
549 |
+
testing only. The returned probabilities are not guaranteed to be correct
|
550 |
+
(they might not have the right scaling).
|
551 |
+
Return:
|
552 |
+
out: (batch_size, seqlen, nheads, headdim).
|
553 |
+
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
|
554 |
+
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
|
555 |
+
normalization factor).
|
556 |
+
"""
|
557 |
+
return FlashAttnFunc.apply(
|
558 |
+
q,
|
559 |
+
k,
|
560 |
+
v,
|
561 |
+
softmax_scale,
|
562 |
+
causal,
|
563 |
+
qv,
|
564 |
+
q_descale, k_descale, v_descale,
|
565 |
+
window_size,
|
566 |
+
attention_chunk,
|
567 |
+
softcap,
|
568 |
+
num_splits,
|
569 |
+
pack_gqa,
|
570 |
+
deterministic,
|
571 |
+
sm_margin,
|
572 |
+
)
|
573 |
+
|
574 |
+
|
575 |
+
def flash_attn_varlen_func(
|
576 |
+
q,
|
577 |
+
k,
|
578 |
+
v,
|
579 |
+
cu_seqlens_q,
|
580 |
+
cu_seqlens_k,
|
581 |
+
max_seqlen_q,
|
582 |
+
max_seqlen_k,
|
583 |
+
seqused_q=None,
|
584 |
+
seqused_k=None,
|
585 |
+
softmax_scale=None,
|
586 |
+
causal=False,
|
587 |
+
qv=None,
|
588 |
+
q_descale=None, k_descale=None, v_descale=None,
|
589 |
+
window_size=(-1, -1),
|
590 |
+
attention_chunk=0,
|
591 |
+
softcap=0.0,
|
592 |
+
num_splits=1,
|
593 |
+
pack_gqa=None,
|
594 |
+
deterministic=False,
|
595 |
+
sm_margin=0,
|
596 |
+
):
|
597 |
+
return FlashAttnVarlenFunc.apply(
|
598 |
+
q,
|
599 |
+
k,
|
600 |
+
v,
|
601 |
+
cu_seqlens_q,
|
602 |
+
cu_seqlens_k,
|
603 |
+
seqused_q,
|
604 |
+
seqused_k,
|
605 |
+
max_seqlen_q,
|
606 |
+
max_seqlen_k,
|
607 |
+
softmax_scale,
|
608 |
+
causal,
|
609 |
+
qv,
|
610 |
+
q_descale, k_descale, v_descale,
|
611 |
+
window_size,
|
612 |
+
attention_chunk,
|
613 |
+
softcap,
|
614 |
+
num_splits,
|
615 |
+
pack_gqa,
|
616 |
+
deterministic,
|
617 |
+
sm_margin,
|
618 |
+
)
|
619 |
+
|
620 |
+
|
621 |
+
def flash_attn_combine(out_partial, lse_partial, out=None, out_dtype=None):
|
622 |
+
return flash_attn_3_cuda.fwd_combine(out_partial, lse_partial, out, out_dtype)
|
623 |
+
|
624 |
+
|
625 |
+
def flash_attn_with_kvcache(
|
626 |
+
q,
|
627 |
+
k_cache,
|
628 |
+
v_cache,
|
629 |
+
k=None,
|
630 |
+
v=None,
|
631 |
+
qv=None,
|
632 |
+
rotary_cos=None,
|
633 |
+
rotary_sin=None,
|
634 |
+
cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
|
635 |
+
cache_batch_idx: Optional[torch.Tensor] = None,
|
636 |
+
cache_leftpad: Optional[torch.Tensor] = None,
|
637 |
+
page_table: Optional[torch.Tensor] = None,
|
638 |
+
cu_seqlens_q: Optional[torch.Tensor] = None,
|
639 |
+
cu_seqlens_k_new: Optional[torch.Tensor] = None,
|
640 |
+
max_seqlen_q: Optional[int] = None,
|
641 |
+
rotary_seqlens: Optional[torch.Tensor] = None,
|
642 |
+
q_descale: Optional[torch.Tensor] = None,
|
643 |
+
k_descale: Optional[torch.Tensor] = None,
|
644 |
+
v_descale: Optional[torch.Tensor] = None,
|
645 |
+
softmax_scale=None,
|
646 |
+
causal=False,
|
647 |
+
window_size=(-1, -1), # -1 means infinite context window
|
648 |
+
attention_chunk=0,
|
649 |
+
softcap=0.0, # 0.0 means deactivated
|
650 |
+
rotary_interleaved=True,
|
651 |
+
scheduler_metadata=None,
|
652 |
+
num_splits=0, # Can be tuned for speed
|
653 |
+
pack_gqa=None, # Can be tuned for speed
|
654 |
+
sm_margin=0, # Can be tuned if some SMs are used for communication
|
655 |
+
return_softmax_lse=False,
|
656 |
+
):
|
657 |
+
"""
|
658 |
+
If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
|
659 |
+
k and v. This is useful for incremental decoding: you can pass in the cached keys/values from
|
660 |
+
the previous step, and update them with the new keys/values from the current step, and do
|
661 |
+
attention with the updated cache, all in 1 kernel.
|
662 |
+
|
663 |
+
If you pass in k / v, you must make sure that the cache is large enough to hold the new values.
|
664 |
+
For example, the KV cache could be pre-allocated with the max sequence length, and you can use
|
665 |
+
cache_seqlens to keep track of the current sequence lengths of each sequence in the batch.
|
666 |
+
|
667 |
+
Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be
|
668 |
+
rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
|
669 |
+
If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos
|
670 |
+
and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
|
671 |
+
If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at
|
672 |
+
indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens).
|
673 |
+
|
674 |
+
See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function.
|
675 |
+
|
676 |
+
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
|
677 |
+
than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
|
678 |
+
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
|
679 |
+
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
|
680 |
+
|
681 |
+
If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
|
682 |
+
For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
|
683 |
+
1 1 1 1 0
|
684 |
+
1 1 1 1 1
|
685 |
+
If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
|
686 |
+
0 0
|
687 |
+
0 0
|
688 |
+
0 0
|
689 |
+
1 0
|
690 |
+
1 1
|
691 |
+
If the row of the mask is all zero, the output will be zero.
|
692 |
+
|
693 |
+
If window_size != (-1, -1), implements sliding window local attention. Query at position i
|
694 |
+
will only attend to keys between
|
695 |
+
[i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
|
696 |
+
|
697 |
+
Note: Does not support backward pass.
|
698 |
+
|
699 |
+
Arguments:
|
700 |
+
q: (batch_size, seqlen, nheads, headdim)
|
701 |
+
k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no page_table,
|
702 |
+
or (num_blocks, page_block_size, nheads_k, headdim) if there's a page_table (i.e. paged KV cache)
|
703 |
+
page_block_size must be a multiple of 256.
|
704 |
+
v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim_v) if there's no page_table,
|
705 |
+
or (num_blocks, page_block_size, nheads_k, headdim_v) if there's a page_table (i.e. paged KV cache)
|
706 |
+
k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
|
707 |
+
k with k_cache, starting at the indices specified by cache_seqlens.
|
708 |
+
v [optional]: (batch_size, seqlen_new, nheads_k, headdim_v). Similar to k.
|
709 |
+
qv [optional]: (batch_size, seqlen, nheads, headdim_v)
|
710 |
+
rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding
|
711 |
+
to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16.
|
712 |
+
rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
|
713 |
+
cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
|
714 |
+
KV cache.
|
715 |
+
cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.
|
716 |
+
If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].
|
717 |
+
If the indices are not distinct, and k and v are provided, the values updated in the cache
|
718 |
+
might come from any of the duplicate indices.
|
719 |
+
cache_leftpad: (batch_size,), dtype torch.int32. The index that the KV cache starts. If None, assume 0.
|
720 |
+
page_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32.
|
721 |
+
softmax_scale: float. The scaling of QK^T before applying softmax.
|
722 |
+
Default to 1 / sqrt(headdim).
|
723 |
+
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
|
724 |
+
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
|
725 |
+
softcap: float. Anything > 0 activates softcapping attention.
|
726 |
+
rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
|
727 |
+
If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
|
728 |
+
rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
|
729 |
+
(i.e. GPT-NeoX style).
|
730 |
+
num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
|
731 |
+
If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
|
732 |
+
to automatically determine the number of splits.
|
733 |
+
Don't change this unless you know what you are doing.
|
734 |
+
return_softmax_lse: bool. Whether to return the logsumexp of the attention scores.
|
735 |
+
|
736 |
+
Return:
|
737 |
+
out: (batch_size, seqlen, nheads, headdim).
|
738 |
+
softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The
|
739 |
+
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
|
740 |
+
normalization factor).
|
741 |
+
"""
|
742 |
+
assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
|
743 |
+
assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
|
744 |
+
if softmax_scale is None:
|
745 |
+
softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
|
746 |
+
if cache_seqlens is not None and isinstance(cache_seqlens, int):
|
747 |
+
cache_seqlens = torch.full(
|
748 |
+
(k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device
|
749 |
+
)
|
750 |
+
cache_seqlens = maybe_contiguous(cache_seqlens)
|
751 |
+
out, softmax_lse, *rest = _flash_attn_forward(
|
752 |
+
q,
|
753 |
+
k_cache,
|
754 |
+
v_cache,
|
755 |
+
k,
|
756 |
+
v,
|
757 |
+
qv,
|
758 |
+
None, # out
|
759 |
+
cu_seqlens_q,
|
760 |
+
None, # cu_seqlens_k
|
761 |
+
cu_seqlens_k_new,
|
762 |
+
None, # seqused_q
|
763 |
+
cache_seqlens,
|
764 |
+
max_seqlen_q,
|
765 |
+
None, # max_seqlen_k
|
766 |
+
page_table,
|
767 |
+
cache_batch_idx,
|
768 |
+
cache_leftpad,
|
769 |
+
rotary_cos,
|
770 |
+
rotary_sin,
|
771 |
+
rotary_seqlens,
|
772 |
+
q_descale, k_descale, v_descale,
|
773 |
+
softmax_scale,
|
774 |
+
causal=causal,
|
775 |
+
window_size=window_size,
|
776 |
+
attention_chunk=attention_chunk,
|
777 |
+
softcap=softcap,
|
778 |
+
rotary_interleaved=rotary_interleaved,
|
779 |
+
scheduler_metadata=scheduler_metadata,
|
780 |
+
num_splits=num_splits,
|
781 |
+
pack_gqa=pack_gqa,
|
782 |
+
sm_margin=sm_margin,
|
783 |
+
)
|
784 |
+
# return (out, softmax_lse) if return_softmax_lse else out
|
785 |
+
return (out, softmax_lse, *rest) if return_softmax_lse else out
|
786 |
+
|
787 |
+
|
788 |
+
def get_scheduler_metadata(
|
789 |
+
batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim,
|
790 |
+
cache_seqlens: torch.Tensor,
|
791 |
+
qkv_dtype=torch.bfloat16,
|
792 |
+
headdim_v=None,
|
793 |
+
cu_seqlens_q: Optional[torch.Tensor] = None,
|
794 |
+
cu_seqlens_k_new: Optional[torch.Tensor] = None,
|
795 |
+
cache_leftpad: Optional[torch.Tensor] = None,
|
796 |
+
page_size: Optional[int] = None,
|
797 |
+
max_seqlen_k_new=0,
|
798 |
+
causal=False,
|
799 |
+
window_size=(-1, -1), # -1 means infinite context window
|
800 |
+
attention_chunk=0,
|
801 |
+
has_softcap=False,
|
802 |
+
num_splits=0, # Can be tuned for speed
|
803 |
+
pack_gqa=None, # Can be tuned for speed
|
804 |
+
sm_margin=0, # Can be tuned if some SMs are used for communication
|
805 |
+
):
|
806 |
+
cache_seqlens = maybe_contiguous(cache_seqlens)
|
807 |
+
if headdim_v is None:
|
808 |
+
headdim_v = headdim
|
809 |
+
scheduler_metadata = flash_attn_3_cuda.get_scheduler_metadata(
|
810 |
+
batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim, headdim_v,
|
811 |
+
qkv_dtype,
|
812 |
+
cache_seqlens,
|
813 |
+
cu_seqlens_q,
|
814 |
+
None, # cu_seqlens_k
|
815 |
+
cu_seqlens_k_new,
|
816 |
+
None, # seqused_q
|
817 |
+
cache_leftpad,
|
818 |
+
page_size,
|
819 |
+
max_seqlen_k_new,
|
820 |
+
causal,
|
821 |
+
window_size[0], window_size[1],
|
822 |
+
attention_chunk,
|
823 |
+
has_softcap,
|
824 |
+
num_splits,
|
825 |
+
pack_gqa,
|
826 |
+
sm_margin,
|
827 |
+
)
|
828 |
+
return scheduler_metadata
|
build/torch26-cxx98-cu126-x86_64-linux/flash_attn3/__init__.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .flash_attn_interface import (
|
2 |
+
flash_attn_combine,
|
3 |
+
flash_attn_func,
|
4 |
+
flash_attn_qkvpacked_func,
|
5 |
+
flash_attn_varlen_func,
|
6 |
+
flash_attn_with_kvcache,
|
7 |
+
get_scheduler_metadata,
|
8 |
+
)
|
9 |
+
|
10 |
+
__all__ = [
|
11 |
+
"flash_attn_combine",
|
12 |
+
"flash_attn_func",
|
13 |
+
"flash_attn_qkvpacked_func",
|
14 |
+
"flash_attn_varlen_func",
|
15 |
+
"flash_attn_with_kvcache",
|
16 |
+
"get_scheduler_metadata",
|
17 |
+
]
|
build/torch26-cxx98-cu126-x86_64-linux/flash_attn3/_flash_attn3_2e75662.abi3.so
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9627e08ec8778d2a409a2a0477572edb3e03eaca2b45e7b4810ee0a9126d6547
|
3 |
+
size 838456048
|
build/torch26-cxx98-cu126-x86_64-linux/flash_attn3/_flash_attn3_557701f.abi3.so
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:07fe025ba95671f6ff957991f74c66063bfb10ab6737641c88f88116c9f83718
|
3 |
+
size 838456048
|
build/torch26-cxx98-cu126-x86_64-linux/flash_attn3/_ops.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from . import _flash_attn3_557701f
|
3 |
+
ops = torch.ops._flash_attn3_557701f
|
4 |
+
|
5 |
+
def add_op_namespace_prefix(op_name: str):
|
6 |
+
"""
|
7 |
+
Prefix op by namespace.
|
8 |
+
"""
|
9 |
+
return f"_flash_attn3_557701f::{op_name}"
|
build/torch26-cxx98-cu126-x86_64-linux/flash_attn3/flash_attn_interface.py
ADDED
@@ -0,0 +1,828 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023, Tri Dao.
|
2 |
+
|
3 |
+
from typing import Optional, Union
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
|
8 |
+
from ._ops import ops as flash_attn_3_cuda
|
9 |
+
|
10 |
+
def maybe_contiguous(x):
|
11 |
+
return x.contiguous() if x is not None and x.stride(-1) != 1 else x
|
12 |
+
|
13 |
+
|
14 |
+
def _flash_attn_forward(
|
15 |
+
q,
|
16 |
+
k,
|
17 |
+
v,
|
18 |
+
k_new,
|
19 |
+
v_new,
|
20 |
+
qv,
|
21 |
+
out,
|
22 |
+
cu_seqlens_q,
|
23 |
+
cu_seqlens_k,
|
24 |
+
cu_seqlens_k_new,
|
25 |
+
seqused_q,
|
26 |
+
seqused_k,
|
27 |
+
max_seqlen_q,
|
28 |
+
max_seqlen_k,
|
29 |
+
page_table,
|
30 |
+
kv_batch_idx,
|
31 |
+
leftpad_k,
|
32 |
+
rotary_cos,
|
33 |
+
rotary_sin,
|
34 |
+
seqlens_rotary,
|
35 |
+
q_descale,
|
36 |
+
k_descale,
|
37 |
+
v_descale,
|
38 |
+
softmax_scale,
|
39 |
+
causal,
|
40 |
+
window_size=(-1, -1),
|
41 |
+
attention_chunk=0,
|
42 |
+
softcap=0.0,
|
43 |
+
rotary_interleaved=True,
|
44 |
+
scheduler_metadata=None,
|
45 |
+
num_splits=1,
|
46 |
+
pack_gqa=None,
|
47 |
+
sm_margin=0):
|
48 |
+
q, k, k_new, v_new = [maybe_contiguous(x) for x in (q, k, k_new, v_new)]
|
49 |
+
v = v.contiguous() if v.stride(-1) != 1 and v.stride(-3) != 1 else v
|
50 |
+
cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new = [
|
51 |
+
maybe_contiguous(x) for x in (cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new)
|
52 |
+
]
|
53 |
+
seqused_q, seqused_k = [maybe_contiguous(x) for x in (seqused_q, seqused_k)]
|
54 |
+
page_table, kv_batch_idx, leftpad_k = [
|
55 |
+
maybe_contiguous(x) for x in (page_table, kv_batch_idx, leftpad_k)
|
56 |
+
]
|
57 |
+
rotary_cos, rotary_sin = [maybe_contiguous(x) for x in (rotary_cos, rotary_sin)]
|
58 |
+
seqlens_rotary = maybe_contiguous(seqlens_rotary)
|
59 |
+
out, softmax_lse, *rest = flash_attn_3_cuda.fwd(
|
60 |
+
q,
|
61 |
+
k,
|
62 |
+
v,
|
63 |
+
k_new,
|
64 |
+
v_new,
|
65 |
+
qv,
|
66 |
+
out,
|
67 |
+
cu_seqlens_q,
|
68 |
+
cu_seqlens_k,
|
69 |
+
cu_seqlens_k_new,
|
70 |
+
seqused_q,
|
71 |
+
seqused_k,
|
72 |
+
max_seqlen_q,
|
73 |
+
max_seqlen_k,
|
74 |
+
page_table,
|
75 |
+
kv_batch_idx,
|
76 |
+
leftpad_k,
|
77 |
+
rotary_cos,
|
78 |
+
rotary_sin,
|
79 |
+
seqlens_rotary,
|
80 |
+
q_descale,
|
81 |
+
k_descale,
|
82 |
+
v_descale,
|
83 |
+
softmax_scale,
|
84 |
+
causal,
|
85 |
+
window_size[0],
|
86 |
+
window_size[1],
|
87 |
+
attention_chunk,
|
88 |
+
softcap,
|
89 |
+
rotary_interleaved,
|
90 |
+
scheduler_metadata,
|
91 |
+
num_splits,
|
92 |
+
pack_gqa,
|
93 |
+
sm_margin,
|
94 |
+
)
|
95 |
+
return out, softmax_lse, *rest
|
96 |
+
|
97 |
+
|
98 |
+
def _flash_attn_backward(
|
99 |
+
dout,
|
100 |
+
q,
|
101 |
+
k,
|
102 |
+
v,
|
103 |
+
out,
|
104 |
+
softmax_lse,
|
105 |
+
cu_seqlens_q,
|
106 |
+
cu_seqlens_k,
|
107 |
+
sequed_q,
|
108 |
+
sequed_k,
|
109 |
+
max_seqlen_q,
|
110 |
+
max_seqlen_k,
|
111 |
+
dq,
|
112 |
+
dk,
|
113 |
+
dv,
|
114 |
+
softmax_scale,
|
115 |
+
causal,
|
116 |
+
window_size=(-1, -1),
|
117 |
+
softcap=0.0,
|
118 |
+
deterministic=False,
|
119 |
+
sm_margin=0,
|
120 |
+
):
|
121 |
+
# dq, dk, dv are allocated by us so they should already be contiguous
|
122 |
+
dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
|
123 |
+
dq, dk, dv, softmax_d, *rest = flash_attn_3_cuda.bwd(
|
124 |
+
dout,
|
125 |
+
q,
|
126 |
+
k,
|
127 |
+
v,
|
128 |
+
out,
|
129 |
+
softmax_lse,
|
130 |
+
dq,
|
131 |
+
dk,
|
132 |
+
dv,
|
133 |
+
cu_seqlens_q,
|
134 |
+
cu_seqlens_k,
|
135 |
+
sequed_q,
|
136 |
+
sequed_k,
|
137 |
+
max_seqlen_q,
|
138 |
+
max_seqlen_k,
|
139 |
+
softmax_scale,
|
140 |
+
causal,
|
141 |
+
window_size[0],
|
142 |
+
window_size[1],
|
143 |
+
softcap,
|
144 |
+
deterministic,
|
145 |
+
sm_margin,
|
146 |
+
)
|
147 |
+
return dq, dk, dv, softmax_d
|
148 |
+
|
149 |
+
|
150 |
+
class FlashAttnQKVPackedFunc(torch.autograd.Function):
|
151 |
+
@staticmethod
|
152 |
+
def forward(
|
153 |
+
ctx,
|
154 |
+
qkv,
|
155 |
+
softmax_scale,
|
156 |
+
causal,
|
157 |
+
q_descale=None, k_descale=None, v_descale=None,
|
158 |
+
window_size=(-1, -1),
|
159 |
+
attention_chunk=0,
|
160 |
+
softcap=0.0,
|
161 |
+
deterministic=False,
|
162 |
+
num_heads_q=None,
|
163 |
+
sm_margin=0,
|
164 |
+
):
|
165 |
+
if softmax_scale is None:
|
166 |
+
softmax_scale = qkv.shape[-1] ** (-0.5)
|
167 |
+
if qkv.dim() == 5:
|
168 |
+
assert qkv.shape[-3] == 3
|
169 |
+
q, k, v = qkv.unbind(dim=-3)
|
170 |
+
else:
|
171 |
+
assert qkv.dim() == 4
|
172 |
+
assert num_heads_q is not None
|
173 |
+
num_heads_k = (qkv.shape[2] - num_heads_q) // 2
|
174 |
+
assert num_heads_k * 2 + num_heads_q == qkv.shape[2]
|
175 |
+
q, k, v = qkv.split([num_heads_q, num_heads_k, num_heads_k], dim=-2)
|
176 |
+
out, softmax_lse, *rest = _flash_attn_forward(
|
177 |
+
q,
|
178 |
+
k,
|
179 |
+
v,
|
180 |
+
None, None, # k_new, v_new
|
181 |
+
None, # qv
|
182 |
+
None, # out
|
183 |
+
None, None, None, # cu_seqlens_q/k/k_new
|
184 |
+
None, None, # seqused_q/k
|
185 |
+
None, None, # max_seqlen_q/k
|
186 |
+
None, None, None, # page_table, kv_batch_idx, leftpad_k,
|
187 |
+
None, None, None, # rotary_cos/sin, seqlens_rotary
|
188 |
+
q_descale, k_descale, v_descale,
|
189 |
+
softmax_scale,
|
190 |
+
causal=causal,
|
191 |
+
window_size=window_size,
|
192 |
+
attention_chunk=attention_chunk,
|
193 |
+
softcap=softcap,
|
194 |
+
sm_margin=sm_margin,
|
195 |
+
)
|
196 |
+
# ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
|
197 |
+
ctx.save_for_backward(q, k, v, out, softmax_lse)
|
198 |
+
ctx.softmax_scale = softmax_scale
|
199 |
+
ctx.causal = causal
|
200 |
+
ctx.window_size = window_size
|
201 |
+
ctx.attention_chunk = attention_chunk
|
202 |
+
ctx.softcap = softcap
|
203 |
+
ctx.deterministic = deterministic
|
204 |
+
ctx.ndim = qkv.dim()
|
205 |
+
ctx.sm_margin = sm_margin
|
206 |
+
# return out, softmax_lse
|
207 |
+
return out
|
208 |
+
|
209 |
+
@staticmethod
|
210 |
+
def backward(ctx, dout, *args):
|
211 |
+
q, k, v, out, softmax_lse = ctx.saved_tensors
|
212 |
+
assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
|
213 |
+
if ctx.ndim == 5:
|
214 |
+
qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
|
215 |
+
dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
|
216 |
+
dq, dk, dv = dqkv.unbind(dim=-3)
|
217 |
+
else:
|
218 |
+
num_heads_q = q.shape[2]
|
219 |
+
num_heads_k = k.shape[2]
|
220 |
+
qkv_shape = q.shape[:-2] + (num_heads_q + num_heads_k * 2, *q.shape[-1:])
|
221 |
+
dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
|
222 |
+
dq, dk, dv = dqkv.split([num_heads_q, num_heads_k, num_heads_k], dim=-2)
|
223 |
+
_flash_attn_backward(
|
224 |
+
dout,
|
225 |
+
q,
|
226 |
+
k,
|
227 |
+
v,
|
228 |
+
out,
|
229 |
+
softmax_lse,
|
230 |
+
None, None, # cu_seqlens_q, cu_seqlens_k,
|
231 |
+
None, None, # sequed_q, sequed_k,
|
232 |
+
None, None, # max_seqlen_q, max_seqlen_k,
|
233 |
+
dq,
|
234 |
+
dk,
|
235 |
+
dv,
|
236 |
+
ctx.softmax_scale,
|
237 |
+
ctx.causal,
|
238 |
+
ctx.window_size,
|
239 |
+
ctx.softcap,
|
240 |
+
ctx.deterministic,
|
241 |
+
ctx.sm_margin,
|
242 |
+
)
|
243 |
+
dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension
|
244 |
+
return dqkv, None, None, None, None, None, None, None, None, None, None, None
|
245 |
+
|
246 |
+
|
247 |
+
class FlashAttnFunc(torch.autograd.Function):
|
248 |
+
|
249 |
+
@staticmethod
|
250 |
+
def forward(
|
251 |
+
ctx,
|
252 |
+
q,
|
253 |
+
k,
|
254 |
+
v,
|
255 |
+
softmax_scale,
|
256 |
+
causal,
|
257 |
+
qv=None,
|
258 |
+
q_descale=None, k_descale=None, v_descale=None,
|
259 |
+
window_size=(-1, -1),
|
260 |
+
attention_chunk=0,
|
261 |
+
softcap=0.0,
|
262 |
+
num_splits=1,
|
263 |
+
pack_gqa=None,
|
264 |
+
deterministic=False,
|
265 |
+
sm_margin=0,
|
266 |
+
):
|
267 |
+
if softmax_scale is None:
|
268 |
+
softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
|
269 |
+
# out, q, k, v, out_padded, softmax_lse = _flash_attn_forward(
|
270 |
+
out, softmax_lse, *rest = _flash_attn_forward(
|
271 |
+
q,
|
272 |
+
k,
|
273 |
+
v,
|
274 |
+
None, None, # k_new, v_new
|
275 |
+
qv, # qv
|
276 |
+
None, # out
|
277 |
+
None, None, None, # cu_seqlens_q/k/k_new
|
278 |
+
None, None, # seqused_q/k
|
279 |
+
None, None, # max_seqlen_q/k
|
280 |
+
None, None, None, # page_table, kv_batch_idx, leftpad_k,
|
281 |
+
None, None, None, # rotary_cos/sin, seqlens_rotary
|
282 |
+
q_descale, k_descale, v_descale,
|
283 |
+
softmax_scale,
|
284 |
+
causal=causal,
|
285 |
+
window_size=window_size,
|
286 |
+
attention_chunk=attention_chunk,
|
287 |
+
softcap=softcap,
|
288 |
+
num_splits=num_splits,
|
289 |
+
pack_gqa=pack_gqa,
|
290 |
+
sm_margin=sm_margin,
|
291 |
+
)
|
292 |
+
# ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
|
293 |
+
ctx.save_for_backward(q, k, v, out, softmax_lse)
|
294 |
+
ctx.softmax_scale = softmax_scale
|
295 |
+
ctx.causal = causal
|
296 |
+
ctx.window_size = window_size
|
297 |
+
ctx.attention_chunk = attention_chunk
|
298 |
+
ctx.softcap = softcap
|
299 |
+
ctx.deterministic = deterministic
|
300 |
+
ctx.sm_margin = sm_margin
|
301 |
+
return out, softmax_lse
|
302 |
+
|
303 |
+
@staticmethod
|
304 |
+
def backward(ctx, dout, *args):
|
305 |
+
q, k, v, out, softmax_lse = ctx.saved_tensors
|
306 |
+
assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
|
307 |
+
dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
|
308 |
+
_flash_attn_backward(
|
309 |
+
dout,
|
310 |
+
q,
|
311 |
+
k,
|
312 |
+
v,
|
313 |
+
out,
|
314 |
+
softmax_lse,
|
315 |
+
None, None, # cu_seqlens_q, cu_seqlens_k,
|
316 |
+
None, None, # sequed_q, sequed_k,
|
317 |
+
None, None, # max_seqlen_q, max_seqlen_k,
|
318 |
+
dq,
|
319 |
+
dk,
|
320 |
+
dv,
|
321 |
+
ctx.softmax_scale,
|
322 |
+
ctx.causal,
|
323 |
+
ctx.window_size,
|
324 |
+
ctx.softcap,
|
325 |
+
ctx.deterministic,
|
326 |
+
ctx.sm_margin,
|
327 |
+
)
|
328 |
+
dq = dq[..., : q.shape[-1]] # We could have padded the head dimension
|
329 |
+
dk = dk[..., : k.shape[-1]]
|
330 |
+
dv = dv[..., : v.shape[-1]]
|
331 |
+
return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None
|
332 |
+
|
333 |
+
|
334 |
+
class FlashAttnVarlenFunc(torch.autograd.Function):
|
335 |
+
|
336 |
+
@staticmethod
|
337 |
+
def forward(
|
338 |
+
ctx,
|
339 |
+
q,
|
340 |
+
k,
|
341 |
+
v,
|
342 |
+
cu_seqlens_q,
|
343 |
+
cu_seqlens_k,
|
344 |
+
seqused_q,
|
345 |
+
seqused_k,
|
346 |
+
max_seqlen_q,
|
347 |
+
max_seqlen_k,
|
348 |
+
softmax_scale,
|
349 |
+
causal,
|
350 |
+
qv=None,
|
351 |
+
q_descale=None, k_descale=None, v_descale=None,
|
352 |
+
window_size=(-1, -1),
|
353 |
+
attention_chunk=0,
|
354 |
+
softcap=0.0,
|
355 |
+
num_splits=1,
|
356 |
+
pack_gqa=None,
|
357 |
+
deterministic=False,
|
358 |
+
sm_margin=0,
|
359 |
+
):
|
360 |
+
if softmax_scale is None:
|
361 |
+
softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
|
362 |
+
# out, q, k, v, out_padded, softmax_lse = _flash_attn_varlen_forward(
|
363 |
+
out, softmax_lse, *rest = _flash_attn_forward(
|
364 |
+
q,
|
365 |
+
k,
|
366 |
+
v,
|
367 |
+
None, None, # k_new, v_new
|
368 |
+
qv, # qv
|
369 |
+
None, # out
|
370 |
+
cu_seqlens_q,
|
371 |
+
cu_seqlens_k,
|
372 |
+
None, # cu_seqlens_k_new
|
373 |
+
seqused_q,
|
374 |
+
seqused_k,
|
375 |
+
max_seqlen_q,
|
376 |
+
max_seqlen_k,
|
377 |
+
None, None, None, # page_table, kv_batch_idx, leftpad_k,
|
378 |
+
None, None, None, # rotary_cos/sin, seqlens_rotary
|
379 |
+
q_descale, k_descale, v_descale,
|
380 |
+
softmax_scale,
|
381 |
+
causal=causal,
|
382 |
+
window_size=window_size,
|
383 |
+
attention_chunk=attention_chunk,
|
384 |
+
softcap=softcap,
|
385 |
+
num_splits=num_splits,
|
386 |
+
pack_gqa=pack_gqa,
|
387 |
+
sm_margin=sm_margin,
|
388 |
+
)
|
389 |
+
# ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
|
390 |
+
ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
|
391 |
+
ctx.max_seqlen_q = max_seqlen_q
|
392 |
+
ctx.max_seqlen_k = max_seqlen_k
|
393 |
+
ctx.softmax_scale = softmax_scale
|
394 |
+
ctx.causal = causal
|
395 |
+
ctx.window_size = window_size
|
396 |
+
ctx.attention_chunk = attention_chunk
|
397 |
+
ctx.softcap = softcap
|
398 |
+
ctx.deterministic = deterministic
|
399 |
+
ctx.sm_margin = sm_margin
|
400 |
+
return out, softmax_lse
|
401 |
+
|
402 |
+
@staticmethod
|
403 |
+
def backward(ctx, dout, *args):
|
404 |
+
q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = ctx.saved_tensors
|
405 |
+
assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
|
406 |
+
dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
|
407 |
+
_flash_attn_backward(
|
408 |
+
dout,
|
409 |
+
q,
|
410 |
+
k,
|
411 |
+
v,
|
412 |
+
out,
|
413 |
+
softmax_lse,
|
414 |
+
cu_seqlens_q,
|
415 |
+
cu_seqlens_k,
|
416 |
+
seqused_q,
|
417 |
+
seqused_k,
|
418 |
+
ctx.max_seqlen_q,
|
419 |
+
ctx.max_seqlen_k,
|
420 |
+
dq,
|
421 |
+
dk,
|
422 |
+
dv,
|
423 |
+
ctx.softmax_scale,
|
424 |
+
ctx.causal,
|
425 |
+
ctx.window_size,
|
426 |
+
ctx.softcap,
|
427 |
+
ctx.deterministic,
|
428 |
+
ctx.sm_margin,
|
429 |
+
)
|
430 |
+
dq = dq[..., : q.shape[-1]] # We could have padded the head dimension
|
431 |
+
dk = dk[..., : k.shape[-1]]
|
432 |
+
dv = dv[..., : v.shape[-1]]
|
433 |
+
return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None
|
434 |
+
|
435 |
+
|
436 |
+
def flash_attn_qkvpacked_func(
|
437 |
+
qkv,
|
438 |
+
softmax_scale=None,
|
439 |
+
causal=False,
|
440 |
+
q_descale=None, k_descale=None, v_descale=None,
|
441 |
+
window_size=(-1, -1),
|
442 |
+
attention_chunk=0,
|
443 |
+
softcap=0.0,
|
444 |
+
deterministic=False,
|
445 |
+
num_heads_q=None,
|
446 |
+
sm_margin=0,
|
447 |
+
):
|
448 |
+
"""dropout_p should be set to 0.0 during evaluation
|
449 |
+
If Q, K, V are already stacked into 1 tensor, this function will be faster than
|
450 |
+
calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
|
451 |
+
of the gradients of Q, K, V.
|
452 |
+
For multi-query and grouped-query attention (MQA/GQA), please see
|
453 |
+
flash_attn_kvpacked_func and flash_attn_func.
|
454 |
+
|
455 |
+
If window_size != (-1, -1), implements sliding window local attention. Query at position i
|
456 |
+
will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.
|
457 |
+
|
458 |
+
Arguments:
|
459 |
+
qkv: (batch_size, seqlen, 3, nheads, headdim)
|
460 |
+
dropout_p: float. Dropout probability.
|
461 |
+
softmax_scale: float. The scaling of QK^T before applying softmax.
|
462 |
+
Default to 1 / sqrt(headdim).
|
463 |
+
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
|
464 |
+
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
|
465 |
+
softcap: float. Anything > 0 activates softcapping attention.
|
466 |
+
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to
|
467 |
+
the attention score of query i and key j.
|
468 |
+
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
|
469 |
+
which is slightly slower and uses more memory. The forward pass is always deterministic.
|
470 |
+
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
|
471 |
+
testing only. The returned probabilities are not guaranteed to be correct
|
472 |
+
(they might not have the right scaling).
|
473 |
+
Return:
|
474 |
+
out: (batch_size, seqlen, nheads, headdim).
|
475 |
+
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
|
476 |
+
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
|
477 |
+
normalization factor).
|
478 |
+
S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
|
479 |
+
The output of softmax (possibly with different scaling). It also encodes the dropout
|
480 |
+
pattern (negative means that location was dropped, nonnegative means it was kept).
|
481 |
+
"""
|
482 |
+
return FlashAttnQKVPackedFunc.apply(
|
483 |
+
qkv,
|
484 |
+
softmax_scale,
|
485 |
+
causal,
|
486 |
+
q_descale, k_descale, v_descale,
|
487 |
+
window_size,
|
488 |
+
attention_chunk,
|
489 |
+
softcap,
|
490 |
+
deterministic,
|
491 |
+
num_heads_q,
|
492 |
+
sm_margin,
|
493 |
+
)
|
494 |
+
|
495 |
+
|
496 |
+
def flash_attn_func(
|
497 |
+
q,
|
498 |
+
k,
|
499 |
+
v,
|
500 |
+
softmax_scale=None,
|
501 |
+
causal=False,
|
502 |
+
qv=None,
|
503 |
+
q_descale=None, k_descale=None, v_descale=None,
|
504 |
+
window_size=(-1, -1),
|
505 |
+
attention_chunk=0,
|
506 |
+
softcap=0.0,
|
507 |
+
num_splits=1,
|
508 |
+
pack_gqa=None,
|
509 |
+
deterministic=False,
|
510 |
+
sm_margin=0,
|
511 |
+
):
|
512 |
+
"""dropout_p should be set to 0.0 during evaluation
|
513 |
+
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
|
514 |
+
than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
|
515 |
+
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
|
516 |
+
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
|
517 |
+
|
518 |
+
If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
|
519 |
+
For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
|
520 |
+
1 1 1 1 0
|
521 |
+
1 1 1 1 1
|
522 |
+
If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
|
523 |
+
0 0
|
524 |
+
0 0
|
525 |
+
0 0
|
526 |
+
1 0
|
527 |
+
1 1
|
528 |
+
If the row of the mask is all zero, the output will be zero.
|
529 |
+
|
530 |
+
If window_size != (-1, -1), implements sliding window local attention. Query at position i
|
531 |
+
will only attend to keys between
|
532 |
+
[i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
|
533 |
+
|
534 |
+
Arguments:
|
535 |
+
q: (batch_size, seqlen, nheads, headdim)
|
536 |
+
k: (batch_size, seqlen, nheads_k, headdim)
|
537 |
+
v: (batch_size, seqlen, nheads_k, headdim)
|
538 |
+
dropout_p: float. Dropout probability.
|
539 |
+
softmax_scale: float. The scaling of QK^T before applying softmax.
|
540 |
+
Default to 1 / sqrt(headdim).
|
541 |
+
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
|
542 |
+
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
|
543 |
+
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
|
544 |
+
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
|
545 |
+
is added to the attention score of query i and key j.
|
546 |
+
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
|
547 |
+
which is slightly slower and uses more memory. The forward pass is always deterministic.
|
548 |
+
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
|
549 |
+
testing only. The returned probabilities are not guaranteed to be correct
|
550 |
+
(they might not have the right scaling).
|
551 |
+
Return:
|
552 |
+
out: (batch_size, seqlen, nheads, headdim).
|
553 |
+
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
|
554 |
+
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
|
555 |
+
normalization factor).
|
556 |
+
"""
|
557 |
+
return FlashAttnFunc.apply(
|
558 |
+
q,
|
559 |
+
k,
|
560 |
+
v,
|
561 |
+
softmax_scale,
|
562 |
+
causal,
|
563 |
+
qv,
|
564 |
+
q_descale, k_descale, v_descale,
|
565 |
+
window_size,
|
566 |
+
attention_chunk,
|
567 |
+
softcap,
|
568 |
+
num_splits,
|
569 |
+
pack_gqa,
|
570 |
+
deterministic,
|
571 |
+
sm_margin,
|
572 |
+
)
|
573 |
+
|
574 |
+
|
575 |
+
def flash_attn_varlen_func(
|
576 |
+
q,
|
577 |
+
k,
|
578 |
+
v,
|
579 |
+
cu_seqlens_q,
|
580 |
+
cu_seqlens_k,
|
581 |
+
max_seqlen_q,
|
582 |
+
max_seqlen_k,
|
583 |
+
seqused_q=None,
|
584 |
+
seqused_k=None,
|
585 |
+
softmax_scale=None,
|
586 |
+
causal=False,
|
587 |
+
qv=None,
|
588 |
+
q_descale=None, k_descale=None, v_descale=None,
|
589 |
+
window_size=(-1, -1),
|
590 |
+
attention_chunk=0,
|
591 |
+
softcap=0.0,
|
592 |
+
num_splits=1,
|
593 |
+
pack_gqa=None,
|
594 |
+
deterministic=False,
|
595 |
+
sm_margin=0,
|
596 |
+
):
|
597 |
+
return FlashAttnVarlenFunc.apply(
|
598 |
+
q,
|
599 |
+
k,
|
600 |
+
v,
|
601 |
+
cu_seqlens_q,
|
602 |
+
cu_seqlens_k,
|
603 |
+
seqused_q,
|
604 |
+
seqused_k,
|
605 |
+
max_seqlen_q,
|
606 |
+
max_seqlen_k,
|
607 |
+
softmax_scale,
|
608 |
+
causal,
|
609 |
+
qv,
|
610 |
+
q_descale, k_descale, v_descale,
|
611 |
+
window_size,
|
612 |
+
attention_chunk,
|
613 |
+
softcap,
|
614 |
+
num_splits,
|
615 |
+
pack_gqa,
|
616 |
+
deterministic,
|
617 |
+
sm_margin,
|
618 |
+
)
|
619 |
+
|
620 |
+
|
621 |
+
def flash_attn_combine(out_partial, lse_partial, out=None, out_dtype=None):
|
622 |
+
return flash_attn_3_cuda.fwd_combine(out_partial, lse_partial, out, out_dtype)
|
623 |
+
|
624 |
+
|
625 |
+
def flash_attn_with_kvcache(
|
626 |
+
q,
|
627 |
+
k_cache,
|
628 |
+
v_cache,
|
629 |
+
k=None,
|
630 |
+
v=None,
|
631 |
+
qv=None,
|
632 |
+
rotary_cos=None,
|
633 |
+
rotary_sin=None,
|
634 |
+
cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
|
635 |
+
cache_batch_idx: Optional[torch.Tensor] = None,
|
636 |
+
cache_leftpad: Optional[torch.Tensor] = None,
|
637 |
+
page_table: Optional[torch.Tensor] = None,
|
638 |
+
cu_seqlens_q: Optional[torch.Tensor] = None,
|
639 |
+
cu_seqlens_k_new: Optional[torch.Tensor] = None,
|
640 |
+
max_seqlen_q: Optional[int] = None,
|
641 |
+
rotary_seqlens: Optional[torch.Tensor] = None,
|
642 |
+
q_descale: Optional[torch.Tensor] = None,
|
643 |
+
k_descale: Optional[torch.Tensor] = None,
|
644 |
+
v_descale: Optional[torch.Tensor] = None,
|
645 |
+
softmax_scale=None,
|
646 |
+
causal=False,
|
647 |
+
window_size=(-1, -1), # -1 means infinite context window
|
648 |
+
attention_chunk=0,
|
649 |
+
softcap=0.0, # 0.0 means deactivated
|
650 |
+
rotary_interleaved=True,
|
651 |
+
scheduler_metadata=None,
|
652 |
+
num_splits=0, # Can be tuned for speed
|
653 |
+
pack_gqa=None, # Can be tuned for speed
|
654 |
+
sm_margin=0, # Can be tuned if some SMs are used for communication
|
655 |
+
return_softmax_lse=False,
|
656 |
+
):
|
657 |
+
"""
|
658 |
+
If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
|
659 |
+
k and v. This is useful for incremental decoding: you can pass in the cached keys/values from
|
660 |
+
the previous step, and update them with the new keys/values from the current step, and do
|
661 |
+
attention with the updated cache, all in 1 kernel.
|
662 |
+
|
663 |
+
If you pass in k / v, you must make sure that the cache is large enough to hold the new values.
|
664 |
+
For example, the KV cache could be pre-allocated with the max sequence length, and you can use
|
665 |
+
cache_seqlens to keep track of the current sequence lengths of each sequence in the batch.
|
666 |
+
|
667 |
+
Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be
|
668 |
+
rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
|
669 |
+
If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos
|
670 |
+
and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
|
671 |
+
If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at
|
672 |
+
indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens).
|
673 |
+
|
674 |
+
See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function.
|
675 |
+
|
676 |
+
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
|
677 |
+
than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
|
678 |
+
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
|
679 |
+
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
|
680 |
+
|
681 |
+
If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
|
682 |
+
For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
|
683 |
+
1 1 1 1 0
|
684 |
+
1 1 1 1 1
|
685 |
+
If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
|
686 |
+
0 0
|
687 |
+
0 0
|
688 |
+
0 0
|
689 |
+
1 0
|
690 |
+
1 1
|
691 |
+
If the row of the mask is all zero, the output will be zero.
|
692 |
+
|
693 |
+
If window_size != (-1, -1), implements sliding window local attention. Query at position i
|
694 |
+
will only attend to keys between
|
695 |
+
[i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
|
696 |
+
|
697 |
+
Note: Does not support backward pass.
|
698 |
+
|
699 |
+
Arguments:
|
700 |
+
q: (batch_size, seqlen, nheads, headdim)
|
701 |
+
k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no page_table,
|
702 |
+
or (num_blocks, page_block_size, nheads_k, headdim) if there's a page_table (i.e. paged KV cache)
|
703 |
+
page_block_size must be a multiple of 256.
|
704 |
+
v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim_v) if there's no page_table,
|
705 |
+
or (num_blocks, page_block_size, nheads_k, headdim_v) if there's a page_table (i.e. paged KV cache)
|
706 |
+
k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
|
707 |
+
k with k_cache, starting at the indices specified by cache_seqlens.
|
708 |
+
v [optional]: (batch_size, seqlen_new, nheads_k, headdim_v). Similar to k.
|
709 |
+
qv [optional]: (batch_size, seqlen, nheads, headdim_v)
|
710 |
+
rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding
|
711 |
+
to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16.
|
712 |
+
rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
|
713 |
+
cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
|
714 |
+
KV cache.
|
715 |
+
cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.
|
716 |
+
If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].
|
717 |
+
If the indices are not distinct, and k and v are provided, the values updated in the cache
|
718 |
+
might come from any of the duplicate indices.
|
719 |
+
cache_leftpad: (batch_size,), dtype torch.int32. The index that the KV cache starts. If None, assume 0.
|
720 |
+
page_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32.
|
721 |
+
softmax_scale: float. The scaling of QK^T before applying softmax.
|
722 |
+
Default to 1 / sqrt(headdim).
|
723 |
+
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
|
724 |
+
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
|
725 |
+
softcap: float. Anything > 0 activates softcapping attention.
|
726 |
+
rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
|
727 |
+
If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
|
728 |
+
rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
|
729 |
+
(i.e. GPT-NeoX style).
|
730 |
+
num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
|
731 |
+
If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
|
732 |
+
to automatically determine the number of splits.
|
733 |
+
Don't change this unless you know what you are doing.
|
734 |
+
return_softmax_lse: bool. Whether to return the logsumexp of the attention scores.
|
735 |
+
|
736 |
+
Return:
|
737 |
+
out: (batch_size, seqlen, nheads, headdim).
|
738 |
+
softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The
|
739 |
+
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
|
740 |
+
normalization factor).
|
741 |
+
"""
|
742 |
+
assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
|
743 |
+
assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
|
744 |
+
if softmax_scale is None:
|
745 |
+
softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
|
746 |
+
if cache_seqlens is not None and isinstance(cache_seqlens, int):
|
747 |
+
cache_seqlens = torch.full(
|
748 |
+
(k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device
|
749 |
+
)
|
750 |
+
cache_seqlens = maybe_contiguous(cache_seqlens)
|
751 |
+
out, softmax_lse, *rest = _flash_attn_forward(
|
752 |
+
q,
|
753 |
+
k_cache,
|
754 |
+
v_cache,
|
755 |
+
k,
|
756 |
+
v,
|
757 |
+
qv,
|
758 |
+
None, # out
|
759 |
+
cu_seqlens_q,
|
760 |
+
None, # cu_seqlens_k
|
761 |
+
cu_seqlens_k_new,
|
762 |
+
None, # seqused_q
|
763 |
+
cache_seqlens,
|
764 |
+
max_seqlen_q,
|
765 |
+
None, # max_seqlen_k
|
766 |
+
page_table,
|
767 |
+
cache_batch_idx,
|
768 |
+
cache_leftpad,
|
769 |
+
rotary_cos,
|
770 |
+
rotary_sin,
|
771 |
+
rotary_seqlens,
|
772 |
+
q_descale, k_descale, v_descale,
|
773 |
+
softmax_scale,
|
774 |
+
causal=causal,
|
775 |
+
window_size=window_size,
|
776 |
+
attention_chunk=attention_chunk,
|
777 |
+
softcap=softcap,
|
778 |
+
rotary_interleaved=rotary_interleaved,
|
779 |
+
scheduler_metadata=scheduler_metadata,
|
780 |
+
num_splits=num_splits,
|
781 |
+
pack_gqa=pack_gqa,
|
782 |
+
sm_margin=sm_margin,
|
783 |
+
)
|
784 |
+
# return (out, softmax_lse) if return_softmax_lse else out
|
785 |
+
return (out, softmax_lse, *rest) if return_softmax_lse else out
|
786 |
+
|
787 |
+
|
788 |
+
def get_scheduler_metadata(
|
789 |
+
batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim,
|
790 |
+
cache_seqlens: torch.Tensor,
|
791 |
+
qkv_dtype=torch.bfloat16,
|
792 |
+
headdim_v=None,
|
793 |
+
cu_seqlens_q: Optional[torch.Tensor] = None,
|
794 |
+
cu_seqlens_k_new: Optional[torch.Tensor] = None,
|
795 |
+
cache_leftpad: Optional[torch.Tensor] = None,
|
796 |
+
page_size: Optional[int] = None,
|
797 |
+
max_seqlen_k_new=0,
|
798 |
+
causal=False,
|
799 |
+
window_size=(-1, -1), # -1 means infinite context window
|
800 |
+
attention_chunk=0,
|
801 |
+
has_softcap=False,
|
802 |
+
num_splits=0, # Can be tuned for speed
|
803 |
+
pack_gqa=None, # Can be tuned for speed
|
804 |
+
sm_margin=0, # Can be tuned if some SMs are used for communication
|
805 |
+
):
|
806 |
+
cache_seqlens = maybe_contiguous(cache_seqlens)
|
807 |
+
if headdim_v is None:
|
808 |
+
headdim_v = headdim
|
809 |
+
scheduler_metadata = flash_attn_3_cuda.get_scheduler_metadata(
|
810 |
+
batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim, headdim_v,
|
811 |
+
qkv_dtype,
|
812 |
+
cache_seqlens,
|
813 |
+
cu_seqlens_q,
|
814 |
+
None, # cu_seqlens_k
|
815 |
+
cu_seqlens_k_new,
|
816 |
+
None, # seqused_q
|
817 |
+
cache_leftpad,
|
818 |
+
page_size,
|
819 |
+
max_seqlen_k_new,
|
820 |
+
causal,
|
821 |
+
window_size[0], window_size[1],
|
822 |
+
attention_chunk,
|
823 |
+
has_softcap,
|
824 |
+
num_splits,
|
825 |
+
pack_gqa,
|
826 |
+
sm_margin,
|
827 |
+
)
|
828 |
+
return scheduler_metadata
|
build/torch27-cxx11-cu128-x86_64-linux/flash_attn3/__init__.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .flash_attn_interface import (
|
2 |
+
flash_attn_combine,
|
3 |
+
flash_attn_func,
|
4 |
+
flash_attn_qkvpacked_func,
|
5 |
+
flash_attn_varlen_func,
|
6 |
+
flash_attn_with_kvcache,
|
7 |
+
get_scheduler_metadata,
|
8 |
+
)
|
9 |
+
|
10 |
+
__all__ = [
|
11 |
+
"flash_attn_combine",
|
12 |
+
"flash_attn_func",
|
13 |
+
"flash_attn_qkvpacked_func",
|
14 |
+
"flash_attn_varlen_func",
|
15 |
+
"flash_attn_with_kvcache",
|
16 |
+
"get_scheduler_metadata",
|
17 |
+
]
|
build/torch27-cxx11-cu128-x86_64-linux/flash_attn3/_flash_attn3_2e75662.abi3.so
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c0302224ac29ba4773d926d4cb16c01c45a374c6dd61286aae1f423f2bf495ea
|
3 |
+
size 838459544
|
build/torch27-cxx11-cu128-x86_64-linux/flash_attn3/_ops.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from . import _flash_attn3_2e75662
|
3 |
+
ops = torch.ops._flash_attn3_2e75662
|
4 |
+
|
5 |
+
def add_op_namespace_prefix(op_name: str):
|
6 |
+
"""
|
7 |
+
Prefix op by namespace.
|
8 |
+
"""
|
9 |
+
return f"_flash_attn3_2e75662::{op_name}"
|
build/torch27-cxx11-cu128-x86_64-linux/flash_attn3/flash_attn_interface.py
ADDED
@@ -0,0 +1,828 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023, Tri Dao.
|
2 |
+
|
3 |
+
from typing import Optional, Union
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
|
8 |
+
from ._ops import ops as flash_attn_3_cuda
|
9 |
+
|
10 |
+
def maybe_contiguous(x):
|
11 |
+
return x.contiguous() if x is not None and x.stride(-1) != 1 else x
|
12 |
+
|
13 |
+
|
14 |
+
def _flash_attn_forward(
|
15 |
+
q,
|
16 |
+
k,
|
17 |
+
v,
|
18 |
+
k_new,
|
19 |
+
v_new,
|
20 |
+
qv,
|
21 |
+
out,
|
22 |
+
cu_seqlens_q,
|
23 |
+
cu_seqlens_k,
|
24 |
+
cu_seqlens_k_new,
|
25 |
+
seqused_q,
|
26 |
+
seqused_k,
|
27 |
+
max_seqlen_q,
|
28 |
+
max_seqlen_k,
|
29 |
+
page_table,
|
30 |
+
kv_batch_idx,
|
31 |
+
leftpad_k,
|
32 |
+
rotary_cos,
|
33 |
+
rotary_sin,
|
34 |
+
seqlens_rotary,
|
35 |
+
q_descale,
|
36 |
+
k_descale,
|
37 |
+
v_descale,
|
38 |
+
softmax_scale,
|
39 |
+
causal,
|
40 |
+
window_size=(-1, -1),
|
41 |
+
attention_chunk=0,
|
42 |
+
softcap=0.0,
|
43 |
+
rotary_interleaved=True,
|
44 |
+
scheduler_metadata=None,
|
45 |
+
num_splits=1,
|
46 |
+
pack_gqa=None,
|
47 |
+
sm_margin=0):
|
48 |
+
q, k, k_new, v_new = [maybe_contiguous(x) for x in (q, k, k_new, v_new)]
|
49 |
+
v = v.contiguous() if v.stride(-1) != 1 and v.stride(-3) != 1 else v
|
50 |
+
cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new = [
|
51 |
+
maybe_contiguous(x) for x in (cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new)
|
52 |
+
]
|
53 |
+
seqused_q, seqused_k = [maybe_contiguous(x) for x in (seqused_q, seqused_k)]
|
54 |
+
page_table, kv_batch_idx, leftpad_k = [
|
55 |
+
maybe_contiguous(x) for x in (page_table, kv_batch_idx, leftpad_k)
|
56 |
+
]
|
57 |
+
rotary_cos, rotary_sin = [maybe_contiguous(x) for x in (rotary_cos, rotary_sin)]
|
58 |
+
seqlens_rotary = maybe_contiguous(seqlens_rotary)
|
59 |
+
out, softmax_lse, *rest = flash_attn_3_cuda.fwd(
|
60 |
+
q,
|
61 |
+
k,
|
62 |
+
v,
|
63 |
+
k_new,
|
64 |
+
v_new,
|
65 |
+
qv,
|
66 |
+
out,
|
67 |
+
cu_seqlens_q,
|
68 |
+
cu_seqlens_k,
|
69 |
+
cu_seqlens_k_new,
|
70 |
+
seqused_q,
|
71 |
+
seqused_k,
|
72 |
+
max_seqlen_q,
|
73 |
+
max_seqlen_k,
|
74 |
+
page_table,
|
75 |
+
kv_batch_idx,
|
76 |
+
leftpad_k,
|
77 |
+
rotary_cos,
|
78 |
+
rotary_sin,
|
79 |
+
seqlens_rotary,
|
80 |
+
q_descale,
|
81 |
+
k_descale,
|
82 |
+
v_descale,
|
83 |
+
softmax_scale,
|
84 |
+
causal,
|
85 |
+
window_size[0],
|
86 |
+
window_size[1],
|
87 |
+
attention_chunk,
|
88 |
+
softcap,
|
89 |
+
rotary_interleaved,
|
90 |
+
scheduler_metadata,
|
91 |
+
num_splits,
|
92 |
+
pack_gqa,
|
93 |
+
sm_margin,
|
94 |
+
)
|
95 |
+
return out, softmax_lse, *rest
|
96 |
+
|
97 |
+
|
98 |
+
def _flash_attn_backward(
|
99 |
+
dout,
|
100 |
+
q,
|
101 |
+
k,
|
102 |
+
v,
|
103 |
+
out,
|
104 |
+
softmax_lse,
|
105 |
+
cu_seqlens_q,
|
106 |
+
cu_seqlens_k,
|
107 |
+
sequed_q,
|
108 |
+
sequed_k,
|
109 |
+
max_seqlen_q,
|
110 |
+
max_seqlen_k,
|
111 |
+
dq,
|
112 |
+
dk,
|
113 |
+
dv,
|
114 |
+
softmax_scale,
|
115 |
+
causal,
|
116 |
+
window_size=(-1, -1),
|
117 |
+
softcap=0.0,
|
118 |
+
deterministic=False,
|
119 |
+
sm_margin=0,
|
120 |
+
):
|
121 |
+
# dq, dk, dv are allocated by us so they should already be contiguous
|
122 |
+
dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
|
123 |
+
dq, dk, dv, softmax_d, *rest = flash_attn_3_cuda.bwd(
|
124 |
+
dout,
|
125 |
+
q,
|
126 |
+
k,
|
127 |
+
v,
|
128 |
+
out,
|
129 |
+
softmax_lse,
|
130 |
+
dq,
|
131 |
+
dk,
|
132 |
+
dv,
|
133 |
+
cu_seqlens_q,
|
134 |
+
cu_seqlens_k,
|
135 |
+
sequed_q,
|
136 |
+
sequed_k,
|
137 |
+
max_seqlen_q,
|
138 |
+
max_seqlen_k,
|
139 |
+
softmax_scale,
|
140 |
+
causal,
|
141 |
+
window_size[0],
|
142 |
+
window_size[1],
|
143 |
+
softcap,
|
144 |
+
deterministic,
|
145 |
+
sm_margin,
|
146 |
+
)
|
147 |
+
return dq, dk, dv, softmax_d
|
148 |
+
|
149 |
+
|
150 |
+
class FlashAttnQKVPackedFunc(torch.autograd.Function):
|
151 |
+
@staticmethod
|
152 |
+
def forward(
|
153 |
+
ctx,
|
154 |
+
qkv,
|
155 |
+
softmax_scale,
|
156 |
+
causal,
|
157 |
+
q_descale=None, k_descale=None, v_descale=None,
|
158 |
+
window_size=(-1, -1),
|
159 |
+
attention_chunk=0,
|
160 |
+
softcap=0.0,
|
161 |
+
deterministic=False,
|
162 |
+
num_heads_q=None,
|
163 |
+
sm_margin=0,
|
164 |
+
):
|
165 |
+
if softmax_scale is None:
|
166 |
+
softmax_scale = qkv.shape[-1] ** (-0.5)
|
167 |
+
if qkv.dim() == 5:
|
168 |
+
assert qkv.shape[-3] == 3
|
169 |
+
q, k, v = qkv.unbind(dim=-3)
|
170 |
+
else:
|
171 |
+
assert qkv.dim() == 4
|
172 |
+
assert num_heads_q is not None
|
173 |
+
num_heads_k = (qkv.shape[2] - num_heads_q) // 2
|
174 |
+
assert num_heads_k * 2 + num_heads_q == qkv.shape[2]
|
175 |
+
q, k, v = qkv.split([num_heads_q, num_heads_k, num_heads_k], dim=-2)
|
176 |
+
out, softmax_lse, *rest = _flash_attn_forward(
|
177 |
+
q,
|
178 |
+
k,
|
179 |
+
v,
|
180 |
+
None, None, # k_new, v_new
|
181 |
+
None, # qv
|
182 |
+
None, # out
|
183 |
+
None, None, None, # cu_seqlens_q/k/k_new
|
184 |
+
None, None, # seqused_q/k
|
185 |
+
None, None, # max_seqlen_q/k
|
186 |
+
None, None, None, # page_table, kv_batch_idx, leftpad_k,
|
187 |
+
None, None, None, # rotary_cos/sin, seqlens_rotary
|
188 |
+
q_descale, k_descale, v_descale,
|
189 |
+
softmax_scale,
|
190 |
+
causal=causal,
|
191 |
+
window_size=window_size,
|
192 |
+
attention_chunk=attention_chunk,
|
193 |
+
softcap=softcap,
|
194 |
+
sm_margin=sm_margin,
|
195 |
+
)
|
196 |
+
# ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
|
197 |
+
ctx.save_for_backward(q, k, v, out, softmax_lse)
|
198 |
+
ctx.softmax_scale = softmax_scale
|
199 |
+
ctx.causal = causal
|
200 |
+
ctx.window_size = window_size
|
201 |
+
ctx.attention_chunk = attention_chunk
|
202 |
+
ctx.softcap = softcap
|
203 |
+
ctx.deterministic = deterministic
|
204 |
+
ctx.ndim = qkv.dim()
|
205 |
+
ctx.sm_margin = sm_margin
|
206 |
+
# return out, softmax_lse
|
207 |
+
return out
|
208 |
+
|
209 |
+
@staticmethod
|
210 |
+
def backward(ctx, dout, *args):
|
211 |
+
q, k, v, out, softmax_lse = ctx.saved_tensors
|
212 |
+
assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
|
213 |
+
if ctx.ndim == 5:
|
214 |
+
qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
|
215 |
+
dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
|
216 |
+
dq, dk, dv = dqkv.unbind(dim=-3)
|
217 |
+
else:
|
218 |
+
num_heads_q = q.shape[2]
|
219 |
+
num_heads_k = k.shape[2]
|
220 |
+
qkv_shape = q.shape[:-2] + (num_heads_q + num_heads_k * 2, *q.shape[-1:])
|
221 |
+
dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
|
222 |
+
dq, dk, dv = dqkv.split([num_heads_q, num_heads_k, num_heads_k], dim=-2)
|
223 |
+
_flash_attn_backward(
|
224 |
+
dout,
|
225 |
+
q,
|
226 |
+
k,
|
227 |
+
v,
|
228 |
+
out,
|
229 |
+
softmax_lse,
|
230 |
+
None, None, # cu_seqlens_q, cu_seqlens_k,
|
231 |
+
None, None, # sequed_q, sequed_k,
|
232 |
+
None, None, # max_seqlen_q, max_seqlen_k,
|
233 |
+
dq,
|
234 |
+
dk,
|
235 |
+
dv,
|
236 |
+
ctx.softmax_scale,
|
237 |
+
ctx.causal,
|
238 |
+
ctx.window_size,
|
239 |
+
ctx.softcap,
|
240 |
+
ctx.deterministic,
|
241 |
+
ctx.sm_margin,
|
242 |
+
)
|
243 |
+
dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension
|
244 |
+
return dqkv, None, None, None, None, None, None, None, None, None, None, None
|
245 |
+
|
246 |
+
|
247 |
+
class FlashAttnFunc(torch.autograd.Function):
|
248 |
+
|
249 |
+
@staticmethod
|
250 |
+
def forward(
|
251 |
+
ctx,
|
252 |
+
q,
|
253 |
+
k,
|
254 |
+
v,
|
255 |
+
softmax_scale,
|
256 |
+
causal,
|
257 |
+
qv=None,
|
258 |
+
q_descale=None, k_descale=None, v_descale=None,
|
259 |
+
window_size=(-1, -1),
|
260 |
+
attention_chunk=0,
|
261 |
+
softcap=0.0,
|
262 |
+
num_splits=1,
|
263 |
+
pack_gqa=None,
|
264 |
+
deterministic=False,
|
265 |
+
sm_margin=0,
|
266 |
+
):
|
267 |
+
if softmax_scale is None:
|
268 |
+
softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
|
269 |
+
# out, q, k, v, out_padded, softmax_lse = _flash_attn_forward(
|
270 |
+
out, softmax_lse, *rest = _flash_attn_forward(
|
271 |
+
q,
|
272 |
+
k,
|
273 |
+
v,
|
274 |
+
None, None, # k_new, v_new
|
275 |
+
qv, # qv
|
276 |
+
None, # out
|
277 |
+
None, None, None, # cu_seqlens_q/k/k_new
|
278 |
+
None, None, # seqused_q/k
|
279 |
+
None, None, # max_seqlen_q/k
|
280 |
+
None, None, None, # page_table, kv_batch_idx, leftpad_k,
|
281 |
+
None, None, None, # rotary_cos/sin, seqlens_rotary
|
282 |
+
q_descale, k_descale, v_descale,
|
283 |
+
softmax_scale,
|
284 |
+
causal=causal,
|
285 |
+
window_size=window_size,
|
286 |
+
attention_chunk=attention_chunk,
|
287 |
+
softcap=softcap,
|
288 |
+
num_splits=num_splits,
|
289 |
+
pack_gqa=pack_gqa,
|
290 |
+
sm_margin=sm_margin,
|
291 |
+
)
|
292 |
+
# ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
|
293 |
+
ctx.save_for_backward(q, k, v, out, softmax_lse)
|
294 |
+
ctx.softmax_scale = softmax_scale
|
295 |
+
ctx.causal = causal
|
296 |
+
ctx.window_size = window_size
|
297 |
+
ctx.attention_chunk = attention_chunk
|
298 |
+
ctx.softcap = softcap
|
299 |
+
ctx.deterministic = deterministic
|
300 |
+
ctx.sm_margin = sm_margin
|
301 |
+
return out, softmax_lse
|
302 |
+
|
303 |
+
@staticmethod
|
304 |
+
def backward(ctx, dout, *args):
|
305 |
+
q, k, v, out, softmax_lse = ctx.saved_tensors
|
306 |
+
assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
|
307 |
+
dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
|
308 |
+
_flash_attn_backward(
|
309 |
+
dout,
|
310 |
+
q,
|
311 |
+
k,
|
312 |
+
v,
|
313 |
+
out,
|
314 |
+
softmax_lse,
|
315 |
+
None, None, # cu_seqlens_q, cu_seqlens_k,
|
316 |
+
None, None, # sequed_q, sequed_k,
|
317 |
+
None, None, # max_seqlen_q, max_seqlen_k,
|
318 |
+
dq,
|
319 |
+
dk,
|
320 |
+
dv,
|
321 |
+
ctx.softmax_scale,
|
322 |
+
ctx.causal,
|
323 |
+
ctx.window_size,
|
324 |
+
ctx.softcap,
|
325 |
+
ctx.deterministic,
|
326 |
+
ctx.sm_margin,
|
327 |
+
)
|
328 |
+
dq = dq[..., : q.shape[-1]] # We could have padded the head dimension
|
329 |
+
dk = dk[..., : k.shape[-1]]
|
330 |
+
dv = dv[..., : v.shape[-1]]
|
331 |
+
return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None
|
332 |
+
|
333 |
+
|
334 |
+
class FlashAttnVarlenFunc(torch.autograd.Function):
|
335 |
+
|
336 |
+
@staticmethod
|
337 |
+
def forward(
|
338 |
+
ctx,
|
339 |
+
q,
|
340 |
+
k,
|
341 |
+
v,
|
342 |
+
cu_seqlens_q,
|
343 |
+
cu_seqlens_k,
|
344 |
+
seqused_q,
|
345 |
+
seqused_k,
|
346 |
+
max_seqlen_q,
|
347 |
+
max_seqlen_k,
|
348 |
+
softmax_scale,
|
349 |
+
causal,
|
350 |
+
qv=None,
|
351 |
+
q_descale=None, k_descale=None, v_descale=None,
|
352 |
+
window_size=(-1, -1),
|
353 |
+
attention_chunk=0,
|
354 |
+
softcap=0.0,
|
355 |
+
num_splits=1,
|
356 |
+
pack_gqa=None,
|
357 |
+
deterministic=False,
|
358 |
+
sm_margin=0,
|
359 |
+
):
|
360 |
+
if softmax_scale is None:
|
361 |
+
softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
|
362 |
+
# out, q, k, v, out_padded, softmax_lse = _flash_attn_varlen_forward(
|
363 |
+
out, softmax_lse, *rest = _flash_attn_forward(
|
364 |
+
q,
|
365 |
+
k,
|
366 |
+
v,
|
367 |
+
None, None, # k_new, v_new
|
368 |
+
qv, # qv
|
369 |
+
None, # out
|
370 |
+
cu_seqlens_q,
|
371 |
+
cu_seqlens_k,
|
372 |
+
None, # cu_seqlens_k_new
|
373 |
+
seqused_q,
|
374 |
+
seqused_k,
|
375 |
+
max_seqlen_q,
|
376 |
+
max_seqlen_k,
|
377 |
+
None, None, None, # page_table, kv_batch_idx, leftpad_k,
|
378 |
+
None, None, None, # rotary_cos/sin, seqlens_rotary
|
379 |
+
q_descale, k_descale, v_descale,
|
380 |
+
softmax_scale,
|
381 |
+
causal=causal,
|
382 |
+
window_size=window_size,
|
383 |
+
attention_chunk=attention_chunk,
|
384 |
+
softcap=softcap,
|
385 |
+
num_splits=num_splits,
|
386 |
+
pack_gqa=pack_gqa,
|
387 |
+
sm_margin=sm_margin,
|
388 |
+
)
|
389 |
+
# ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
|
390 |
+
ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
|
391 |
+
ctx.max_seqlen_q = max_seqlen_q
|
392 |
+
ctx.max_seqlen_k = max_seqlen_k
|
393 |
+
ctx.softmax_scale = softmax_scale
|
394 |
+
ctx.causal = causal
|
395 |
+
ctx.window_size = window_size
|
396 |
+
ctx.attention_chunk = attention_chunk
|
397 |
+
ctx.softcap = softcap
|
398 |
+
ctx.deterministic = deterministic
|
399 |
+
ctx.sm_margin = sm_margin
|
400 |
+
return out, softmax_lse
|
401 |
+
|
402 |
+
@staticmethod
|
403 |
+
def backward(ctx, dout, *args):
|
404 |
+
q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = ctx.saved_tensors
|
405 |
+
assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
|
406 |
+
dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
|
407 |
+
_flash_attn_backward(
|
408 |
+
dout,
|
409 |
+
q,
|
410 |
+
k,
|
411 |
+
v,
|
412 |
+
out,
|
413 |
+
softmax_lse,
|
414 |
+
cu_seqlens_q,
|
415 |
+
cu_seqlens_k,
|
416 |
+
seqused_q,
|
417 |
+
seqused_k,
|
418 |
+
ctx.max_seqlen_q,
|
419 |
+
ctx.max_seqlen_k,
|
420 |
+
dq,
|
421 |
+
dk,
|
422 |
+
dv,
|
423 |
+
ctx.softmax_scale,
|
424 |
+
ctx.causal,
|
425 |
+
ctx.window_size,
|
426 |
+
ctx.softcap,
|
427 |
+
ctx.deterministic,
|
428 |
+
ctx.sm_margin,
|
429 |
+
)
|
430 |
+
dq = dq[..., : q.shape[-1]] # We could have padded the head dimension
|
431 |
+
dk = dk[..., : k.shape[-1]]
|
432 |
+
dv = dv[..., : v.shape[-1]]
|
433 |
+
return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None
|
434 |
+
|
435 |
+
|
436 |
+
def flash_attn_qkvpacked_func(
|
437 |
+
qkv,
|
438 |
+
softmax_scale=None,
|
439 |
+
causal=False,
|
440 |
+
q_descale=None, k_descale=None, v_descale=None,
|
441 |
+
window_size=(-1, -1),
|
442 |
+
attention_chunk=0,
|
443 |
+
softcap=0.0,
|
444 |
+
deterministic=False,
|
445 |
+
num_heads_q=None,
|
446 |
+
sm_margin=0,
|
447 |
+
):
|
448 |
+
"""dropout_p should be set to 0.0 during evaluation
|
449 |
+
If Q, K, V are already stacked into 1 tensor, this function will be faster than
|
450 |
+
calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
|
451 |
+
of the gradients of Q, K, V.
|
452 |
+
For multi-query and grouped-query attention (MQA/GQA), please see
|
453 |
+
flash_attn_kvpacked_func and flash_attn_func.
|
454 |
+
|
455 |
+
If window_size != (-1, -1), implements sliding window local attention. Query at position i
|
456 |
+
will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.
|
457 |
+
|
458 |
+
Arguments:
|
459 |
+
qkv: (batch_size, seqlen, 3, nheads, headdim)
|
460 |
+
dropout_p: float. Dropout probability.
|
461 |
+
softmax_scale: float. The scaling of QK^T before applying softmax.
|
462 |
+
Default to 1 / sqrt(headdim).
|
463 |
+
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
|
464 |
+
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
|
465 |
+
softcap: float. Anything > 0 activates softcapping attention.
|
466 |
+
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to
|
467 |
+
the attention score of query i and key j.
|
468 |
+
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
|
469 |
+
which is slightly slower and uses more memory. The forward pass is always deterministic.
|
470 |
+
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
|
471 |
+
testing only. The returned probabilities are not guaranteed to be correct
|
472 |
+
(they might not have the right scaling).
|
473 |
+
Return:
|
474 |
+
out: (batch_size, seqlen, nheads, headdim).
|
475 |
+
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
|
476 |
+
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
|
477 |
+
normalization factor).
|
478 |
+
S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
|
479 |
+
The output of softmax (possibly with different scaling). It also encodes the dropout
|
480 |
+
pattern (negative means that location was dropped, nonnegative means it was kept).
|
481 |
+
"""
|
482 |
+
return FlashAttnQKVPackedFunc.apply(
|
483 |
+
qkv,
|
484 |
+
softmax_scale,
|
485 |
+
causal,
|
486 |
+
q_descale, k_descale, v_descale,
|
487 |
+
window_size,
|
488 |
+
attention_chunk,
|
489 |
+
softcap,
|
490 |
+
deterministic,
|
491 |
+
num_heads_q,
|
492 |
+
sm_margin,
|
493 |
+
)
|
494 |
+
|
495 |
+
|
496 |
+
def flash_attn_func(
|
497 |
+
q,
|
498 |
+
k,
|
499 |
+
v,
|
500 |
+
softmax_scale=None,
|
501 |
+
causal=False,
|
502 |
+
qv=None,
|
503 |
+
q_descale=None, k_descale=None, v_descale=None,
|
504 |
+
window_size=(-1, -1),
|
505 |
+
attention_chunk=0,
|
506 |
+
softcap=0.0,
|
507 |
+
num_splits=1,
|
508 |
+
pack_gqa=None,
|
509 |
+
deterministic=False,
|
510 |
+
sm_margin=0,
|
511 |
+
):
|
512 |
+
"""dropout_p should be set to 0.0 during evaluation
|
513 |
+
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
|
514 |
+
than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
|
515 |
+
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
|
516 |
+
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
|
517 |
+
|
518 |
+
If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
|
519 |
+
For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
|
520 |
+
1 1 1 1 0
|
521 |
+
1 1 1 1 1
|
522 |
+
If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
|
523 |
+
0 0
|
524 |
+
0 0
|
525 |
+
0 0
|
526 |
+
1 0
|
527 |
+
1 1
|
528 |
+
If the row of the mask is all zero, the output will be zero.
|
529 |
+
|
530 |
+
If window_size != (-1, -1), implements sliding window local attention. Query at position i
|
531 |
+
will only attend to keys between
|
532 |
+
[i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
|
533 |
+
|
534 |
+
Arguments:
|
535 |
+
q: (batch_size, seqlen, nheads, headdim)
|
536 |
+
k: (batch_size, seqlen, nheads_k, headdim)
|
537 |
+
v: (batch_size, seqlen, nheads_k, headdim)
|
538 |
+
dropout_p: float. Dropout probability.
|
539 |
+
softmax_scale: float. The scaling of QK^T before applying softmax.
|
540 |
+
Default to 1 / sqrt(headdim).
|
541 |
+
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
|
542 |
+
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
|
543 |
+
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
|
544 |
+
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
|
545 |
+
is added to the attention score of query i and key j.
|
546 |
+
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
|
547 |
+
which is slightly slower and uses more memory. The forward pass is always deterministic.
|
548 |
+
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
|
549 |
+
testing only. The returned probabilities are not guaranteed to be correct
|
550 |
+
(they might not have the right scaling).
|
551 |
+
Return:
|
552 |
+
out: (batch_size, seqlen, nheads, headdim).
|
553 |
+
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
|
554 |
+
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
|
555 |
+
normalization factor).
|
556 |
+
"""
|
557 |
+
return FlashAttnFunc.apply(
|
558 |
+
q,
|
559 |
+
k,
|
560 |
+
v,
|
561 |
+
softmax_scale,
|
562 |
+
causal,
|
563 |
+
qv,
|
564 |
+
q_descale, k_descale, v_descale,
|
565 |
+
window_size,
|
566 |
+
attention_chunk,
|
567 |
+
softcap,
|
568 |
+
num_splits,
|
569 |
+
pack_gqa,
|
570 |
+
deterministic,
|
571 |
+
sm_margin,
|
572 |
+
)
|
573 |
+
|
574 |
+
|
575 |
+
def flash_attn_varlen_func(
|
576 |
+
q,
|
577 |
+
k,
|
578 |
+
v,
|
579 |
+
cu_seqlens_q,
|
580 |
+
cu_seqlens_k,
|
581 |
+
max_seqlen_q,
|
582 |
+
max_seqlen_k,
|
583 |
+
seqused_q=None,
|
584 |
+
seqused_k=None,
|
585 |
+
softmax_scale=None,
|
586 |
+
causal=False,
|
587 |
+
qv=None,
|
588 |
+
q_descale=None, k_descale=None, v_descale=None,
|
589 |
+
window_size=(-1, -1),
|
590 |
+
attention_chunk=0,
|
591 |
+
softcap=0.0,
|
592 |
+
num_splits=1,
|
593 |
+
pack_gqa=None,
|
594 |
+
deterministic=False,
|
595 |
+
sm_margin=0,
|
596 |
+
):
|
597 |
+
return FlashAttnVarlenFunc.apply(
|
598 |
+
q,
|
599 |
+
k,
|
600 |
+
v,
|
601 |
+
cu_seqlens_q,
|
602 |
+
cu_seqlens_k,
|
603 |
+
seqused_q,
|
604 |
+
seqused_k,
|
605 |
+
max_seqlen_q,
|
606 |
+
max_seqlen_k,
|
607 |
+
softmax_scale,
|
608 |
+
causal,
|
609 |
+
qv,
|
610 |
+
q_descale, k_descale, v_descale,
|
611 |
+
window_size,
|
612 |
+
attention_chunk,
|
613 |
+
softcap,
|
614 |
+
num_splits,
|
615 |
+
pack_gqa,
|
616 |
+
deterministic,
|
617 |
+
sm_margin,
|
618 |
+
)
|
619 |
+
|
620 |
+
|
621 |
+
def flash_attn_combine(out_partial, lse_partial, out=None, out_dtype=None):
|
622 |
+
return flash_attn_3_cuda.fwd_combine(out_partial, lse_partial, out, out_dtype)
|
623 |
+
|
624 |
+
|
625 |
+
def flash_attn_with_kvcache(
|
626 |
+
q,
|
627 |
+
k_cache,
|
628 |
+
v_cache,
|
629 |
+
k=None,
|
630 |
+
v=None,
|
631 |
+
qv=None,
|
632 |
+
rotary_cos=None,
|
633 |
+
rotary_sin=None,
|
634 |
+
cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
|
635 |
+
cache_batch_idx: Optional[torch.Tensor] = None,
|
636 |
+
cache_leftpad: Optional[torch.Tensor] = None,
|
637 |
+
page_table: Optional[torch.Tensor] = None,
|
638 |
+
cu_seqlens_q: Optional[torch.Tensor] = None,
|
639 |
+
cu_seqlens_k_new: Optional[torch.Tensor] = None,
|
640 |
+
max_seqlen_q: Optional[int] = None,
|
641 |
+
rotary_seqlens: Optional[torch.Tensor] = None,
|
642 |
+
q_descale: Optional[torch.Tensor] = None,
|
643 |
+
k_descale: Optional[torch.Tensor] = None,
|
644 |
+
v_descale: Optional[torch.Tensor] = None,
|
645 |
+
softmax_scale=None,
|
646 |
+
causal=False,
|
647 |
+
window_size=(-1, -1), # -1 means infinite context window
|
648 |
+
attention_chunk=0,
|
649 |
+
softcap=0.0, # 0.0 means deactivated
|
650 |
+
rotary_interleaved=True,
|
651 |
+
scheduler_metadata=None,
|
652 |
+
num_splits=0, # Can be tuned for speed
|
653 |
+
pack_gqa=None, # Can be tuned for speed
|
654 |
+
sm_margin=0, # Can be tuned if some SMs are used for communication
|
655 |
+
return_softmax_lse=False,
|
656 |
+
):
|
657 |
+
"""
|
658 |
+
If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
|
659 |
+
k and v. This is useful for incremental decoding: you can pass in the cached keys/values from
|
660 |
+
the previous step, and update them with the new keys/values from the current step, and do
|
661 |
+
attention with the updated cache, all in 1 kernel.
|
662 |
+
|
663 |
+
If you pass in k / v, you must make sure that the cache is large enough to hold the new values.
|
664 |
+
For example, the KV cache could be pre-allocated with the max sequence length, and you can use
|
665 |
+
cache_seqlens to keep track of the current sequence lengths of each sequence in the batch.
|
666 |
+
|
667 |
+
Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be
|
668 |
+
rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
|
669 |
+
If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos
|
670 |
+
and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
|
671 |
+
If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at
|
672 |
+
indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens).
|
673 |
+
|
674 |
+
See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function.
|
675 |
+
|
676 |
+
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
|
677 |
+
than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
|
678 |
+
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
|
679 |
+
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
|
680 |
+
|
681 |
+
If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
|
682 |
+
For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
|
683 |
+
1 1 1 1 0
|
684 |
+
1 1 1 1 1
|
685 |
+
If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
|
686 |
+
0 0
|
687 |
+
0 0
|
688 |
+
0 0
|
689 |
+
1 0
|
690 |
+
1 1
|
691 |
+
If the row of the mask is all zero, the output will be zero.
|
692 |
+
|
693 |
+
If window_size != (-1, -1), implements sliding window local attention. Query at position i
|
694 |
+
will only attend to keys between
|
695 |
+
[i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
|
696 |
+
|
697 |
+
Note: Does not support backward pass.
|
698 |
+
|
699 |
+
Arguments:
|
700 |
+
q: (batch_size, seqlen, nheads, headdim)
|
701 |
+
k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no page_table,
|
702 |
+
or (num_blocks, page_block_size, nheads_k, headdim) if there's a page_table (i.e. paged KV cache)
|
703 |
+
page_block_size must be a multiple of 256.
|
704 |
+
v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim_v) if there's no page_table,
|
705 |
+
or (num_blocks, page_block_size, nheads_k, headdim_v) if there's a page_table (i.e. paged KV cache)
|
706 |
+
k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
|
707 |
+
k with k_cache, starting at the indices specified by cache_seqlens.
|
708 |
+
v [optional]: (batch_size, seqlen_new, nheads_k, headdim_v). Similar to k.
|
709 |
+
qv [optional]: (batch_size, seqlen, nheads, headdim_v)
|
710 |
+
rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding
|
711 |
+
to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16.
|
712 |
+
rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
|
713 |
+
cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
|
714 |
+
KV cache.
|
715 |
+
cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.
|
716 |
+
If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].
|
717 |
+
If the indices are not distinct, and k and v are provided, the values updated in the cache
|
718 |
+
might come from any of the duplicate indices.
|
719 |
+
cache_leftpad: (batch_size,), dtype torch.int32. The index that the KV cache starts. If None, assume 0.
|
720 |
+
page_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32.
|
721 |
+
softmax_scale: float. The scaling of QK^T before applying softmax.
|
722 |
+
Default to 1 / sqrt(headdim).
|
723 |
+
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
|
724 |
+
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
|
725 |
+
softcap: float. Anything > 0 activates softcapping attention.
|
726 |
+
rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
|
727 |
+
If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
|
728 |
+
rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
|
729 |
+
(i.e. GPT-NeoX style).
|
730 |
+
num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
|
731 |
+
If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
|
732 |
+
to automatically determine the number of splits.
|
733 |
+
Don't change this unless you know what you are doing.
|
734 |
+
return_softmax_lse: bool. Whether to return the logsumexp of the attention scores.
|
735 |
+
|
736 |
+
Return:
|
737 |
+
out: (batch_size, seqlen, nheads, headdim).
|
738 |
+
softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The
|
739 |
+
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
|
740 |
+
normalization factor).
|
741 |
+
"""
|
742 |
+
assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
|
743 |
+
assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
|
744 |
+
if softmax_scale is None:
|
745 |
+
softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
|
746 |
+
if cache_seqlens is not None and isinstance(cache_seqlens, int):
|
747 |
+
cache_seqlens = torch.full(
|
748 |
+
(k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device
|
749 |
+
)
|
750 |
+
cache_seqlens = maybe_contiguous(cache_seqlens)
|
751 |
+
out, softmax_lse, *rest = _flash_attn_forward(
|
752 |
+
q,
|
753 |
+
k_cache,
|
754 |
+
v_cache,
|
755 |
+
k,
|
756 |
+
v,
|
757 |
+
qv,
|
758 |
+
None, # out
|
759 |
+
cu_seqlens_q,
|
760 |
+
None, # cu_seqlens_k
|
761 |
+
cu_seqlens_k_new,
|
762 |
+
None, # seqused_q
|
763 |
+
cache_seqlens,
|
764 |
+
max_seqlen_q,
|
765 |
+
None, # max_seqlen_k
|
766 |
+
page_table,
|
767 |
+
cache_batch_idx,
|
768 |
+
cache_leftpad,
|
769 |
+
rotary_cos,
|
770 |
+
rotary_sin,
|
771 |
+
rotary_seqlens,
|
772 |
+
q_descale, k_descale, v_descale,
|
773 |
+
softmax_scale,
|
774 |
+
causal=causal,
|
775 |
+
window_size=window_size,
|
776 |
+
attention_chunk=attention_chunk,
|
777 |
+
softcap=softcap,
|
778 |
+
rotary_interleaved=rotary_interleaved,
|
779 |
+
scheduler_metadata=scheduler_metadata,
|
780 |
+
num_splits=num_splits,
|
781 |
+
pack_gqa=pack_gqa,
|
782 |
+
sm_margin=sm_margin,
|
783 |
+
)
|
784 |
+
# return (out, softmax_lse) if return_softmax_lse else out
|
785 |
+
return (out, softmax_lse, *rest) if return_softmax_lse else out
|
786 |
+
|
787 |
+
|
788 |
+
def get_scheduler_metadata(
|
789 |
+
batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim,
|
790 |
+
cache_seqlens: torch.Tensor,
|
791 |
+
qkv_dtype=torch.bfloat16,
|
792 |
+
headdim_v=None,
|
793 |
+
cu_seqlens_q: Optional[torch.Tensor] = None,
|
794 |
+
cu_seqlens_k_new: Optional[torch.Tensor] = None,
|
795 |
+
cache_leftpad: Optional[torch.Tensor] = None,
|
796 |
+
page_size: Optional[int] = None,
|
797 |
+
max_seqlen_k_new=0,
|
798 |
+
causal=False,
|
799 |
+
window_size=(-1, -1), # -1 means infinite context window
|
800 |
+
attention_chunk=0,
|
801 |
+
has_softcap=False,
|
802 |
+
num_splits=0, # Can be tuned for speed
|
803 |
+
pack_gqa=None, # Can be tuned for speed
|
804 |
+
sm_margin=0, # Can be tuned if some SMs are used for communication
|
805 |
+
):
|
806 |
+
cache_seqlens = maybe_contiguous(cache_seqlens)
|
807 |
+
if headdim_v is None:
|
808 |
+
headdim_v = headdim
|
809 |
+
scheduler_metadata = flash_attn_3_cuda.get_scheduler_metadata(
|
810 |
+
batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim, headdim_v,
|
811 |
+
qkv_dtype,
|
812 |
+
cache_seqlens,
|
813 |
+
cu_seqlens_q,
|
814 |
+
None, # cu_seqlens_k
|
815 |
+
cu_seqlens_k_new,
|
816 |
+
None, # seqused_q
|
817 |
+
cache_leftpad,
|
818 |
+
page_size,
|
819 |
+
max_seqlen_k_new,
|
820 |
+
causal,
|
821 |
+
window_size[0], window_size[1],
|
822 |
+
attention_chunk,
|
823 |
+
has_softcap,
|
824 |
+
num_splits,
|
825 |
+
pack_gqa,
|
826 |
+
sm_margin,
|
827 |
+
)
|
828 |
+
return scheduler_metadata
|