danieldk HF Staff commited on
Commit
ee60723
·
1 Parent(s): 28fbd26

Build for Torch 2.8

Browse files
Files changed (21) hide show
  1. build/torch28-cxx11-cu126-x86_64-linux/vllm_flash_attn3/__init__.py +17 -0
  2. build/torch28-cxx11-cu126-x86_64-linux/vllm_flash_attn3/__pycache__/__init__.cpython-313.pyc +0 -0
  3. build/torch28-cxx11-cu126-x86_64-linux/vllm_flash_attn3/__pycache__/_ops.cpython-313.pyc +0 -0
  4. build/torch28-cxx11-cu126-x86_64-linux/vllm_flash_attn3/__pycache__/flash_attn_interface.cpython-313.pyc +0 -0
  5. build/torch28-cxx11-cu126-x86_64-linux/vllm_flash_attn3/_ops.py +9 -0
  6. build/torch28-cxx11-cu126-x86_64-linux/vllm_flash_attn3/_vllm_flash_attn3_28fbd26_dirty.abi3.so +3 -0
  7. build/torch28-cxx11-cu126-x86_64-linux/vllm_flash_attn3/flash_attn_interface.py +815 -0
  8. build/torch28-cxx11-cu128-x86_64-linux/vllm_flash_attn3/__init__.py +17 -0
  9. build/torch28-cxx11-cu128-x86_64-linux/vllm_flash_attn3/__pycache__/__init__.cpython-313.pyc +0 -0
  10. build/torch28-cxx11-cu128-x86_64-linux/vllm_flash_attn3/__pycache__/_ops.cpython-313.pyc +0 -0
  11. build/torch28-cxx11-cu128-x86_64-linux/vllm_flash_attn3/__pycache__/flash_attn_interface.cpython-313.pyc +0 -0
  12. build/torch28-cxx11-cu128-x86_64-linux/vllm_flash_attn3/_ops.py +9 -0
  13. build/torch28-cxx11-cu128-x86_64-linux/vllm_flash_attn3/_vllm_flash_attn3_28fbd26_dirty.abi3.so +3 -0
  14. build/torch28-cxx11-cu128-x86_64-linux/vllm_flash_attn3/flash_attn_interface.py +815 -0
  15. build/torch28-cxx11-cu129-x86_64-linux/vllm_flash_attn3/__init__.py +17 -0
  16. build/torch28-cxx11-cu129-x86_64-linux/vllm_flash_attn3/__pycache__/__init__.cpython-313.pyc +0 -0
  17. build/torch28-cxx11-cu129-x86_64-linux/vllm_flash_attn3/__pycache__/_ops.cpython-313.pyc +0 -0
  18. build/torch28-cxx11-cu129-x86_64-linux/vllm_flash_attn3/__pycache__/flash_attn_interface.cpython-313.pyc +0 -0
  19. build/torch28-cxx11-cu129-x86_64-linux/vllm_flash_attn3/_ops.py +9 -0
  20. build/torch28-cxx11-cu129-x86_64-linux/vllm_flash_attn3/_vllm_flash_attn3_28fbd26_dirty.abi3.so +3 -0
  21. build/torch28-cxx11-cu129-x86_64-linux/vllm_flash_attn3/flash_attn_interface.py +815 -0
build/torch28-cxx11-cu126-x86_64-linux/vllm_flash_attn3/__init__.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .flash_attn_interface import (
2
+ flash_attn_combine,
3
+ flash_attn_func,
4
+ flash_attn_qkvpacked_func,
5
+ flash_attn_varlen_func,
6
+ flash_attn_with_kvcache,
7
+ get_scheduler_metadata,
8
+ )
9
+
10
+ __all__ = [
11
+ "flash_attn_combine",
12
+ "flash_attn_func",
13
+ "flash_attn_qkvpacked_func",
14
+ "flash_attn_varlen_func",
15
+ "flash_attn_with_kvcache",
16
+ "get_scheduler_metadata",
17
+ ]
build/torch28-cxx11-cu126-x86_64-linux/vllm_flash_attn3/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (442 Bytes). View file
 
build/torch28-cxx11-cu126-x86_64-linux/vllm_flash_attn3/__pycache__/_ops.cpython-313.pyc ADDED
Binary file (556 Bytes). View file
 
build/torch28-cxx11-cu126-x86_64-linux/vllm_flash_attn3/__pycache__/flash_attn_interface.cpython-313.pyc ADDED
Binary file (25.6 kB). View file
 
build/torch28-cxx11-cu126-x86_64-linux/vllm_flash_attn3/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _vllm_flash_attn3_28fbd26_dirty
3
+ ops = torch.ops._vllm_flash_attn3_28fbd26_dirty
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_vllm_flash_attn3_28fbd26_dirty::{op_name}"
build/torch28-cxx11-cu126-x86_64-linux/vllm_flash_attn3/_vllm_flash_attn3_28fbd26_dirty.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:29c8dc9754cba50e7eeacfbff19710fedb4152119c57a0c5afa00b036480fe6f
3
+ size 915245760
build/torch28-cxx11-cu126-x86_64-linux/vllm_flash_attn3/flash_attn_interface.py ADDED
@@ -0,0 +1,815 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Tri Dao.
2
+
3
+ from typing import Optional, Union
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ # isort: off
9
+ # We need to import the CUDA kernels after importing torch
10
+ from ._ops import ops
11
+
12
+ # isort: on
13
+
14
+
15
+ def maybe_contiguous(x):
16
+ return x.contiguous() if x is not None and x.stride(-1) != 1 else x
17
+
18
+
19
+ def _flash_attn_forward(
20
+ q,
21
+ k,
22
+ v,
23
+ k_new,
24
+ v_new,
25
+ qv,
26
+ out,
27
+ cu_seqlens_q,
28
+ cu_seqlens_k,
29
+ cu_seqlens_k_new,
30
+ seqused_q,
31
+ seqused_k,
32
+ max_seqlen_q,
33
+ max_seqlen_k,
34
+ page_table,
35
+ kv_batch_idx,
36
+ leftpad_k,
37
+ rotary_cos,
38
+ rotary_sin,
39
+ seqlens_rotary,
40
+ q_descale,
41
+ k_descale,
42
+ v_descale,
43
+ softmax_scale,
44
+ causal,
45
+ window_size=(-1, -1),
46
+ softcap=0.0,
47
+ rotary_interleaved=True,
48
+ scheduler_metadata=None,
49
+ num_splits=1,
50
+ pack_gqa=None,
51
+ sm_margin=0,
52
+ s_aux=None):
53
+ q, k, k_new, v_new = [maybe_contiguous(x) for x in (q, k, k_new, v_new)]
54
+ v = v.contiguous() if v.stride(-1) != 1 and v.stride(-3) != 1 else v
55
+ cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new = [
56
+ maybe_contiguous(x) for x in (cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new)
57
+ ]
58
+ seqused_q, seqused_k = [maybe_contiguous(x) for x in (seqused_q, seqused_k)]
59
+ page_table, kv_batch_idx, leftpad_k = [
60
+ maybe_contiguous(x) for x in (page_table, kv_batch_idx, leftpad_k)
61
+ ]
62
+ rotary_cos, rotary_sin = [maybe_contiguous(x) for x in (rotary_cos, rotary_sin)]
63
+ seqlens_rotary = maybe_contiguous(seqlens_rotary)
64
+ out, softmax_lse, *rest = ops.fwd(
65
+ q,
66
+ k,
67
+ v,
68
+ k_new,
69
+ v_new,
70
+ qv,
71
+ out,
72
+ cu_seqlens_q,
73
+ cu_seqlens_k,
74
+ cu_seqlens_k_new,
75
+ seqused_q,
76
+ seqused_k,
77
+ max_seqlen_q,
78
+ max_seqlen_k,
79
+ page_table,
80
+ kv_batch_idx,
81
+ leftpad_k,
82
+ rotary_cos,
83
+ rotary_sin,
84
+ seqlens_rotary,
85
+ q_descale,
86
+ k_descale,
87
+ v_descale,
88
+ softmax_scale,
89
+ causal,
90
+ window_size[0],
91
+ window_size[1],
92
+ softcap,
93
+ rotary_interleaved,
94
+ scheduler_metadata,
95
+ num_splits,
96
+ pack_gqa,
97
+ sm_margin,
98
+ s_aux
99
+ )
100
+ return out, softmax_lse, *rest
101
+
102
+
103
+ def _flash_attn_backward(
104
+ dout,
105
+ q,
106
+ k,
107
+ v,
108
+ out,
109
+ softmax_lse,
110
+ cu_seqlens_q,
111
+ cu_seqlens_k,
112
+ sequed_q,
113
+ sequed_k,
114
+ max_seqlen_q,
115
+ max_seqlen_k,
116
+ dq,
117
+ dk,
118
+ dv,
119
+ softmax_scale,
120
+ causal,
121
+ window_size=(-1, -1),
122
+ softcap=0.0,
123
+ deterministic=False,
124
+ sm_margin=0,
125
+ ):
126
+ # dq, dk, dv are allocated by us so they should already be contiguous
127
+ dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
128
+ dq, dk, dv, softmax_d, *rest = ops.bwd(
129
+ dout,
130
+ q,
131
+ k,
132
+ v,
133
+ out,
134
+ softmax_lse,
135
+ dq,
136
+ dk,
137
+ dv,
138
+ cu_seqlens_q,
139
+ cu_seqlens_k,
140
+ sequed_q,
141
+ sequed_k,
142
+ max_seqlen_q,
143
+ max_seqlen_k,
144
+ softmax_scale,
145
+ causal,
146
+ window_size[0],
147
+ window_size[1],
148
+ softcap,
149
+ deterministic,
150
+ sm_margin,
151
+ )
152
+ return dq, dk, dv, softmax_d
153
+
154
+
155
+ class FlashAttnQKVPackedFunc(torch.autograd.Function):
156
+ @staticmethod
157
+ def forward(
158
+ ctx,
159
+ qkv,
160
+ softmax_scale,
161
+ causal,
162
+ q_descale=None, k_descale=None, v_descale=None,
163
+ window_size=(-1, -1),
164
+ softcap=0.0,
165
+ deterministic=False,
166
+ num_heads_q=None,
167
+ ):
168
+ if softmax_scale is None:
169
+ softmax_scale = qkv.shape[-1] ** (-0.5)
170
+ if qkv.dim() == 5:
171
+ assert qkv.shape[-3] == 3
172
+ q, k, v = qkv.unbind(dim=-3)
173
+ else:
174
+ assert qkv.dim() == 4
175
+ assert num_heads_q is not None
176
+ num_heads_k = (qkv.shape[2] - num_heads_q) // 2
177
+ assert num_heads_k * 2 + num_heads_q == qkv.shape[2]
178
+ q, k, v = qkv.split([num_heads_q, num_heads_k, num_heads_k], dim=-2)
179
+ out, softmax_lse, *rest = _flash_attn_forward(
180
+ q,
181
+ k,
182
+ v,
183
+ None, None, # k_new, v_new
184
+ None, # qv
185
+ None, # out
186
+ None, None, None, # cu_seqlens_q/k/k_new
187
+ None, None, # seqused_q/k
188
+ None, None, # max_seqlen_q/k
189
+ None, None, None, # page_table, kv_batch_idx, leftpad_k,
190
+ None, None, None, # rotary_cos/sin, seqlens_rotary
191
+ q_descale, k_descale, v_descale,
192
+ softmax_scale,
193
+ causal=causal,
194
+ window_size=window_size,
195
+ softcap=softcap,
196
+ )
197
+ # ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
198
+ ctx.save_for_backward(q, k, v, out, softmax_lse)
199
+ ctx.softmax_scale = softmax_scale
200
+ ctx.causal = causal
201
+ ctx.window_size = window_size
202
+ ctx.softcap = softcap
203
+ ctx.deterministic = deterministic
204
+ ctx.ndim = qkv.dim()
205
+ # return out, softmax_lse
206
+ return out
207
+
208
+ @staticmethod
209
+ def backward(ctx, dout, *args):
210
+ q, k, v, out, softmax_lse = ctx.saved_tensors
211
+ if ctx.ndim == 5:
212
+ qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
213
+ dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
214
+ dq, dk, dv = dqkv.unbind(dim=-3)
215
+ else:
216
+ num_heads_q = q.shape[2]
217
+ num_heads_k = k.shape[2]
218
+ qkv_shape = q.shape[:-2] + (num_heads_q + num_heads_k * 2, *q.shape[-1:])
219
+ dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
220
+ dq, dk, dv = dqkv.split([num_heads_q, num_heads_k, num_heads_k], dim=-2)
221
+ _flash_attn_backward(
222
+ dout,
223
+ q,
224
+ k,
225
+ v,
226
+ out,
227
+ softmax_lse,
228
+ None, None, # cu_seqlens_q, cu_seqlens_k,
229
+ None, None, # sequed_q, sequed_k,
230
+ None, None, # max_seqlen_q, max_seqlen_k,
231
+ dq,
232
+ dk,
233
+ dv,
234
+ ctx.softmax_scale,
235
+ ctx.causal,
236
+ ctx.window_size,
237
+ ctx.softcap,
238
+ ctx.deterministic,
239
+ )
240
+ dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension
241
+ return dqkv, None, None, None, None, None, None, None, None, None, None
242
+
243
+
244
+ class FlashAttnFunc(torch.autograd.Function):
245
+
246
+ @staticmethod
247
+ def forward(
248
+ ctx,
249
+ q,
250
+ k,
251
+ v,
252
+ softmax_scale,
253
+ causal,
254
+ qv=None,
255
+ q_descale=None, k_descale=None, v_descale=None,
256
+ window_size=(-1, -1),
257
+ softcap=0.0,
258
+ num_splits=1,
259
+ pack_gqa=None,
260
+ deterministic=False,
261
+ sm_margin=0,
262
+ s_aux=None,
263
+ ):
264
+ if softmax_scale is None:
265
+ softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
266
+ # out, q, k, v, out_padded, softmax_lse = _flash_attn_forward(
267
+ out, softmax_lse, *rest = _flash_attn_forward(
268
+ q,
269
+ k,
270
+ v,
271
+ None, None, # k_new, v_new
272
+ qv, # qv
273
+ None, # out
274
+ None, None, None, # cu_seqlens_q/k/k_new
275
+ None, None, # seqused_q/k
276
+ None, None, # max_seqlen_q/k
277
+ None, None, None, # page_table, kv_batch_idx, leftpad_k,
278
+ None, None, None, # rotary_cos/sin, seqlens_rotary
279
+ q_descale, k_descale, v_descale,
280
+ softmax_scale,
281
+ causal=causal,
282
+ window_size=window_size,
283
+ softcap=softcap,
284
+ num_splits=num_splits,
285
+ pack_gqa=pack_gqa,
286
+ sm_margin=sm_margin,
287
+ s_aux=s_aux,
288
+ )
289
+ # ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
290
+ ctx.save_for_backward(q, k, v, out, softmax_lse)
291
+ ctx.softmax_scale = softmax_scale
292
+ ctx.causal = causal
293
+ ctx.window_size = window_size
294
+ ctx.softcap = softcap
295
+ ctx.deterministic = deterministic
296
+ ctx.sm_margin = sm_margin
297
+ return out, softmax_lse
298
+
299
+ @staticmethod
300
+ def backward(ctx, dout, *args):
301
+ q, k, v, out, softmax_lse = ctx.saved_tensors
302
+ dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
303
+ _flash_attn_backward(
304
+ dout,
305
+ q,
306
+ k,
307
+ v,
308
+ out,
309
+ softmax_lse,
310
+ None, None, # cu_seqlens_q, cu_seqlens_k,
311
+ None, None, # sequed_q, sequed_k,
312
+ None, None, # max_seqlen_q, max_seqlen_k,
313
+ dq,
314
+ dk,
315
+ dv,
316
+ ctx.softmax_scale,
317
+ ctx.causal,
318
+ ctx.window_size,
319
+ ctx.softcap,
320
+ ctx.deterministic,
321
+ ctx.sm_margin,
322
+ )
323
+ dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
324
+ dk = dk[..., : dout.shape[-1]]
325
+ dv = dv[..., : dout.shape[-1]]
326
+ return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None
327
+
328
+
329
+ class FlashAttnVarlenFunc(torch.autograd.Function):
330
+
331
+ @staticmethod
332
+ def forward(
333
+ ctx,
334
+ q,
335
+ k,
336
+ v,
337
+ cu_seqlens_q,
338
+ cu_seqlens_k,
339
+ seqused_q,
340
+ seqused_k,
341
+ max_seqlen_q,
342
+ max_seqlen_k,
343
+ softmax_scale,
344
+ causal,
345
+ qv=None,
346
+ q_descale=None, k_descale=None, v_descale=None,
347
+ window_size=(-1, -1),
348
+ softcap=0.0,
349
+ num_splits=1,
350
+ pack_gqa=None,
351
+ deterministic=False,
352
+ sm_margin=0,
353
+ s_aux=None,
354
+ ):
355
+ if softmax_scale is None:
356
+ softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
357
+ # out, q, k, v, out_padded, softmax_lse = _flash_attn_varlen_forward(
358
+ out, softmax_lse, *rest = _flash_attn_forward(
359
+ q,
360
+ k,
361
+ v,
362
+ None, None, # k_new, v_new
363
+ qv, # qv
364
+ None, # out
365
+ cu_seqlens_q,
366
+ cu_seqlens_k,
367
+ None, # cu_seqlens_k_new
368
+ seqused_q,
369
+ seqused_k,
370
+ max_seqlen_q,
371
+ max_seqlen_k,
372
+ None, None, None, # page_table, kv_batch_idx, leftpad_k,
373
+ None, None, None, # rotary_cos/sin, seqlens_rotary
374
+ q_descale, k_descale, v_descale,
375
+ softmax_scale,
376
+ causal=causal,
377
+ window_size=window_size,
378
+ softcap=softcap,
379
+ num_splits=num_splits,
380
+ pack_gqa=pack_gqa,
381
+ sm_margin=sm_margin,
382
+ s_aux=s_aux,
383
+ )
384
+ # ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
385
+ ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
386
+ ctx.max_seqlen_q = max_seqlen_q
387
+ ctx.max_seqlen_k = max_seqlen_k
388
+ ctx.softmax_scale = softmax_scale
389
+ ctx.causal = causal
390
+ ctx.window_size = window_size
391
+ ctx.softcap = softcap
392
+ ctx.deterministic = deterministic
393
+ ctx.sm_margin = sm_margin
394
+ return out, softmax_lse
395
+
396
+ @staticmethod
397
+ def backward(ctx, dout, *args):
398
+ q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = ctx.saved_tensors
399
+ dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
400
+ _flash_attn_backward(
401
+ dout,
402
+ q,
403
+ k,
404
+ v,
405
+ out,
406
+ softmax_lse,
407
+ cu_seqlens_q,
408
+ cu_seqlens_k,
409
+ seqused_q,
410
+ seqused_k,
411
+ ctx.max_seqlen_q,
412
+ ctx.max_seqlen_k,
413
+ dq,
414
+ dk,
415
+ dv,
416
+ ctx.softmax_scale,
417
+ ctx.causal,
418
+ ctx.window_size,
419
+ ctx.softcap,
420
+ ctx.deterministic,
421
+ ctx.sm_margin,
422
+ )
423
+ dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
424
+ dk = dk[..., : dout.shape[-1]]
425
+ dv = dv[..., : dout.shape[-1]]
426
+ return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None
427
+
428
+
429
+ def flash_attn_qkvpacked_func(
430
+ qkv,
431
+ softmax_scale=None,
432
+ causal=False,
433
+ q_descale=None, k_descale=None, v_descale=None,
434
+ window_size=(-1, -1),
435
+ softcap=0.0,
436
+ deterministic=False,
437
+ num_heads_q=None,
438
+ ):
439
+ """dropout_p should be set to 0.0 during evaluation
440
+ If Q, K, V are already stacked into 1 tensor, this function will be faster than
441
+ calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
442
+ of the gradients of Q, K, V.
443
+ For multi-query and grouped-query attention (MQA/GQA), please see
444
+ flash_attn_kvpacked_func and flash_attn_func.
445
+
446
+ If window_size != (-1, -1), implements sliding window local attention. Query at position i
447
+ will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.
448
+
449
+ Arguments:
450
+ qkv: (batch_size, seqlen, 3, nheads, headdim)
451
+ dropout_p: float. Dropout probability.
452
+ softmax_scale: float. The scaling of QK^T before applying softmax.
453
+ Default to 1 / sqrt(headdim).
454
+ causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
455
+ window_size: (left, right). If not (-1, -1), implements sliding window local attention.
456
+ softcap: float. Anything > 0 activates softcapping attention.
457
+ alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to
458
+ the attention score of query i and key j.
459
+ deterministic: bool. Whether to use the deterministic implementation of the backward pass,
460
+ which is slightly slower and uses more memory. The forward pass is always deterministic.
461
+ return_attn_probs: bool. Whether to return the attention probabilities. This option is for
462
+ testing only. The returned probabilities are not guaranteed to be correct
463
+ (they might not have the right scaling).
464
+ Return:
465
+ out: (batch_size, seqlen, nheads, headdim).
466
+ softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
467
+ logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
468
+ normalization factor).
469
+ S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
470
+ The output of softmax (possibly with different scaling). It also encodes the dropout
471
+ pattern (negative means that location was dropped, nonnegative means it was kept).
472
+ """
473
+ return FlashAttnQKVPackedFunc.apply(
474
+ qkv,
475
+ softmax_scale,
476
+ causal,
477
+ q_descale, k_descale, v_descale,
478
+ window_size,
479
+ softcap,
480
+ deterministic,
481
+ num_heads_q,
482
+ )
483
+
484
+
485
+ def flash_attn_func(
486
+ q,
487
+ k,
488
+ v,
489
+ softmax_scale=None,
490
+ causal=False,
491
+ qv=None,
492
+ q_descale=None, k_descale=None, v_descale=None,
493
+ window_size=(-1, -1),
494
+ softcap=0.0,
495
+ num_splits=1,
496
+ pack_gqa=None,
497
+ deterministic=False,
498
+ sm_margin=0,
499
+ s_aux=None,
500
+ ):
501
+ """dropout_p should be set to 0.0 during evaluation
502
+ Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
503
+ than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
504
+ For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
505
+ 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
506
+
507
+ If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
508
+ For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
509
+ 1 1 1 1 0
510
+ 1 1 1 1 1
511
+ If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
512
+ 0 0
513
+ 0 0
514
+ 0 0
515
+ 1 0
516
+ 1 1
517
+ If the row of the mask is all zero, the output will be zero.
518
+
519
+ If window_size != (-1, -1), implements sliding window local attention. Query at position i
520
+ will only attend to keys between
521
+ [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
522
+
523
+ Arguments:
524
+ q: (batch_size, seqlen, nheads, headdim)
525
+ k: (batch_size, seqlen, nheads_k, headdim)
526
+ v: (batch_size, seqlen, nheads_k, headdim)
527
+ dropout_p: float. Dropout probability.
528
+ softmax_scale: float. The scaling of QK^T before applying softmax.
529
+ Default to 1 / sqrt(headdim).
530
+ causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
531
+ window_size: (left, right). If not (-1, -1), implements sliding window local attention.
532
+ alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
533
+ (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
534
+ is added to the attention score of query i and key j.
535
+ deterministic: bool. Whether to use the deterministic implementation of the backward pass,
536
+ which is slightly slower and uses more memory. The forward pass is always deterministic.
537
+ return_attn_probs: bool. Whether to return the attention probabilities. This option is for
538
+ testing only. The returned probabilities are not guaranteed to be correct
539
+ (they might not have the right scaling).
540
+ Return:
541
+ out: (batch_size, seqlen, nheads, headdim).
542
+ softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
543
+ logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
544
+ normalization factor).
545
+ """
546
+ return FlashAttnFunc.apply(
547
+ q,
548
+ k,
549
+ v,
550
+ softmax_scale,
551
+ causal,
552
+ qv,
553
+ q_descale, k_descale, v_descale,
554
+ window_size,
555
+ softcap,
556
+ num_splits,
557
+ pack_gqa,
558
+ deterministic,
559
+ sm_margin,
560
+ s_aux,
561
+ )
562
+
563
+
564
+ def flash_attn_varlen_func(
565
+ q,
566
+ k,
567
+ v,
568
+ cu_seqlens_q,
569
+ cu_seqlens_k,
570
+ max_seqlen_q,
571
+ max_seqlen_k,
572
+ seqused_q=None,
573
+ seqused_k=None,
574
+ softmax_scale=None,
575
+ causal=False,
576
+ qv=None,
577
+ q_descale=None, k_descale=None, v_descale=None,
578
+ window_size=(-1, -1),
579
+ softcap=0.0,
580
+ num_splits=1,
581
+ pack_gqa=None,
582
+ deterministic=False,
583
+ sm_margin=0,
584
+ s_aux=None,
585
+ ):
586
+ return FlashAttnVarlenFunc.apply(
587
+ q,
588
+ k,
589
+ v,
590
+ cu_seqlens_q,
591
+ cu_seqlens_k,
592
+ seqused_q,
593
+ seqused_k,
594
+ max_seqlen_q,
595
+ max_seqlen_k,
596
+ softmax_scale,
597
+ causal,
598
+ qv,
599
+ q_descale, k_descale, v_descale,
600
+ window_size,
601
+ softcap,
602
+ num_splits,
603
+ pack_gqa,
604
+ deterministic,
605
+ sm_margin,
606
+ s_aux,
607
+ )
608
+
609
+
610
+ def flash_attn_combine(out_partial, lse_partial, out=None, out_dtype=None):
611
+ return ops.fwd_combine(out_partial, lse_partial, out, out_dtype)
612
+
613
+
614
+ def flash_attn_with_kvcache(
615
+ q,
616
+ k_cache,
617
+ v_cache,
618
+ k=None,
619
+ v=None,
620
+ qv=None,
621
+ rotary_cos=None,
622
+ rotary_sin=None,
623
+ cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
624
+ cache_batch_idx: Optional[torch.Tensor] = None,
625
+ cache_leftpad: Optional[torch.Tensor] = None,
626
+ page_table: Optional[torch.Tensor] = None,
627
+ cu_seqlens_q: Optional[torch.Tensor] = None,
628
+ cu_seqlens_k_new: Optional[torch.Tensor] = None,
629
+ max_seqlen_q: Optional[int] = None,
630
+ rotary_seqlens: Optional[torch.Tensor] = None,
631
+ q_descale: Optional[torch.Tensor] = None,
632
+ k_descale: Optional[torch.Tensor] = None,
633
+ v_descale: Optional[torch.Tensor] = None,
634
+ softmax_scale=None,
635
+ causal=False,
636
+ window_size=(-1, -1), # -1 means infinite context window
637
+ softcap=0.0, # 0.0 means deactivated
638
+ rotary_interleaved=True,
639
+ scheduler_metadata=None,
640
+ num_splits=0, # Can be tuned for speed
641
+ pack_gqa=None, # Can be tuned for speed
642
+ sm_margin=0, # Can be tuned if some SMs are used for communication
643
+ return_softmax_lse=False,
644
+ s_aux=None,
645
+ ):
646
+ """
647
+ If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
648
+ k and v. This is useful for incremental decoding: you can pass in the cached keys/values from
649
+ the previous step, and update them with the new keys/values from the current step, and do
650
+ attention with the updated cache, all in 1 kernel.
651
+
652
+ If you pass in k / v, you must make sure that the cache is large enough to hold the new values.
653
+ For example, the KV cache could be pre-allocated with the max sequence length, and you can use
654
+ cache_seqlens to keep track of the current sequence lengths of each sequence in the batch.
655
+
656
+ Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be
657
+ rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
658
+ If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos
659
+ and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
660
+ If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at
661
+ indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens).
662
+
663
+ See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function.
664
+
665
+ Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
666
+ than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
667
+ For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
668
+ 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
669
+
670
+ If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
671
+ For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
672
+ 1 1 1 1 0
673
+ 1 1 1 1 1
674
+ If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
675
+ 0 0
676
+ 0 0
677
+ 0 0
678
+ 1 0
679
+ 1 1
680
+ If the row of the mask is all zero, the output will be zero.
681
+
682
+ If window_size != (-1, -1), implements sliding window local attention. Query at position i
683
+ will only attend to keys between
684
+ [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
685
+
686
+ Note: Does not support backward pass.
687
+
688
+ Arguments:
689
+ q: (batch_size, seqlen, nheads, headdim)
690
+ k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no page_table,
691
+ or (num_blocks, page_block_size, nheads_k, headdim) if there's a page_table (i.e. paged KV cache)
692
+ page_block_size must be a multiple of 256.
693
+ v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim_v) if there's no page_table,
694
+ or (num_blocks, page_block_size, nheads_k, headdim_v) if there's a page_table (i.e. paged KV cache)
695
+ k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
696
+ k with k_cache, starting at the indices specified by cache_seqlens.
697
+ v [optional]: (batch_size, seqlen_new, nheads_k, headdim_v). Similar to k.
698
+ qv [optional]: (batch_size, seqlen, nheads, headdim_v)
699
+ rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding
700
+ to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16.
701
+ rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
702
+ cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
703
+ KV cache.
704
+ cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.
705
+ If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].
706
+ If the indices are not distinct, and k and v are provided, the values updated in the cache
707
+ might come from any of the duplicate indices.
708
+ cache_leftpad: (batch_size,), dtype torch.int32. The index that the KV cache starts. If None, assume 0.
709
+ page_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32.
710
+ softmax_scale: float. The scaling of QK^T before applying softmax.
711
+ Default to 1 / sqrt(headdim).
712
+ causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
713
+ window_size: (left, right). If not (-1, -1), implements sliding window local attention.
714
+ softcap: float. Anything > 0 activates softcapping attention.
715
+ rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
716
+ If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
717
+ rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
718
+ (i.e. GPT-NeoX style).
719
+ num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
720
+ If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
721
+ to automatically determine the number of splits.
722
+ Don't change this unless you know what you are doing.
723
+ return_softmax_lse: bool. Whether to return the logsumexp of the attention scores.
724
+
725
+ Return:
726
+ out: (batch_size, seqlen, nheads, headdim).
727
+ softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The
728
+ logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
729
+ normalization factor).
730
+ """
731
+ assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
732
+ assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
733
+ if softmax_scale is None:
734
+ softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
735
+ if cache_seqlens is not None and isinstance(cache_seqlens, int):
736
+ cache_seqlens = torch.full(
737
+ (k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device
738
+ )
739
+ cache_seqlens = maybe_contiguous(cache_seqlens)
740
+ out, softmax_lse, *rest = _flash_attn_forward(
741
+ q,
742
+ k_cache,
743
+ v_cache,
744
+ k,
745
+ v,
746
+ qv,
747
+ None, # out
748
+ cu_seqlens_q,
749
+ None, # cu_seqlens_k
750
+ cu_seqlens_k_new,
751
+ None, # seqused_q
752
+ cache_seqlens,
753
+ max_seqlen_q,
754
+ None, # max_seqlen_k
755
+ page_table,
756
+ cache_batch_idx,
757
+ cache_leftpad,
758
+ rotary_cos,
759
+ rotary_sin,
760
+ rotary_seqlens,
761
+ q_descale, k_descale, v_descale,
762
+ softmax_scale,
763
+ causal=causal,
764
+ window_size=window_size,
765
+ softcap=softcap,
766
+ rotary_interleaved=rotary_interleaved,
767
+ scheduler_metadata=scheduler_metadata,
768
+ num_splits=num_splits,
769
+ pack_gqa=pack_gqa,
770
+ sm_margin=sm_margin,
771
+ s_aux=s_aux,
772
+ )
773
+ # return (out, softmax_lse) if return_softmax_lse else out
774
+ return (out, softmax_lse, *rest) if return_softmax_lse else out
775
+
776
+
777
+ def get_scheduler_metadata(
778
+ batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim,
779
+ cache_seqlens: torch.Tensor,
780
+ qkv_dtype=torch.bfloat16,
781
+ headdim_v=None,
782
+ cu_seqlens_q: Optional[torch.Tensor] = None,
783
+ cu_seqlens_k_new: Optional[torch.Tensor] = None,
784
+ cache_leftpad: Optional[torch.Tensor] = None,
785
+ page_size: Optional[int] = None,
786
+ max_seqlen_k_new=0,
787
+ causal=False,
788
+ window_size=(-1, -1), # -1 means infinite context window
789
+ has_softcap=False,
790
+ num_splits=0, # Can be tuned for speed
791
+ pack_gqa=None, # Can be tuned for speed
792
+ sm_margin=0, # Can be tuned if some SMs are used for communication
793
+ ):
794
+ cache_seqlens = maybe_contiguous(cache_seqlens)
795
+ if headdim_v is None:
796
+ headdim_v = headdim
797
+ scheduler_metadata = ops.get_scheduler_metadata(
798
+ batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim, headdim_v,
799
+ qkv_dtype,
800
+ cache_seqlens,
801
+ cu_seqlens_q,
802
+ None, # cu_seqlens_k
803
+ cu_seqlens_k_new,
804
+ None, # seqused_q
805
+ cache_leftpad,
806
+ page_size,
807
+ max_seqlen_k_new,
808
+ causal,
809
+ window_size[0], window_size[1],
810
+ has_softcap,
811
+ num_splits,
812
+ pack_gqa,
813
+ sm_margin,
814
+ )
815
+ return scheduler_metadata
build/torch28-cxx11-cu128-x86_64-linux/vllm_flash_attn3/__init__.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .flash_attn_interface import (
2
+ flash_attn_combine,
3
+ flash_attn_func,
4
+ flash_attn_qkvpacked_func,
5
+ flash_attn_varlen_func,
6
+ flash_attn_with_kvcache,
7
+ get_scheduler_metadata,
8
+ )
9
+
10
+ __all__ = [
11
+ "flash_attn_combine",
12
+ "flash_attn_func",
13
+ "flash_attn_qkvpacked_func",
14
+ "flash_attn_varlen_func",
15
+ "flash_attn_with_kvcache",
16
+ "get_scheduler_metadata",
17
+ ]
build/torch28-cxx11-cu128-x86_64-linux/vllm_flash_attn3/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (442 Bytes). View file
 
build/torch28-cxx11-cu128-x86_64-linux/vllm_flash_attn3/__pycache__/_ops.cpython-313.pyc ADDED
Binary file (556 Bytes). View file
 
build/torch28-cxx11-cu128-x86_64-linux/vllm_flash_attn3/__pycache__/flash_attn_interface.cpython-313.pyc ADDED
Binary file (25.6 kB). View file
 
build/torch28-cxx11-cu128-x86_64-linux/vllm_flash_attn3/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _vllm_flash_attn3_28fbd26_dirty
3
+ ops = torch.ops._vllm_flash_attn3_28fbd26_dirty
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_vllm_flash_attn3_28fbd26_dirty::{op_name}"
build/torch28-cxx11-cu128-x86_64-linux/vllm_flash_attn3/_vllm_flash_attn3_28fbd26_dirty.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:29c8dc9754cba50e7eeacfbff19710fedb4152119c57a0c5afa00b036480fe6f
3
+ size 915245760
build/torch28-cxx11-cu128-x86_64-linux/vllm_flash_attn3/flash_attn_interface.py ADDED
@@ -0,0 +1,815 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Tri Dao.
2
+
3
+ from typing import Optional, Union
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ # isort: off
9
+ # We need to import the CUDA kernels after importing torch
10
+ from ._ops import ops
11
+
12
+ # isort: on
13
+
14
+
15
+ def maybe_contiguous(x):
16
+ return x.contiguous() if x is not None and x.stride(-1) != 1 else x
17
+
18
+
19
+ def _flash_attn_forward(
20
+ q,
21
+ k,
22
+ v,
23
+ k_new,
24
+ v_new,
25
+ qv,
26
+ out,
27
+ cu_seqlens_q,
28
+ cu_seqlens_k,
29
+ cu_seqlens_k_new,
30
+ seqused_q,
31
+ seqused_k,
32
+ max_seqlen_q,
33
+ max_seqlen_k,
34
+ page_table,
35
+ kv_batch_idx,
36
+ leftpad_k,
37
+ rotary_cos,
38
+ rotary_sin,
39
+ seqlens_rotary,
40
+ q_descale,
41
+ k_descale,
42
+ v_descale,
43
+ softmax_scale,
44
+ causal,
45
+ window_size=(-1, -1),
46
+ softcap=0.0,
47
+ rotary_interleaved=True,
48
+ scheduler_metadata=None,
49
+ num_splits=1,
50
+ pack_gqa=None,
51
+ sm_margin=0,
52
+ s_aux=None):
53
+ q, k, k_new, v_new = [maybe_contiguous(x) for x in (q, k, k_new, v_new)]
54
+ v = v.contiguous() if v.stride(-1) != 1 and v.stride(-3) != 1 else v
55
+ cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new = [
56
+ maybe_contiguous(x) for x in (cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new)
57
+ ]
58
+ seqused_q, seqused_k = [maybe_contiguous(x) for x in (seqused_q, seqused_k)]
59
+ page_table, kv_batch_idx, leftpad_k = [
60
+ maybe_contiguous(x) for x in (page_table, kv_batch_idx, leftpad_k)
61
+ ]
62
+ rotary_cos, rotary_sin = [maybe_contiguous(x) for x in (rotary_cos, rotary_sin)]
63
+ seqlens_rotary = maybe_contiguous(seqlens_rotary)
64
+ out, softmax_lse, *rest = ops.fwd(
65
+ q,
66
+ k,
67
+ v,
68
+ k_new,
69
+ v_new,
70
+ qv,
71
+ out,
72
+ cu_seqlens_q,
73
+ cu_seqlens_k,
74
+ cu_seqlens_k_new,
75
+ seqused_q,
76
+ seqused_k,
77
+ max_seqlen_q,
78
+ max_seqlen_k,
79
+ page_table,
80
+ kv_batch_idx,
81
+ leftpad_k,
82
+ rotary_cos,
83
+ rotary_sin,
84
+ seqlens_rotary,
85
+ q_descale,
86
+ k_descale,
87
+ v_descale,
88
+ softmax_scale,
89
+ causal,
90
+ window_size[0],
91
+ window_size[1],
92
+ softcap,
93
+ rotary_interleaved,
94
+ scheduler_metadata,
95
+ num_splits,
96
+ pack_gqa,
97
+ sm_margin,
98
+ s_aux
99
+ )
100
+ return out, softmax_lse, *rest
101
+
102
+
103
+ def _flash_attn_backward(
104
+ dout,
105
+ q,
106
+ k,
107
+ v,
108
+ out,
109
+ softmax_lse,
110
+ cu_seqlens_q,
111
+ cu_seqlens_k,
112
+ sequed_q,
113
+ sequed_k,
114
+ max_seqlen_q,
115
+ max_seqlen_k,
116
+ dq,
117
+ dk,
118
+ dv,
119
+ softmax_scale,
120
+ causal,
121
+ window_size=(-1, -1),
122
+ softcap=0.0,
123
+ deterministic=False,
124
+ sm_margin=0,
125
+ ):
126
+ # dq, dk, dv are allocated by us so they should already be contiguous
127
+ dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
128
+ dq, dk, dv, softmax_d, *rest = ops.bwd(
129
+ dout,
130
+ q,
131
+ k,
132
+ v,
133
+ out,
134
+ softmax_lse,
135
+ dq,
136
+ dk,
137
+ dv,
138
+ cu_seqlens_q,
139
+ cu_seqlens_k,
140
+ sequed_q,
141
+ sequed_k,
142
+ max_seqlen_q,
143
+ max_seqlen_k,
144
+ softmax_scale,
145
+ causal,
146
+ window_size[0],
147
+ window_size[1],
148
+ softcap,
149
+ deterministic,
150
+ sm_margin,
151
+ )
152
+ return dq, dk, dv, softmax_d
153
+
154
+
155
+ class FlashAttnQKVPackedFunc(torch.autograd.Function):
156
+ @staticmethod
157
+ def forward(
158
+ ctx,
159
+ qkv,
160
+ softmax_scale,
161
+ causal,
162
+ q_descale=None, k_descale=None, v_descale=None,
163
+ window_size=(-1, -1),
164
+ softcap=0.0,
165
+ deterministic=False,
166
+ num_heads_q=None,
167
+ ):
168
+ if softmax_scale is None:
169
+ softmax_scale = qkv.shape[-1] ** (-0.5)
170
+ if qkv.dim() == 5:
171
+ assert qkv.shape[-3] == 3
172
+ q, k, v = qkv.unbind(dim=-3)
173
+ else:
174
+ assert qkv.dim() == 4
175
+ assert num_heads_q is not None
176
+ num_heads_k = (qkv.shape[2] - num_heads_q) // 2
177
+ assert num_heads_k * 2 + num_heads_q == qkv.shape[2]
178
+ q, k, v = qkv.split([num_heads_q, num_heads_k, num_heads_k], dim=-2)
179
+ out, softmax_lse, *rest = _flash_attn_forward(
180
+ q,
181
+ k,
182
+ v,
183
+ None, None, # k_new, v_new
184
+ None, # qv
185
+ None, # out
186
+ None, None, None, # cu_seqlens_q/k/k_new
187
+ None, None, # seqused_q/k
188
+ None, None, # max_seqlen_q/k
189
+ None, None, None, # page_table, kv_batch_idx, leftpad_k,
190
+ None, None, None, # rotary_cos/sin, seqlens_rotary
191
+ q_descale, k_descale, v_descale,
192
+ softmax_scale,
193
+ causal=causal,
194
+ window_size=window_size,
195
+ softcap=softcap,
196
+ )
197
+ # ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
198
+ ctx.save_for_backward(q, k, v, out, softmax_lse)
199
+ ctx.softmax_scale = softmax_scale
200
+ ctx.causal = causal
201
+ ctx.window_size = window_size
202
+ ctx.softcap = softcap
203
+ ctx.deterministic = deterministic
204
+ ctx.ndim = qkv.dim()
205
+ # return out, softmax_lse
206
+ return out
207
+
208
+ @staticmethod
209
+ def backward(ctx, dout, *args):
210
+ q, k, v, out, softmax_lse = ctx.saved_tensors
211
+ if ctx.ndim == 5:
212
+ qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
213
+ dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
214
+ dq, dk, dv = dqkv.unbind(dim=-3)
215
+ else:
216
+ num_heads_q = q.shape[2]
217
+ num_heads_k = k.shape[2]
218
+ qkv_shape = q.shape[:-2] + (num_heads_q + num_heads_k * 2, *q.shape[-1:])
219
+ dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
220
+ dq, dk, dv = dqkv.split([num_heads_q, num_heads_k, num_heads_k], dim=-2)
221
+ _flash_attn_backward(
222
+ dout,
223
+ q,
224
+ k,
225
+ v,
226
+ out,
227
+ softmax_lse,
228
+ None, None, # cu_seqlens_q, cu_seqlens_k,
229
+ None, None, # sequed_q, sequed_k,
230
+ None, None, # max_seqlen_q, max_seqlen_k,
231
+ dq,
232
+ dk,
233
+ dv,
234
+ ctx.softmax_scale,
235
+ ctx.causal,
236
+ ctx.window_size,
237
+ ctx.softcap,
238
+ ctx.deterministic,
239
+ )
240
+ dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension
241
+ return dqkv, None, None, None, None, None, None, None, None, None, None
242
+
243
+
244
+ class FlashAttnFunc(torch.autograd.Function):
245
+
246
+ @staticmethod
247
+ def forward(
248
+ ctx,
249
+ q,
250
+ k,
251
+ v,
252
+ softmax_scale,
253
+ causal,
254
+ qv=None,
255
+ q_descale=None, k_descale=None, v_descale=None,
256
+ window_size=(-1, -1),
257
+ softcap=0.0,
258
+ num_splits=1,
259
+ pack_gqa=None,
260
+ deterministic=False,
261
+ sm_margin=0,
262
+ s_aux=None,
263
+ ):
264
+ if softmax_scale is None:
265
+ softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
266
+ # out, q, k, v, out_padded, softmax_lse = _flash_attn_forward(
267
+ out, softmax_lse, *rest = _flash_attn_forward(
268
+ q,
269
+ k,
270
+ v,
271
+ None, None, # k_new, v_new
272
+ qv, # qv
273
+ None, # out
274
+ None, None, None, # cu_seqlens_q/k/k_new
275
+ None, None, # seqused_q/k
276
+ None, None, # max_seqlen_q/k
277
+ None, None, None, # page_table, kv_batch_idx, leftpad_k,
278
+ None, None, None, # rotary_cos/sin, seqlens_rotary
279
+ q_descale, k_descale, v_descale,
280
+ softmax_scale,
281
+ causal=causal,
282
+ window_size=window_size,
283
+ softcap=softcap,
284
+ num_splits=num_splits,
285
+ pack_gqa=pack_gqa,
286
+ sm_margin=sm_margin,
287
+ s_aux=s_aux,
288
+ )
289
+ # ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
290
+ ctx.save_for_backward(q, k, v, out, softmax_lse)
291
+ ctx.softmax_scale = softmax_scale
292
+ ctx.causal = causal
293
+ ctx.window_size = window_size
294
+ ctx.softcap = softcap
295
+ ctx.deterministic = deterministic
296
+ ctx.sm_margin = sm_margin
297
+ return out, softmax_lse
298
+
299
+ @staticmethod
300
+ def backward(ctx, dout, *args):
301
+ q, k, v, out, softmax_lse = ctx.saved_tensors
302
+ dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
303
+ _flash_attn_backward(
304
+ dout,
305
+ q,
306
+ k,
307
+ v,
308
+ out,
309
+ softmax_lse,
310
+ None, None, # cu_seqlens_q, cu_seqlens_k,
311
+ None, None, # sequed_q, sequed_k,
312
+ None, None, # max_seqlen_q, max_seqlen_k,
313
+ dq,
314
+ dk,
315
+ dv,
316
+ ctx.softmax_scale,
317
+ ctx.causal,
318
+ ctx.window_size,
319
+ ctx.softcap,
320
+ ctx.deterministic,
321
+ ctx.sm_margin,
322
+ )
323
+ dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
324
+ dk = dk[..., : dout.shape[-1]]
325
+ dv = dv[..., : dout.shape[-1]]
326
+ return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None
327
+
328
+
329
+ class FlashAttnVarlenFunc(torch.autograd.Function):
330
+
331
+ @staticmethod
332
+ def forward(
333
+ ctx,
334
+ q,
335
+ k,
336
+ v,
337
+ cu_seqlens_q,
338
+ cu_seqlens_k,
339
+ seqused_q,
340
+ seqused_k,
341
+ max_seqlen_q,
342
+ max_seqlen_k,
343
+ softmax_scale,
344
+ causal,
345
+ qv=None,
346
+ q_descale=None, k_descale=None, v_descale=None,
347
+ window_size=(-1, -1),
348
+ softcap=0.0,
349
+ num_splits=1,
350
+ pack_gqa=None,
351
+ deterministic=False,
352
+ sm_margin=0,
353
+ s_aux=None,
354
+ ):
355
+ if softmax_scale is None:
356
+ softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
357
+ # out, q, k, v, out_padded, softmax_lse = _flash_attn_varlen_forward(
358
+ out, softmax_lse, *rest = _flash_attn_forward(
359
+ q,
360
+ k,
361
+ v,
362
+ None, None, # k_new, v_new
363
+ qv, # qv
364
+ None, # out
365
+ cu_seqlens_q,
366
+ cu_seqlens_k,
367
+ None, # cu_seqlens_k_new
368
+ seqused_q,
369
+ seqused_k,
370
+ max_seqlen_q,
371
+ max_seqlen_k,
372
+ None, None, None, # page_table, kv_batch_idx, leftpad_k,
373
+ None, None, None, # rotary_cos/sin, seqlens_rotary
374
+ q_descale, k_descale, v_descale,
375
+ softmax_scale,
376
+ causal=causal,
377
+ window_size=window_size,
378
+ softcap=softcap,
379
+ num_splits=num_splits,
380
+ pack_gqa=pack_gqa,
381
+ sm_margin=sm_margin,
382
+ s_aux=s_aux,
383
+ )
384
+ # ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
385
+ ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
386
+ ctx.max_seqlen_q = max_seqlen_q
387
+ ctx.max_seqlen_k = max_seqlen_k
388
+ ctx.softmax_scale = softmax_scale
389
+ ctx.causal = causal
390
+ ctx.window_size = window_size
391
+ ctx.softcap = softcap
392
+ ctx.deterministic = deterministic
393
+ ctx.sm_margin = sm_margin
394
+ return out, softmax_lse
395
+
396
+ @staticmethod
397
+ def backward(ctx, dout, *args):
398
+ q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = ctx.saved_tensors
399
+ dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
400
+ _flash_attn_backward(
401
+ dout,
402
+ q,
403
+ k,
404
+ v,
405
+ out,
406
+ softmax_lse,
407
+ cu_seqlens_q,
408
+ cu_seqlens_k,
409
+ seqused_q,
410
+ seqused_k,
411
+ ctx.max_seqlen_q,
412
+ ctx.max_seqlen_k,
413
+ dq,
414
+ dk,
415
+ dv,
416
+ ctx.softmax_scale,
417
+ ctx.causal,
418
+ ctx.window_size,
419
+ ctx.softcap,
420
+ ctx.deterministic,
421
+ ctx.sm_margin,
422
+ )
423
+ dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
424
+ dk = dk[..., : dout.shape[-1]]
425
+ dv = dv[..., : dout.shape[-1]]
426
+ return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None
427
+
428
+
429
+ def flash_attn_qkvpacked_func(
430
+ qkv,
431
+ softmax_scale=None,
432
+ causal=False,
433
+ q_descale=None, k_descale=None, v_descale=None,
434
+ window_size=(-1, -1),
435
+ softcap=0.0,
436
+ deterministic=False,
437
+ num_heads_q=None,
438
+ ):
439
+ """dropout_p should be set to 0.0 during evaluation
440
+ If Q, K, V are already stacked into 1 tensor, this function will be faster than
441
+ calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
442
+ of the gradients of Q, K, V.
443
+ For multi-query and grouped-query attention (MQA/GQA), please see
444
+ flash_attn_kvpacked_func and flash_attn_func.
445
+
446
+ If window_size != (-1, -1), implements sliding window local attention. Query at position i
447
+ will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.
448
+
449
+ Arguments:
450
+ qkv: (batch_size, seqlen, 3, nheads, headdim)
451
+ dropout_p: float. Dropout probability.
452
+ softmax_scale: float. The scaling of QK^T before applying softmax.
453
+ Default to 1 / sqrt(headdim).
454
+ causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
455
+ window_size: (left, right). If not (-1, -1), implements sliding window local attention.
456
+ softcap: float. Anything > 0 activates softcapping attention.
457
+ alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to
458
+ the attention score of query i and key j.
459
+ deterministic: bool. Whether to use the deterministic implementation of the backward pass,
460
+ which is slightly slower and uses more memory. The forward pass is always deterministic.
461
+ return_attn_probs: bool. Whether to return the attention probabilities. This option is for
462
+ testing only. The returned probabilities are not guaranteed to be correct
463
+ (they might not have the right scaling).
464
+ Return:
465
+ out: (batch_size, seqlen, nheads, headdim).
466
+ softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
467
+ logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
468
+ normalization factor).
469
+ S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
470
+ The output of softmax (possibly with different scaling). It also encodes the dropout
471
+ pattern (negative means that location was dropped, nonnegative means it was kept).
472
+ """
473
+ return FlashAttnQKVPackedFunc.apply(
474
+ qkv,
475
+ softmax_scale,
476
+ causal,
477
+ q_descale, k_descale, v_descale,
478
+ window_size,
479
+ softcap,
480
+ deterministic,
481
+ num_heads_q,
482
+ )
483
+
484
+
485
+ def flash_attn_func(
486
+ q,
487
+ k,
488
+ v,
489
+ softmax_scale=None,
490
+ causal=False,
491
+ qv=None,
492
+ q_descale=None, k_descale=None, v_descale=None,
493
+ window_size=(-1, -1),
494
+ softcap=0.0,
495
+ num_splits=1,
496
+ pack_gqa=None,
497
+ deterministic=False,
498
+ sm_margin=0,
499
+ s_aux=None,
500
+ ):
501
+ """dropout_p should be set to 0.0 during evaluation
502
+ Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
503
+ than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
504
+ For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
505
+ 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
506
+
507
+ If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
508
+ For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
509
+ 1 1 1 1 0
510
+ 1 1 1 1 1
511
+ If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
512
+ 0 0
513
+ 0 0
514
+ 0 0
515
+ 1 0
516
+ 1 1
517
+ If the row of the mask is all zero, the output will be zero.
518
+
519
+ If window_size != (-1, -1), implements sliding window local attention. Query at position i
520
+ will only attend to keys between
521
+ [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
522
+
523
+ Arguments:
524
+ q: (batch_size, seqlen, nheads, headdim)
525
+ k: (batch_size, seqlen, nheads_k, headdim)
526
+ v: (batch_size, seqlen, nheads_k, headdim)
527
+ dropout_p: float. Dropout probability.
528
+ softmax_scale: float. The scaling of QK^T before applying softmax.
529
+ Default to 1 / sqrt(headdim).
530
+ causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
531
+ window_size: (left, right). If not (-1, -1), implements sliding window local attention.
532
+ alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
533
+ (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
534
+ is added to the attention score of query i and key j.
535
+ deterministic: bool. Whether to use the deterministic implementation of the backward pass,
536
+ which is slightly slower and uses more memory. The forward pass is always deterministic.
537
+ return_attn_probs: bool. Whether to return the attention probabilities. This option is for
538
+ testing only. The returned probabilities are not guaranteed to be correct
539
+ (they might not have the right scaling).
540
+ Return:
541
+ out: (batch_size, seqlen, nheads, headdim).
542
+ softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
543
+ logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
544
+ normalization factor).
545
+ """
546
+ return FlashAttnFunc.apply(
547
+ q,
548
+ k,
549
+ v,
550
+ softmax_scale,
551
+ causal,
552
+ qv,
553
+ q_descale, k_descale, v_descale,
554
+ window_size,
555
+ softcap,
556
+ num_splits,
557
+ pack_gqa,
558
+ deterministic,
559
+ sm_margin,
560
+ s_aux,
561
+ )
562
+
563
+
564
+ def flash_attn_varlen_func(
565
+ q,
566
+ k,
567
+ v,
568
+ cu_seqlens_q,
569
+ cu_seqlens_k,
570
+ max_seqlen_q,
571
+ max_seqlen_k,
572
+ seqused_q=None,
573
+ seqused_k=None,
574
+ softmax_scale=None,
575
+ causal=False,
576
+ qv=None,
577
+ q_descale=None, k_descale=None, v_descale=None,
578
+ window_size=(-1, -1),
579
+ softcap=0.0,
580
+ num_splits=1,
581
+ pack_gqa=None,
582
+ deterministic=False,
583
+ sm_margin=0,
584
+ s_aux=None,
585
+ ):
586
+ return FlashAttnVarlenFunc.apply(
587
+ q,
588
+ k,
589
+ v,
590
+ cu_seqlens_q,
591
+ cu_seqlens_k,
592
+ seqused_q,
593
+ seqused_k,
594
+ max_seqlen_q,
595
+ max_seqlen_k,
596
+ softmax_scale,
597
+ causal,
598
+ qv,
599
+ q_descale, k_descale, v_descale,
600
+ window_size,
601
+ softcap,
602
+ num_splits,
603
+ pack_gqa,
604
+ deterministic,
605
+ sm_margin,
606
+ s_aux,
607
+ )
608
+
609
+
610
+ def flash_attn_combine(out_partial, lse_partial, out=None, out_dtype=None):
611
+ return ops.fwd_combine(out_partial, lse_partial, out, out_dtype)
612
+
613
+
614
+ def flash_attn_with_kvcache(
615
+ q,
616
+ k_cache,
617
+ v_cache,
618
+ k=None,
619
+ v=None,
620
+ qv=None,
621
+ rotary_cos=None,
622
+ rotary_sin=None,
623
+ cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
624
+ cache_batch_idx: Optional[torch.Tensor] = None,
625
+ cache_leftpad: Optional[torch.Tensor] = None,
626
+ page_table: Optional[torch.Tensor] = None,
627
+ cu_seqlens_q: Optional[torch.Tensor] = None,
628
+ cu_seqlens_k_new: Optional[torch.Tensor] = None,
629
+ max_seqlen_q: Optional[int] = None,
630
+ rotary_seqlens: Optional[torch.Tensor] = None,
631
+ q_descale: Optional[torch.Tensor] = None,
632
+ k_descale: Optional[torch.Tensor] = None,
633
+ v_descale: Optional[torch.Tensor] = None,
634
+ softmax_scale=None,
635
+ causal=False,
636
+ window_size=(-1, -1), # -1 means infinite context window
637
+ softcap=0.0, # 0.0 means deactivated
638
+ rotary_interleaved=True,
639
+ scheduler_metadata=None,
640
+ num_splits=0, # Can be tuned for speed
641
+ pack_gqa=None, # Can be tuned for speed
642
+ sm_margin=0, # Can be tuned if some SMs are used for communication
643
+ return_softmax_lse=False,
644
+ s_aux=None,
645
+ ):
646
+ """
647
+ If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
648
+ k and v. This is useful for incremental decoding: you can pass in the cached keys/values from
649
+ the previous step, and update them with the new keys/values from the current step, and do
650
+ attention with the updated cache, all in 1 kernel.
651
+
652
+ If you pass in k / v, you must make sure that the cache is large enough to hold the new values.
653
+ For example, the KV cache could be pre-allocated with the max sequence length, and you can use
654
+ cache_seqlens to keep track of the current sequence lengths of each sequence in the batch.
655
+
656
+ Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be
657
+ rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
658
+ If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos
659
+ and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
660
+ If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at
661
+ indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens).
662
+
663
+ See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function.
664
+
665
+ Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
666
+ than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
667
+ For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
668
+ 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
669
+
670
+ If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
671
+ For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
672
+ 1 1 1 1 0
673
+ 1 1 1 1 1
674
+ If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
675
+ 0 0
676
+ 0 0
677
+ 0 0
678
+ 1 0
679
+ 1 1
680
+ If the row of the mask is all zero, the output will be zero.
681
+
682
+ If window_size != (-1, -1), implements sliding window local attention. Query at position i
683
+ will only attend to keys between
684
+ [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
685
+
686
+ Note: Does not support backward pass.
687
+
688
+ Arguments:
689
+ q: (batch_size, seqlen, nheads, headdim)
690
+ k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no page_table,
691
+ or (num_blocks, page_block_size, nheads_k, headdim) if there's a page_table (i.e. paged KV cache)
692
+ page_block_size must be a multiple of 256.
693
+ v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim_v) if there's no page_table,
694
+ or (num_blocks, page_block_size, nheads_k, headdim_v) if there's a page_table (i.e. paged KV cache)
695
+ k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
696
+ k with k_cache, starting at the indices specified by cache_seqlens.
697
+ v [optional]: (batch_size, seqlen_new, nheads_k, headdim_v). Similar to k.
698
+ qv [optional]: (batch_size, seqlen, nheads, headdim_v)
699
+ rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding
700
+ to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16.
701
+ rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
702
+ cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
703
+ KV cache.
704
+ cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.
705
+ If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].
706
+ If the indices are not distinct, and k and v are provided, the values updated in the cache
707
+ might come from any of the duplicate indices.
708
+ cache_leftpad: (batch_size,), dtype torch.int32. The index that the KV cache starts. If None, assume 0.
709
+ page_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32.
710
+ softmax_scale: float. The scaling of QK^T before applying softmax.
711
+ Default to 1 / sqrt(headdim).
712
+ causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
713
+ window_size: (left, right). If not (-1, -1), implements sliding window local attention.
714
+ softcap: float. Anything > 0 activates softcapping attention.
715
+ rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
716
+ If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
717
+ rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
718
+ (i.e. GPT-NeoX style).
719
+ num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
720
+ If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
721
+ to automatically determine the number of splits.
722
+ Don't change this unless you know what you are doing.
723
+ return_softmax_lse: bool. Whether to return the logsumexp of the attention scores.
724
+
725
+ Return:
726
+ out: (batch_size, seqlen, nheads, headdim).
727
+ softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The
728
+ logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
729
+ normalization factor).
730
+ """
731
+ assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
732
+ assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
733
+ if softmax_scale is None:
734
+ softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
735
+ if cache_seqlens is not None and isinstance(cache_seqlens, int):
736
+ cache_seqlens = torch.full(
737
+ (k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device
738
+ )
739
+ cache_seqlens = maybe_contiguous(cache_seqlens)
740
+ out, softmax_lse, *rest = _flash_attn_forward(
741
+ q,
742
+ k_cache,
743
+ v_cache,
744
+ k,
745
+ v,
746
+ qv,
747
+ None, # out
748
+ cu_seqlens_q,
749
+ None, # cu_seqlens_k
750
+ cu_seqlens_k_new,
751
+ None, # seqused_q
752
+ cache_seqlens,
753
+ max_seqlen_q,
754
+ None, # max_seqlen_k
755
+ page_table,
756
+ cache_batch_idx,
757
+ cache_leftpad,
758
+ rotary_cos,
759
+ rotary_sin,
760
+ rotary_seqlens,
761
+ q_descale, k_descale, v_descale,
762
+ softmax_scale,
763
+ causal=causal,
764
+ window_size=window_size,
765
+ softcap=softcap,
766
+ rotary_interleaved=rotary_interleaved,
767
+ scheduler_metadata=scheduler_metadata,
768
+ num_splits=num_splits,
769
+ pack_gqa=pack_gqa,
770
+ sm_margin=sm_margin,
771
+ s_aux=s_aux,
772
+ )
773
+ # return (out, softmax_lse) if return_softmax_lse else out
774
+ return (out, softmax_lse, *rest) if return_softmax_lse else out
775
+
776
+
777
+ def get_scheduler_metadata(
778
+ batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim,
779
+ cache_seqlens: torch.Tensor,
780
+ qkv_dtype=torch.bfloat16,
781
+ headdim_v=None,
782
+ cu_seqlens_q: Optional[torch.Tensor] = None,
783
+ cu_seqlens_k_new: Optional[torch.Tensor] = None,
784
+ cache_leftpad: Optional[torch.Tensor] = None,
785
+ page_size: Optional[int] = None,
786
+ max_seqlen_k_new=0,
787
+ causal=False,
788
+ window_size=(-1, -1), # -1 means infinite context window
789
+ has_softcap=False,
790
+ num_splits=0, # Can be tuned for speed
791
+ pack_gqa=None, # Can be tuned for speed
792
+ sm_margin=0, # Can be tuned if some SMs are used for communication
793
+ ):
794
+ cache_seqlens = maybe_contiguous(cache_seqlens)
795
+ if headdim_v is None:
796
+ headdim_v = headdim
797
+ scheduler_metadata = ops.get_scheduler_metadata(
798
+ batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim, headdim_v,
799
+ qkv_dtype,
800
+ cache_seqlens,
801
+ cu_seqlens_q,
802
+ None, # cu_seqlens_k
803
+ cu_seqlens_k_new,
804
+ None, # seqused_q
805
+ cache_leftpad,
806
+ page_size,
807
+ max_seqlen_k_new,
808
+ causal,
809
+ window_size[0], window_size[1],
810
+ has_softcap,
811
+ num_splits,
812
+ pack_gqa,
813
+ sm_margin,
814
+ )
815
+ return scheduler_metadata
build/torch28-cxx11-cu129-x86_64-linux/vllm_flash_attn3/__init__.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .flash_attn_interface import (
2
+ flash_attn_combine,
3
+ flash_attn_func,
4
+ flash_attn_qkvpacked_func,
5
+ flash_attn_varlen_func,
6
+ flash_attn_with_kvcache,
7
+ get_scheduler_metadata,
8
+ )
9
+
10
+ __all__ = [
11
+ "flash_attn_combine",
12
+ "flash_attn_func",
13
+ "flash_attn_qkvpacked_func",
14
+ "flash_attn_varlen_func",
15
+ "flash_attn_with_kvcache",
16
+ "get_scheduler_metadata",
17
+ ]
build/torch28-cxx11-cu129-x86_64-linux/vllm_flash_attn3/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (442 Bytes). View file
 
build/torch28-cxx11-cu129-x86_64-linux/vllm_flash_attn3/__pycache__/_ops.cpython-313.pyc ADDED
Binary file (556 Bytes). View file
 
build/torch28-cxx11-cu129-x86_64-linux/vllm_flash_attn3/__pycache__/flash_attn_interface.cpython-313.pyc ADDED
Binary file (25.6 kB). View file
 
build/torch28-cxx11-cu129-x86_64-linux/vllm_flash_attn3/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _vllm_flash_attn3_28fbd26_dirty
3
+ ops = torch.ops._vllm_flash_attn3_28fbd26_dirty
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_vllm_flash_attn3_28fbd26_dirty::{op_name}"
build/torch28-cxx11-cu129-x86_64-linux/vllm_flash_attn3/_vllm_flash_attn3_28fbd26_dirty.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:29c8dc9754cba50e7eeacfbff19710fedb4152119c57a0c5afa00b036480fe6f
3
+ size 915245760
build/torch28-cxx11-cu129-x86_64-linux/vllm_flash_attn3/flash_attn_interface.py ADDED
@@ -0,0 +1,815 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Tri Dao.
2
+
3
+ from typing import Optional, Union
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ # isort: off
9
+ # We need to import the CUDA kernels after importing torch
10
+ from ._ops import ops
11
+
12
+ # isort: on
13
+
14
+
15
+ def maybe_contiguous(x):
16
+ return x.contiguous() if x is not None and x.stride(-1) != 1 else x
17
+
18
+
19
+ def _flash_attn_forward(
20
+ q,
21
+ k,
22
+ v,
23
+ k_new,
24
+ v_new,
25
+ qv,
26
+ out,
27
+ cu_seqlens_q,
28
+ cu_seqlens_k,
29
+ cu_seqlens_k_new,
30
+ seqused_q,
31
+ seqused_k,
32
+ max_seqlen_q,
33
+ max_seqlen_k,
34
+ page_table,
35
+ kv_batch_idx,
36
+ leftpad_k,
37
+ rotary_cos,
38
+ rotary_sin,
39
+ seqlens_rotary,
40
+ q_descale,
41
+ k_descale,
42
+ v_descale,
43
+ softmax_scale,
44
+ causal,
45
+ window_size=(-1, -1),
46
+ softcap=0.0,
47
+ rotary_interleaved=True,
48
+ scheduler_metadata=None,
49
+ num_splits=1,
50
+ pack_gqa=None,
51
+ sm_margin=0,
52
+ s_aux=None):
53
+ q, k, k_new, v_new = [maybe_contiguous(x) for x in (q, k, k_new, v_new)]
54
+ v = v.contiguous() if v.stride(-1) != 1 and v.stride(-3) != 1 else v
55
+ cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new = [
56
+ maybe_contiguous(x) for x in (cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new)
57
+ ]
58
+ seqused_q, seqused_k = [maybe_contiguous(x) for x in (seqused_q, seqused_k)]
59
+ page_table, kv_batch_idx, leftpad_k = [
60
+ maybe_contiguous(x) for x in (page_table, kv_batch_idx, leftpad_k)
61
+ ]
62
+ rotary_cos, rotary_sin = [maybe_contiguous(x) for x in (rotary_cos, rotary_sin)]
63
+ seqlens_rotary = maybe_contiguous(seqlens_rotary)
64
+ out, softmax_lse, *rest = ops.fwd(
65
+ q,
66
+ k,
67
+ v,
68
+ k_new,
69
+ v_new,
70
+ qv,
71
+ out,
72
+ cu_seqlens_q,
73
+ cu_seqlens_k,
74
+ cu_seqlens_k_new,
75
+ seqused_q,
76
+ seqused_k,
77
+ max_seqlen_q,
78
+ max_seqlen_k,
79
+ page_table,
80
+ kv_batch_idx,
81
+ leftpad_k,
82
+ rotary_cos,
83
+ rotary_sin,
84
+ seqlens_rotary,
85
+ q_descale,
86
+ k_descale,
87
+ v_descale,
88
+ softmax_scale,
89
+ causal,
90
+ window_size[0],
91
+ window_size[1],
92
+ softcap,
93
+ rotary_interleaved,
94
+ scheduler_metadata,
95
+ num_splits,
96
+ pack_gqa,
97
+ sm_margin,
98
+ s_aux
99
+ )
100
+ return out, softmax_lse, *rest
101
+
102
+
103
+ def _flash_attn_backward(
104
+ dout,
105
+ q,
106
+ k,
107
+ v,
108
+ out,
109
+ softmax_lse,
110
+ cu_seqlens_q,
111
+ cu_seqlens_k,
112
+ sequed_q,
113
+ sequed_k,
114
+ max_seqlen_q,
115
+ max_seqlen_k,
116
+ dq,
117
+ dk,
118
+ dv,
119
+ softmax_scale,
120
+ causal,
121
+ window_size=(-1, -1),
122
+ softcap=0.0,
123
+ deterministic=False,
124
+ sm_margin=0,
125
+ ):
126
+ # dq, dk, dv are allocated by us so they should already be contiguous
127
+ dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
128
+ dq, dk, dv, softmax_d, *rest = ops.bwd(
129
+ dout,
130
+ q,
131
+ k,
132
+ v,
133
+ out,
134
+ softmax_lse,
135
+ dq,
136
+ dk,
137
+ dv,
138
+ cu_seqlens_q,
139
+ cu_seqlens_k,
140
+ sequed_q,
141
+ sequed_k,
142
+ max_seqlen_q,
143
+ max_seqlen_k,
144
+ softmax_scale,
145
+ causal,
146
+ window_size[0],
147
+ window_size[1],
148
+ softcap,
149
+ deterministic,
150
+ sm_margin,
151
+ )
152
+ return dq, dk, dv, softmax_d
153
+
154
+
155
+ class FlashAttnQKVPackedFunc(torch.autograd.Function):
156
+ @staticmethod
157
+ def forward(
158
+ ctx,
159
+ qkv,
160
+ softmax_scale,
161
+ causal,
162
+ q_descale=None, k_descale=None, v_descale=None,
163
+ window_size=(-1, -1),
164
+ softcap=0.0,
165
+ deterministic=False,
166
+ num_heads_q=None,
167
+ ):
168
+ if softmax_scale is None:
169
+ softmax_scale = qkv.shape[-1] ** (-0.5)
170
+ if qkv.dim() == 5:
171
+ assert qkv.shape[-3] == 3
172
+ q, k, v = qkv.unbind(dim=-3)
173
+ else:
174
+ assert qkv.dim() == 4
175
+ assert num_heads_q is not None
176
+ num_heads_k = (qkv.shape[2] - num_heads_q) // 2
177
+ assert num_heads_k * 2 + num_heads_q == qkv.shape[2]
178
+ q, k, v = qkv.split([num_heads_q, num_heads_k, num_heads_k], dim=-2)
179
+ out, softmax_lse, *rest = _flash_attn_forward(
180
+ q,
181
+ k,
182
+ v,
183
+ None, None, # k_new, v_new
184
+ None, # qv
185
+ None, # out
186
+ None, None, None, # cu_seqlens_q/k/k_new
187
+ None, None, # seqused_q/k
188
+ None, None, # max_seqlen_q/k
189
+ None, None, None, # page_table, kv_batch_idx, leftpad_k,
190
+ None, None, None, # rotary_cos/sin, seqlens_rotary
191
+ q_descale, k_descale, v_descale,
192
+ softmax_scale,
193
+ causal=causal,
194
+ window_size=window_size,
195
+ softcap=softcap,
196
+ )
197
+ # ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
198
+ ctx.save_for_backward(q, k, v, out, softmax_lse)
199
+ ctx.softmax_scale = softmax_scale
200
+ ctx.causal = causal
201
+ ctx.window_size = window_size
202
+ ctx.softcap = softcap
203
+ ctx.deterministic = deterministic
204
+ ctx.ndim = qkv.dim()
205
+ # return out, softmax_lse
206
+ return out
207
+
208
+ @staticmethod
209
+ def backward(ctx, dout, *args):
210
+ q, k, v, out, softmax_lse = ctx.saved_tensors
211
+ if ctx.ndim == 5:
212
+ qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
213
+ dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
214
+ dq, dk, dv = dqkv.unbind(dim=-3)
215
+ else:
216
+ num_heads_q = q.shape[2]
217
+ num_heads_k = k.shape[2]
218
+ qkv_shape = q.shape[:-2] + (num_heads_q + num_heads_k * 2, *q.shape[-1:])
219
+ dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
220
+ dq, dk, dv = dqkv.split([num_heads_q, num_heads_k, num_heads_k], dim=-2)
221
+ _flash_attn_backward(
222
+ dout,
223
+ q,
224
+ k,
225
+ v,
226
+ out,
227
+ softmax_lse,
228
+ None, None, # cu_seqlens_q, cu_seqlens_k,
229
+ None, None, # sequed_q, sequed_k,
230
+ None, None, # max_seqlen_q, max_seqlen_k,
231
+ dq,
232
+ dk,
233
+ dv,
234
+ ctx.softmax_scale,
235
+ ctx.causal,
236
+ ctx.window_size,
237
+ ctx.softcap,
238
+ ctx.deterministic,
239
+ )
240
+ dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension
241
+ return dqkv, None, None, None, None, None, None, None, None, None, None
242
+
243
+
244
+ class FlashAttnFunc(torch.autograd.Function):
245
+
246
+ @staticmethod
247
+ def forward(
248
+ ctx,
249
+ q,
250
+ k,
251
+ v,
252
+ softmax_scale,
253
+ causal,
254
+ qv=None,
255
+ q_descale=None, k_descale=None, v_descale=None,
256
+ window_size=(-1, -1),
257
+ softcap=0.0,
258
+ num_splits=1,
259
+ pack_gqa=None,
260
+ deterministic=False,
261
+ sm_margin=0,
262
+ s_aux=None,
263
+ ):
264
+ if softmax_scale is None:
265
+ softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
266
+ # out, q, k, v, out_padded, softmax_lse = _flash_attn_forward(
267
+ out, softmax_lse, *rest = _flash_attn_forward(
268
+ q,
269
+ k,
270
+ v,
271
+ None, None, # k_new, v_new
272
+ qv, # qv
273
+ None, # out
274
+ None, None, None, # cu_seqlens_q/k/k_new
275
+ None, None, # seqused_q/k
276
+ None, None, # max_seqlen_q/k
277
+ None, None, None, # page_table, kv_batch_idx, leftpad_k,
278
+ None, None, None, # rotary_cos/sin, seqlens_rotary
279
+ q_descale, k_descale, v_descale,
280
+ softmax_scale,
281
+ causal=causal,
282
+ window_size=window_size,
283
+ softcap=softcap,
284
+ num_splits=num_splits,
285
+ pack_gqa=pack_gqa,
286
+ sm_margin=sm_margin,
287
+ s_aux=s_aux,
288
+ )
289
+ # ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
290
+ ctx.save_for_backward(q, k, v, out, softmax_lse)
291
+ ctx.softmax_scale = softmax_scale
292
+ ctx.causal = causal
293
+ ctx.window_size = window_size
294
+ ctx.softcap = softcap
295
+ ctx.deterministic = deterministic
296
+ ctx.sm_margin = sm_margin
297
+ return out, softmax_lse
298
+
299
+ @staticmethod
300
+ def backward(ctx, dout, *args):
301
+ q, k, v, out, softmax_lse = ctx.saved_tensors
302
+ dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
303
+ _flash_attn_backward(
304
+ dout,
305
+ q,
306
+ k,
307
+ v,
308
+ out,
309
+ softmax_lse,
310
+ None, None, # cu_seqlens_q, cu_seqlens_k,
311
+ None, None, # sequed_q, sequed_k,
312
+ None, None, # max_seqlen_q, max_seqlen_k,
313
+ dq,
314
+ dk,
315
+ dv,
316
+ ctx.softmax_scale,
317
+ ctx.causal,
318
+ ctx.window_size,
319
+ ctx.softcap,
320
+ ctx.deterministic,
321
+ ctx.sm_margin,
322
+ )
323
+ dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
324
+ dk = dk[..., : dout.shape[-1]]
325
+ dv = dv[..., : dout.shape[-1]]
326
+ return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None
327
+
328
+
329
+ class FlashAttnVarlenFunc(torch.autograd.Function):
330
+
331
+ @staticmethod
332
+ def forward(
333
+ ctx,
334
+ q,
335
+ k,
336
+ v,
337
+ cu_seqlens_q,
338
+ cu_seqlens_k,
339
+ seqused_q,
340
+ seqused_k,
341
+ max_seqlen_q,
342
+ max_seqlen_k,
343
+ softmax_scale,
344
+ causal,
345
+ qv=None,
346
+ q_descale=None, k_descale=None, v_descale=None,
347
+ window_size=(-1, -1),
348
+ softcap=0.0,
349
+ num_splits=1,
350
+ pack_gqa=None,
351
+ deterministic=False,
352
+ sm_margin=0,
353
+ s_aux=None,
354
+ ):
355
+ if softmax_scale is None:
356
+ softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
357
+ # out, q, k, v, out_padded, softmax_lse = _flash_attn_varlen_forward(
358
+ out, softmax_lse, *rest = _flash_attn_forward(
359
+ q,
360
+ k,
361
+ v,
362
+ None, None, # k_new, v_new
363
+ qv, # qv
364
+ None, # out
365
+ cu_seqlens_q,
366
+ cu_seqlens_k,
367
+ None, # cu_seqlens_k_new
368
+ seqused_q,
369
+ seqused_k,
370
+ max_seqlen_q,
371
+ max_seqlen_k,
372
+ None, None, None, # page_table, kv_batch_idx, leftpad_k,
373
+ None, None, None, # rotary_cos/sin, seqlens_rotary
374
+ q_descale, k_descale, v_descale,
375
+ softmax_scale,
376
+ causal=causal,
377
+ window_size=window_size,
378
+ softcap=softcap,
379
+ num_splits=num_splits,
380
+ pack_gqa=pack_gqa,
381
+ sm_margin=sm_margin,
382
+ s_aux=s_aux,
383
+ )
384
+ # ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
385
+ ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
386
+ ctx.max_seqlen_q = max_seqlen_q
387
+ ctx.max_seqlen_k = max_seqlen_k
388
+ ctx.softmax_scale = softmax_scale
389
+ ctx.causal = causal
390
+ ctx.window_size = window_size
391
+ ctx.softcap = softcap
392
+ ctx.deterministic = deterministic
393
+ ctx.sm_margin = sm_margin
394
+ return out, softmax_lse
395
+
396
+ @staticmethod
397
+ def backward(ctx, dout, *args):
398
+ q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = ctx.saved_tensors
399
+ dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
400
+ _flash_attn_backward(
401
+ dout,
402
+ q,
403
+ k,
404
+ v,
405
+ out,
406
+ softmax_lse,
407
+ cu_seqlens_q,
408
+ cu_seqlens_k,
409
+ seqused_q,
410
+ seqused_k,
411
+ ctx.max_seqlen_q,
412
+ ctx.max_seqlen_k,
413
+ dq,
414
+ dk,
415
+ dv,
416
+ ctx.softmax_scale,
417
+ ctx.causal,
418
+ ctx.window_size,
419
+ ctx.softcap,
420
+ ctx.deterministic,
421
+ ctx.sm_margin,
422
+ )
423
+ dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
424
+ dk = dk[..., : dout.shape[-1]]
425
+ dv = dv[..., : dout.shape[-1]]
426
+ return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None
427
+
428
+
429
+ def flash_attn_qkvpacked_func(
430
+ qkv,
431
+ softmax_scale=None,
432
+ causal=False,
433
+ q_descale=None, k_descale=None, v_descale=None,
434
+ window_size=(-1, -1),
435
+ softcap=0.0,
436
+ deterministic=False,
437
+ num_heads_q=None,
438
+ ):
439
+ """dropout_p should be set to 0.0 during evaluation
440
+ If Q, K, V are already stacked into 1 tensor, this function will be faster than
441
+ calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
442
+ of the gradients of Q, K, V.
443
+ For multi-query and grouped-query attention (MQA/GQA), please see
444
+ flash_attn_kvpacked_func and flash_attn_func.
445
+
446
+ If window_size != (-1, -1), implements sliding window local attention. Query at position i
447
+ will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.
448
+
449
+ Arguments:
450
+ qkv: (batch_size, seqlen, 3, nheads, headdim)
451
+ dropout_p: float. Dropout probability.
452
+ softmax_scale: float. The scaling of QK^T before applying softmax.
453
+ Default to 1 / sqrt(headdim).
454
+ causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
455
+ window_size: (left, right). If not (-1, -1), implements sliding window local attention.
456
+ softcap: float. Anything > 0 activates softcapping attention.
457
+ alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to
458
+ the attention score of query i and key j.
459
+ deterministic: bool. Whether to use the deterministic implementation of the backward pass,
460
+ which is slightly slower and uses more memory. The forward pass is always deterministic.
461
+ return_attn_probs: bool. Whether to return the attention probabilities. This option is for
462
+ testing only. The returned probabilities are not guaranteed to be correct
463
+ (they might not have the right scaling).
464
+ Return:
465
+ out: (batch_size, seqlen, nheads, headdim).
466
+ softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
467
+ logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
468
+ normalization factor).
469
+ S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
470
+ The output of softmax (possibly with different scaling). It also encodes the dropout
471
+ pattern (negative means that location was dropped, nonnegative means it was kept).
472
+ """
473
+ return FlashAttnQKVPackedFunc.apply(
474
+ qkv,
475
+ softmax_scale,
476
+ causal,
477
+ q_descale, k_descale, v_descale,
478
+ window_size,
479
+ softcap,
480
+ deterministic,
481
+ num_heads_q,
482
+ )
483
+
484
+
485
+ def flash_attn_func(
486
+ q,
487
+ k,
488
+ v,
489
+ softmax_scale=None,
490
+ causal=False,
491
+ qv=None,
492
+ q_descale=None, k_descale=None, v_descale=None,
493
+ window_size=(-1, -1),
494
+ softcap=0.0,
495
+ num_splits=1,
496
+ pack_gqa=None,
497
+ deterministic=False,
498
+ sm_margin=0,
499
+ s_aux=None,
500
+ ):
501
+ """dropout_p should be set to 0.0 during evaluation
502
+ Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
503
+ than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
504
+ For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
505
+ 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
506
+
507
+ If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
508
+ For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
509
+ 1 1 1 1 0
510
+ 1 1 1 1 1
511
+ If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
512
+ 0 0
513
+ 0 0
514
+ 0 0
515
+ 1 0
516
+ 1 1
517
+ If the row of the mask is all zero, the output will be zero.
518
+
519
+ If window_size != (-1, -1), implements sliding window local attention. Query at position i
520
+ will only attend to keys between
521
+ [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
522
+
523
+ Arguments:
524
+ q: (batch_size, seqlen, nheads, headdim)
525
+ k: (batch_size, seqlen, nheads_k, headdim)
526
+ v: (batch_size, seqlen, nheads_k, headdim)
527
+ dropout_p: float. Dropout probability.
528
+ softmax_scale: float. The scaling of QK^T before applying softmax.
529
+ Default to 1 / sqrt(headdim).
530
+ causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
531
+ window_size: (left, right). If not (-1, -1), implements sliding window local attention.
532
+ alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
533
+ (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
534
+ is added to the attention score of query i and key j.
535
+ deterministic: bool. Whether to use the deterministic implementation of the backward pass,
536
+ which is slightly slower and uses more memory. The forward pass is always deterministic.
537
+ return_attn_probs: bool. Whether to return the attention probabilities. This option is for
538
+ testing only. The returned probabilities are not guaranteed to be correct
539
+ (they might not have the right scaling).
540
+ Return:
541
+ out: (batch_size, seqlen, nheads, headdim).
542
+ softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
543
+ logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
544
+ normalization factor).
545
+ """
546
+ return FlashAttnFunc.apply(
547
+ q,
548
+ k,
549
+ v,
550
+ softmax_scale,
551
+ causal,
552
+ qv,
553
+ q_descale, k_descale, v_descale,
554
+ window_size,
555
+ softcap,
556
+ num_splits,
557
+ pack_gqa,
558
+ deterministic,
559
+ sm_margin,
560
+ s_aux,
561
+ )
562
+
563
+
564
+ def flash_attn_varlen_func(
565
+ q,
566
+ k,
567
+ v,
568
+ cu_seqlens_q,
569
+ cu_seqlens_k,
570
+ max_seqlen_q,
571
+ max_seqlen_k,
572
+ seqused_q=None,
573
+ seqused_k=None,
574
+ softmax_scale=None,
575
+ causal=False,
576
+ qv=None,
577
+ q_descale=None, k_descale=None, v_descale=None,
578
+ window_size=(-1, -1),
579
+ softcap=0.0,
580
+ num_splits=1,
581
+ pack_gqa=None,
582
+ deterministic=False,
583
+ sm_margin=0,
584
+ s_aux=None,
585
+ ):
586
+ return FlashAttnVarlenFunc.apply(
587
+ q,
588
+ k,
589
+ v,
590
+ cu_seqlens_q,
591
+ cu_seqlens_k,
592
+ seqused_q,
593
+ seqused_k,
594
+ max_seqlen_q,
595
+ max_seqlen_k,
596
+ softmax_scale,
597
+ causal,
598
+ qv,
599
+ q_descale, k_descale, v_descale,
600
+ window_size,
601
+ softcap,
602
+ num_splits,
603
+ pack_gqa,
604
+ deterministic,
605
+ sm_margin,
606
+ s_aux,
607
+ )
608
+
609
+
610
+ def flash_attn_combine(out_partial, lse_partial, out=None, out_dtype=None):
611
+ return ops.fwd_combine(out_partial, lse_partial, out, out_dtype)
612
+
613
+
614
+ def flash_attn_with_kvcache(
615
+ q,
616
+ k_cache,
617
+ v_cache,
618
+ k=None,
619
+ v=None,
620
+ qv=None,
621
+ rotary_cos=None,
622
+ rotary_sin=None,
623
+ cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
624
+ cache_batch_idx: Optional[torch.Tensor] = None,
625
+ cache_leftpad: Optional[torch.Tensor] = None,
626
+ page_table: Optional[torch.Tensor] = None,
627
+ cu_seqlens_q: Optional[torch.Tensor] = None,
628
+ cu_seqlens_k_new: Optional[torch.Tensor] = None,
629
+ max_seqlen_q: Optional[int] = None,
630
+ rotary_seqlens: Optional[torch.Tensor] = None,
631
+ q_descale: Optional[torch.Tensor] = None,
632
+ k_descale: Optional[torch.Tensor] = None,
633
+ v_descale: Optional[torch.Tensor] = None,
634
+ softmax_scale=None,
635
+ causal=False,
636
+ window_size=(-1, -1), # -1 means infinite context window
637
+ softcap=0.0, # 0.0 means deactivated
638
+ rotary_interleaved=True,
639
+ scheduler_metadata=None,
640
+ num_splits=0, # Can be tuned for speed
641
+ pack_gqa=None, # Can be tuned for speed
642
+ sm_margin=0, # Can be tuned if some SMs are used for communication
643
+ return_softmax_lse=False,
644
+ s_aux=None,
645
+ ):
646
+ """
647
+ If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
648
+ k and v. This is useful for incremental decoding: you can pass in the cached keys/values from
649
+ the previous step, and update them with the new keys/values from the current step, and do
650
+ attention with the updated cache, all in 1 kernel.
651
+
652
+ If you pass in k / v, you must make sure that the cache is large enough to hold the new values.
653
+ For example, the KV cache could be pre-allocated with the max sequence length, and you can use
654
+ cache_seqlens to keep track of the current sequence lengths of each sequence in the batch.
655
+
656
+ Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be
657
+ rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
658
+ If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos
659
+ and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
660
+ If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at
661
+ indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens).
662
+
663
+ See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function.
664
+
665
+ Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
666
+ than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
667
+ For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
668
+ 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
669
+
670
+ If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
671
+ For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
672
+ 1 1 1 1 0
673
+ 1 1 1 1 1
674
+ If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
675
+ 0 0
676
+ 0 0
677
+ 0 0
678
+ 1 0
679
+ 1 1
680
+ If the row of the mask is all zero, the output will be zero.
681
+
682
+ If window_size != (-1, -1), implements sliding window local attention. Query at position i
683
+ will only attend to keys between
684
+ [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
685
+
686
+ Note: Does not support backward pass.
687
+
688
+ Arguments:
689
+ q: (batch_size, seqlen, nheads, headdim)
690
+ k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no page_table,
691
+ or (num_blocks, page_block_size, nheads_k, headdim) if there's a page_table (i.e. paged KV cache)
692
+ page_block_size must be a multiple of 256.
693
+ v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim_v) if there's no page_table,
694
+ or (num_blocks, page_block_size, nheads_k, headdim_v) if there's a page_table (i.e. paged KV cache)
695
+ k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
696
+ k with k_cache, starting at the indices specified by cache_seqlens.
697
+ v [optional]: (batch_size, seqlen_new, nheads_k, headdim_v). Similar to k.
698
+ qv [optional]: (batch_size, seqlen, nheads, headdim_v)
699
+ rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding
700
+ to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16.
701
+ rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
702
+ cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
703
+ KV cache.
704
+ cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.
705
+ If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].
706
+ If the indices are not distinct, and k and v are provided, the values updated in the cache
707
+ might come from any of the duplicate indices.
708
+ cache_leftpad: (batch_size,), dtype torch.int32. The index that the KV cache starts. If None, assume 0.
709
+ page_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32.
710
+ softmax_scale: float. The scaling of QK^T before applying softmax.
711
+ Default to 1 / sqrt(headdim).
712
+ causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
713
+ window_size: (left, right). If not (-1, -1), implements sliding window local attention.
714
+ softcap: float. Anything > 0 activates softcapping attention.
715
+ rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
716
+ If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
717
+ rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
718
+ (i.e. GPT-NeoX style).
719
+ num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
720
+ If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
721
+ to automatically determine the number of splits.
722
+ Don't change this unless you know what you are doing.
723
+ return_softmax_lse: bool. Whether to return the logsumexp of the attention scores.
724
+
725
+ Return:
726
+ out: (batch_size, seqlen, nheads, headdim).
727
+ softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The
728
+ logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
729
+ normalization factor).
730
+ """
731
+ assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
732
+ assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
733
+ if softmax_scale is None:
734
+ softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
735
+ if cache_seqlens is not None and isinstance(cache_seqlens, int):
736
+ cache_seqlens = torch.full(
737
+ (k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device
738
+ )
739
+ cache_seqlens = maybe_contiguous(cache_seqlens)
740
+ out, softmax_lse, *rest = _flash_attn_forward(
741
+ q,
742
+ k_cache,
743
+ v_cache,
744
+ k,
745
+ v,
746
+ qv,
747
+ None, # out
748
+ cu_seqlens_q,
749
+ None, # cu_seqlens_k
750
+ cu_seqlens_k_new,
751
+ None, # seqused_q
752
+ cache_seqlens,
753
+ max_seqlen_q,
754
+ None, # max_seqlen_k
755
+ page_table,
756
+ cache_batch_idx,
757
+ cache_leftpad,
758
+ rotary_cos,
759
+ rotary_sin,
760
+ rotary_seqlens,
761
+ q_descale, k_descale, v_descale,
762
+ softmax_scale,
763
+ causal=causal,
764
+ window_size=window_size,
765
+ softcap=softcap,
766
+ rotary_interleaved=rotary_interleaved,
767
+ scheduler_metadata=scheduler_metadata,
768
+ num_splits=num_splits,
769
+ pack_gqa=pack_gqa,
770
+ sm_margin=sm_margin,
771
+ s_aux=s_aux,
772
+ )
773
+ # return (out, softmax_lse) if return_softmax_lse else out
774
+ return (out, softmax_lse, *rest) if return_softmax_lse else out
775
+
776
+
777
+ def get_scheduler_metadata(
778
+ batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim,
779
+ cache_seqlens: torch.Tensor,
780
+ qkv_dtype=torch.bfloat16,
781
+ headdim_v=None,
782
+ cu_seqlens_q: Optional[torch.Tensor] = None,
783
+ cu_seqlens_k_new: Optional[torch.Tensor] = None,
784
+ cache_leftpad: Optional[torch.Tensor] = None,
785
+ page_size: Optional[int] = None,
786
+ max_seqlen_k_new=0,
787
+ causal=False,
788
+ window_size=(-1, -1), # -1 means infinite context window
789
+ has_softcap=False,
790
+ num_splits=0, # Can be tuned for speed
791
+ pack_gqa=None, # Can be tuned for speed
792
+ sm_margin=0, # Can be tuned if some SMs are used for communication
793
+ ):
794
+ cache_seqlens = maybe_contiguous(cache_seqlens)
795
+ if headdim_v is None:
796
+ headdim_v = headdim
797
+ scheduler_metadata = ops.get_scheduler_metadata(
798
+ batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim, headdim_v,
799
+ qkv_dtype,
800
+ cache_seqlens,
801
+ cu_seqlens_q,
802
+ None, # cu_seqlens_k
803
+ cu_seqlens_k_new,
804
+ None, # seqused_q
805
+ cache_leftpad,
806
+ page_size,
807
+ max_seqlen_k_new,
808
+ causal,
809
+ window_size[0], window_size[1],
810
+ has_softcap,
811
+ num_splits,
812
+ pack_gqa,
813
+ sm_margin,
814
+ )
815
+ return scheduler_metadata