kernel
danieldk HF Staff commited on
Commit
eb8ddce
·
0 Parent(s):

Convert FA3 to Kernel Hub format

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. build.toml +593 -0
  2. flake.lock +168 -0
  3. flake.nix +17 -0
  4. flash-attn/block.h +139 -0
  5. flash-attn/copy_sm90_bulk_reduce.hpp +49 -0
  6. flash-attn/cuda_check.h +19 -0
  7. flash-attn/epilogue_bwd.hpp +533 -0
  8. flash-attn/epilogue_fwd.hpp +484 -0
  9. flash-attn/flash.h +218 -0
  10. flash-attn/flash_api.cpp +1720 -0
  11. flash-attn/flash_bwd_kernel_sm80.h +173 -0
  12. flash-attn/flash_bwd_kernel_sm90.h +282 -0
  13. flash-attn/flash_bwd_launch_template.h +390 -0
  14. flash-attn/flash_bwd_postprocess_kernel.h +256 -0
  15. flash-attn/flash_bwd_preprocess_kernel.h +252 -0
  16. flash-attn/flash_fwd_combine.cu +13 -0
  17. flash-attn/flash_fwd_combine_kernel.h +482 -0
  18. flash-attn/flash_fwd_combine_launch_template.h +80 -0
  19. flash-attn/flash_fwd_kernel_sm80.h +215 -0
  20. flash-attn/flash_fwd_kernel_sm90.h +458 -0
  21. flash-attn/flash_fwd_launch_template.h +223 -0
  22. flash-attn/flash_prepare_scheduler.cu +124 -0
  23. flash-attn/heuristics.h +59 -0
  24. flash-attn/instantiations/flash_bwd_hdim128_bf16_sm80.cu +18 -0
  25. flash-attn/instantiations/flash_bwd_hdim128_bf16_sm90.cu +12 -0
  26. flash-attn/instantiations/flash_bwd_hdim128_bf16_softcap_sm80.cu +18 -0
  27. flash-attn/instantiations/flash_bwd_hdim128_bf16_softcap_sm90.cu +12 -0
  28. flash-attn/instantiations/flash_bwd_hdim128_bf16_softcapall_sm90.cu +6 -0
  29. flash-attn/instantiations/flash_bwd_hdim128_fp16_sm80.cu +18 -0
  30. flash-attn/instantiations/flash_bwd_hdim128_fp16_sm90.cu +12 -0
  31. flash-attn/instantiations/flash_bwd_hdim128_fp16_softcap_sm80.cu +18 -0
  32. flash-attn/instantiations/flash_bwd_hdim128_fp16_softcap_sm90.cu +12 -0
  33. flash-attn/instantiations/flash_bwd_hdim128_fp16_softcapall_sm90.cu +6 -0
  34. flash-attn/instantiations/flash_bwd_hdim192_bf16_sm80.cu +18 -0
  35. flash-attn/instantiations/flash_bwd_hdim192_bf16_sm90.cu +12 -0
  36. flash-attn/instantiations/flash_bwd_hdim192_bf16_softcap_sm80.cu +18 -0
  37. flash-attn/instantiations/flash_bwd_hdim192_bf16_softcap_sm90.cu +12 -0
  38. flash-attn/instantiations/flash_bwd_hdim192_bf16_softcapall_sm90.cu +6 -0
  39. flash-attn/instantiations/flash_bwd_hdim192_fp16_sm80.cu +18 -0
  40. flash-attn/instantiations/flash_bwd_hdim192_fp16_sm90.cu +12 -0
  41. flash-attn/instantiations/flash_bwd_hdim192_fp16_softcap_sm80.cu +18 -0
  42. flash-attn/instantiations/flash_bwd_hdim192_fp16_softcap_sm90.cu +12 -0
  43. flash-attn/instantiations/flash_bwd_hdim192_fp16_softcapall_sm90.cu +6 -0
  44. flash-attn/instantiations/flash_bwd_hdim256_bf16_sm80.cu +18 -0
  45. flash-attn/instantiations/flash_bwd_hdim256_bf16_sm90.cu +12 -0
  46. flash-attn/instantiations/flash_bwd_hdim256_bf16_softcap_sm80.cu +18 -0
  47. flash-attn/instantiations/flash_bwd_hdim256_bf16_softcap_sm90.cu +12 -0
  48. flash-attn/instantiations/flash_bwd_hdim256_bf16_softcapall_sm90.cu +6 -0
  49. flash-attn/instantiations/flash_bwd_hdim256_fp16_sm80.cu +18 -0
  50. flash-attn/instantiations/flash_bwd_hdim256_fp16_sm90.cu +12 -0
build.toml ADDED
@@ -0,0 +1,593 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [general]
2
+ name = "flash_attn3"
3
+ universal = false
4
+ cuda-minver = "12.4"
5
+ cuda-maxver = "12.4"
6
+
7
+ [torch]
8
+ src = [
9
+ "torch-ext/pytorch_shim.h",
10
+ "torch-ext/torch_binding.cpp",
11
+ "torch-ext/torch_binding.h",
12
+ ]
13
+
14
+ [kernel.flash_attn]
15
+ backend = "cuda"
16
+ cuda-capabilities = ["8.0", "9.0a"]
17
+ cuda-flags = [
18
+ "-O3",
19
+ "-std=c++17",
20
+ "--ftemplate-backtrace-limit=0", # To debug template code
21
+ "--use_fast_math",
22
+ "-DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED",
23
+ "-DCUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1",
24
+ "-DCUTLASS_ENABLE_GDC_FOR_SM90",
25
+ "--expt-relaxed-constexpr",
26
+ "--expt-extended-lambda",
27
+ "--use_fast_math",
28
+ "-DNDEBUG",
29
+ ]
30
+
31
+ src = [
32
+ "flash-attn/cuda_check.h",
33
+ "flash-attn/flash_api.cpp",
34
+ "flash-attn/flash_fwd_combine.cu",
35
+ "flash-attn/flash_fwd_combine_kernel.h",
36
+ "flash-attn/flash_fwd_combine_launch_template.h",
37
+ "flash-attn/flash.h",
38
+ "flash-attn/flash_prepare_scheduler.cu",
39
+ "flash-attn/heuristics.h",
40
+ "flash-attn/seqlen.h",
41
+ "flash-attn/static_switch.h",
42
+ "flash-attn/tile_size.h",
43
+ "flash-attn/utils.h",
44
+ ]
45
+ depends = ["torch", "cutlass_3_9"]
46
+
47
+ [kernel.flash_attn_sm80]
48
+ backend = "cuda"
49
+ cuda-capabilities = ["8.0", "9.0a"]
50
+ cuda-flags = [
51
+ "-O3",
52
+ "-std=c++17",
53
+ "--ftemplate-backtrace-limit=0", # To debug template code
54
+ "--use_fast_math",
55
+ "-DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED",
56
+ "-DCUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1",
57
+ "-DCUTLASS_ENABLE_GDC_FOR_SM90",
58
+ "--expt-relaxed-constexpr",
59
+ "--expt-extended-lambda",
60
+ "--use_fast_math",
61
+ "-DNDEBUG",
62
+ ]
63
+ src = [
64
+ "flash-attn/block.h",
65
+ "flash-attn/copy_sm90_bulk_reduce.hpp",
66
+ "flash-attn/epilogue_bwd.hpp",
67
+ "flash-attn/epilogue_fwd.hpp",
68
+ "flash-attn/flash.h",
69
+ "flash-attn/flash_bwd_kernel_sm80.h",
70
+ "flash-attn/flash_bwd_kernel_sm90.h",
71
+ "flash-attn/flash_bwd_launch_template.h",
72
+ "flash-attn/flash_bwd_postprocess_kernel.h",
73
+ "flash-attn/flash_bwd_preprocess_kernel.h",
74
+ "flash-attn/flash_fwd_launch_template.h",
75
+ "flash-attn/flash_fwd_kernel_sm80.h",
76
+ "flash-attn/flash_fwd_kernel_sm90.h",
77
+ "flash-attn/heuristics.h",
78
+ "flash-attn/mainloop_bwd_sm80.hpp",
79
+ "flash-attn/mainloop_fwd_sm80.hpp",
80
+ "flash-attn/mainloop_bwd_sm90_tma_gmma_ws.hpp",
81
+ "flash-attn/mainloop_fwd_sm90_tma_gmma_ws.hpp",
82
+ "flash-attn/mask.h",
83
+ "flash-attn/named_barrier.hpp",
84
+ "flash-attn/pack_gqa.h",
85
+ "flash-attn/paged_kv.h",
86
+ "flash-attn/rotary.h",
87
+ "flash-attn/sm90_pipeline_no_cluster.hpp",
88
+ "flash-attn/softmax.h",
89
+ "flash-attn/tile_size.h",
90
+ "flash-attn/tile_scheduler.hpp",
91
+
92
+ "flash-attn/instantiations/flash_bwd_hdim128_bf16_sm80.cu",
93
+ "flash-attn/instantiations/flash_bwd_hdim128_bf16_softcap_sm80.cu",
94
+ "flash-attn/instantiations/flash_bwd_hdim128_fp16_sm80.cu",
95
+ "flash-attn/instantiations/flash_bwd_hdim128_fp16_softcap_sm80.cu",
96
+ "flash-attn/instantiations/flash_bwd_hdim192_bf16_sm80.cu",
97
+ "flash-attn/instantiations/flash_bwd_hdim192_bf16_softcap_sm80.cu",
98
+ "flash-attn/instantiations/flash_bwd_hdim192_fp16_sm80.cu",
99
+ "flash-attn/instantiations/flash_bwd_hdim192_fp16_softcap_sm80.cu",
100
+ "flash-attn/instantiations/flash_bwd_hdim256_bf16_sm80.cu",
101
+ "flash-attn/instantiations/flash_bwd_hdim256_bf16_softcap_sm80.cu",
102
+ "flash-attn/instantiations/flash_bwd_hdim256_fp16_sm80.cu",
103
+ "flash-attn/instantiations/flash_bwd_hdim256_fp16_softcap_sm80.cu",
104
+ "flash-attn/instantiations/flash_bwd_hdim64_bf16_sm80.cu",
105
+ "flash-attn/instantiations/flash_bwd_hdim64_bf16_softcap_sm80.cu",
106
+ "flash-attn/instantiations/flash_bwd_hdim64_fp16_sm80.cu",
107
+ "flash-attn/instantiations/flash_bwd_hdim64_fp16_softcap_sm80.cu",
108
+ "flash-attn/instantiations/flash_bwd_hdim96_bf16_sm80.cu",
109
+ "flash-attn/instantiations/flash_bwd_hdim96_bf16_softcap_sm80.cu",
110
+ "flash-attn/instantiations/flash_bwd_hdim96_fp16_sm80.cu",
111
+ "flash-attn/instantiations/flash_bwd_hdim96_fp16_softcap_sm80.cu",
112
+ "flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_sm80.cu",
113
+ "flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_softcap_sm80.cu",
114
+ "flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_split_sm80.cu",
115
+ "flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_split_softcap_sm80.cu",
116
+ "flash-attn/instantiations/flash_fwd_hdim128_bf16_sm80.cu",
117
+ "flash-attn/instantiations/flash_fwd_hdim128_bf16_softcap_sm80.cu",
118
+ "flash-attn/instantiations/flash_fwd_hdim128_bf16_split_sm80.cu",
119
+ "flash-attn/instantiations/flash_fwd_hdim128_bf16_split_softcap_sm80.cu",
120
+ "flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_sm80.cu",
121
+ "flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_softcap_sm80.cu",
122
+ "flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_split_sm80.cu",
123
+ "flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_split_softcap_sm80.cu",
124
+ "flash-attn/instantiations/flash_fwd_hdim128_fp16_sm80.cu",
125
+ "flash-attn/instantiations/flash_fwd_hdim128_fp16_softcap_sm80.cu",
126
+ "flash-attn/instantiations/flash_fwd_hdim128_fp16_split_sm80.cu",
127
+ "flash-attn/instantiations/flash_fwd_hdim128_fp16_split_softcap_sm80.cu",
128
+ "flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_sm80.cu",
129
+ "flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_softcap_sm80.cu",
130
+ "flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_split_sm80.cu",
131
+ "flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_split_softcap_sm80.cu",
132
+ "flash-attn/instantiations/flash_fwd_hdim192_bf16_sm80.cu",
133
+ "flash-attn/instantiations/flash_fwd_hdim192_bf16_softcap_sm80.cu",
134
+ "flash-attn/instantiations/flash_fwd_hdim192_bf16_split_sm80.cu",
135
+ "flash-attn/instantiations/flash_fwd_hdim192_bf16_split_softcap_sm80.cu",
136
+ "flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_sm80.cu",
137
+ "flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_softcap_sm80.cu",
138
+ "flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_split_sm80.cu",
139
+ "flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_split_softcap_sm80.cu",
140
+ "flash-attn/instantiations/flash_fwd_hdim192_fp16_sm80.cu",
141
+ "flash-attn/instantiations/flash_fwd_hdim192_fp16_softcap_sm80.cu",
142
+ "flash-attn/instantiations/flash_fwd_hdim192_fp16_split_sm80.cu",
143
+ "flash-attn/instantiations/flash_fwd_hdim192_fp16_split_softcap_sm80.cu",
144
+ "flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_sm80.cu",
145
+ "flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_softcap_sm80.cu",
146
+ "flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_split_sm80.cu",
147
+ "flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_split_softcap_sm80.cu",
148
+ "flash-attn/instantiations/flash_fwd_hdim256_bf16_sm80.cu",
149
+ "flash-attn/instantiations/flash_fwd_hdim256_bf16_softcap_sm80.cu",
150
+ "flash-attn/instantiations/flash_fwd_hdim256_bf16_split_sm80.cu",
151
+ "flash-attn/instantiations/flash_fwd_hdim256_bf16_split_softcap_sm80.cu",
152
+ "flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_sm80.cu",
153
+ "flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_softcap_sm80.cu",
154
+ "flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_split_sm80.cu",
155
+ "flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_split_softcap_sm80.cu",
156
+ "flash-attn/instantiations/flash_fwd_hdim256_fp16_sm80.cu",
157
+ "flash-attn/instantiations/flash_fwd_hdim256_fp16_softcap_sm80.cu",
158
+ "flash-attn/instantiations/flash_fwd_hdim256_fp16_split_sm80.cu",
159
+ "flash-attn/instantiations/flash_fwd_hdim256_fp16_split_softcap_sm80.cu",
160
+ "flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_sm80.cu",
161
+ "flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_softcap_sm80.cu",
162
+ "flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_split_sm80.cu",
163
+ "flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_split_softcap_sm80.cu",
164
+ "flash-attn/instantiations/flash_fwd_hdim64_bf16_sm80.cu",
165
+ "flash-attn/instantiations/flash_fwd_hdim64_bf16_softcap_sm80.cu",
166
+ "flash-attn/instantiations/flash_fwd_hdim64_bf16_split_sm80.cu",
167
+ "flash-attn/instantiations/flash_fwd_hdim64_bf16_split_softcap_sm80.cu",
168
+ "flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_sm80.cu",
169
+ "flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_softcap_sm80.cu",
170
+ "flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_split_sm80.cu",
171
+ "flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_split_softcap_sm80.cu",
172
+ "flash-attn/instantiations/flash_fwd_hdim64_fp16_sm80.cu",
173
+ "flash-attn/instantiations/flash_fwd_hdim64_fp16_softcap_sm80.cu",
174
+ "flash-attn/instantiations/flash_fwd_hdim64_fp16_split_sm80.cu",
175
+ "flash-attn/instantiations/flash_fwd_hdim64_fp16_split_softcap_sm80.cu",
176
+ "flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_sm80.cu",
177
+ "flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_softcap_sm80.cu",
178
+ "flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_split_sm80.cu",
179
+ "flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_split_softcap_sm80.cu",
180
+ "flash-attn/instantiations/flash_fwd_hdim96_bf16_sm80.cu",
181
+ "flash-attn/instantiations/flash_fwd_hdim96_bf16_softcap_sm80.cu",
182
+ "flash-attn/instantiations/flash_fwd_hdim96_bf16_split_sm80.cu",
183
+ "flash-attn/instantiations/flash_fwd_hdim96_bf16_split_softcap_sm80.cu",
184
+ "flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_sm80.cu",
185
+ "flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_softcap_sm80.cu",
186
+ "flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_split_sm80.cu",
187
+ "flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_split_softcap_sm80.cu",
188
+ "flash-attn/instantiations/flash_fwd_hdim96_fp16_sm80.cu",
189
+ "flash-attn/instantiations/flash_fwd_hdim96_fp16_softcap_sm80.cu",
190
+ "flash-attn/instantiations/flash_fwd_hdim96_fp16_split_sm80.cu",
191
+ "flash-attn/instantiations/flash_fwd_hdim96_fp16_split_softcap_sm80.cu"
192
+ ]
193
+ include = ["flash-attn"]
194
+ depends = ["torch", "cutlass_3_9"]
195
+
196
+ [kernel.flash_attn_sm90]
197
+ backend = "cuda"
198
+ cuda-capabilities = ["8.0", "9.0a"]
199
+ cuda-flags = [
200
+ "-O3",
201
+ "-std=c++17",
202
+ "--ftemplate-backtrace-limit=0", # To debug template code
203
+ "--use_fast_math",
204
+ "-DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED",
205
+ "-DCUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1",
206
+ "-DCUTLASS_ENABLE_GDC_FOR_SM90",
207
+ "--expt-relaxed-constexpr",
208
+ "--expt-extended-lambda",
209
+ "--use_fast_math",
210
+ "-DNDEBUG",
211
+ ]
212
+ src = [
213
+ "flash-attn/block.h",
214
+ "flash-attn/copy_sm90_bulk_reduce.hpp",
215
+ "flash-attn/epilogue_bwd.hpp",
216
+ "flash-attn/epilogue_fwd.hpp",
217
+ "flash-attn/flash.h",
218
+ "flash-attn/flash_bwd_kernel_sm80.h",
219
+ "flash-attn/flash_bwd_kernel_sm90.h",
220
+ "flash-attn/flash_bwd_launch_template.h",
221
+ "flash-attn/flash_bwd_postprocess_kernel.h",
222
+ "flash-attn/flash_bwd_preprocess_kernel.h",
223
+ "flash-attn/flash_fwd_launch_template.h",
224
+ "flash-attn/flash_fwd_kernel_sm80.h",
225
+ "flash-attn/flash_fwd_kernel_sm90.h",
226
+ "flash-attn/heuristics.h",
227
+ "flash-attn/mainloop_bwd_sm80.hpp",
228
+ "flash-attn/mainloop_fwd_sm80.hpp",
229
+ "flash-attn/mainloop_bwd_sm90_tma_gmma_ws.hpp",
230
+ "flash-attn/mainloop_fwd_sm90_tma_gmma_ws.hpp",
231
+ "flash-attn/mask.h",
232
+ "flash-attn/named_barrier.hpp",
233
+ "flash-attn/pack_gqa.h",
234
+ "flash-attn/paged_kv.h",
235
+ "flash-attn/rotary.h",
236
+ "flash-attn/sm90_pipeline_no_cluster.hpp",
237
+ "flash-attn/softmax.h",
238
+ "flash-attn/tile_size.h",
239
+ "flash-attn/tile_scheduler.hpp",
240
+
241
+ "flash-attn/instantiations/flash_bwd_hdim128_bf16_sm90.cu",
242
+ "flash-attn/instantiations/flash_bwd_hdim128_bf16_softcap_sm90.cu",
243
+ "flash-attn/instantiations/flash_bwd_hdim128_fp16_sm90.cu",
244
+ "flash-attn/instantiations/flash_bwd_hdim128_fp16_softcap_sm90.cu",
245
+ "flash-attn/instantiations/flash_bwd_hdim192_bf16_sm90.cu",
246
+ "flash-attn/instantiations/flash_bwd_hdim192_bf16_softcap_sm90.cu",
247
+ "flash-attn/instantiations/flash_bwd_hdim192_fp16_sm90.cu",
248
+ "flash-attn/instantiations/flash_bwd_hdim192_fp16_softcap_sm90.cu",
249
+ "flash-attn/instantiations/flash_bwd_hdim256_bf16_sm90.cu",
250
+ "flash-attn/instantiations/flash_bwd_hdim256_bf16_softcap_sm90.cu",
251
+ "flash-attn/instantiations/flash_bwd_hdim256_fp16_sm90.cu",
252
+ "flash-attn/instantiations/flash_bwd_hdim256_fp16_softcap_sm90.cu",
253
+ "flash-attn/instantiations/flash_bwd_hdim64_bf16_sm90.cu",
254
+ "flash-attn/instantiations/flash_bwd_hdim64_bf16_softcap_sm90.cu",
255
+ "flash-attn/instantiations/flash_bwd_hdim64_fp16_sm90.cu",
256
+ "flash-attn/instantiations/flash_bwd_hdim64_fp16_softcap_sm90.cu",
257
+ "flash-attn/instantiations/flash_bwd_hdim96_bf16_sm90.cu",
258
+ "flash-attn/instantiations/flash_bwd_hdim96_bf16_softcap_sm90.cu",
259
+ "flash-attn/instantiations/flash_bwd_hdim96_fp16_sm90.cu",
260
+ "flash-attn/instantiations/flash_bwd_hdim96_fp16_softcap_sm90.cu",
261
+ "flash-attn/instantiations/flash_fwd_hdim128_bf16_packgqa_sm90.cu",
262
+ "flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_sm90.cu",
263
+ "flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_softcap_sm90.cu",
264
+ "flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_split_sm90.cu",
265
+ "flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_split_softcap_sm90.cu",
266
+ "flash-attn/instantiations/flash_fwd_hdim128_bf16_sm90.cu",
267
+ "flash-attn/instantiations/flash_fwd_hdim128_bf16_softcap_packgqa_sm90.cu",
268
+ "flash-attn/instantiations/flash_fwd_hdim128_bf16_softcap_sm90.cu",
269
+ "flash-attn/instantiations/flash_fwd_hdim128_bf16_split_sm90.cu",
270
+ "flash-attn/instantiations/flash_fwd_hdim128_bf16_split_softcap_sm90.cu",
271
+ "flash-attn/instantiations/flash_fwd_hdim128_e4m3_packgqa_sm90.cu",
272
+ "flash-attn/instantiations/flash_fwd_hdim128_e4m3_paged_sm90.cu",
273
+ "flash-attn/instantiations/flash_fwd_hdim128_e4m3_paged_softcap_sm90.cu",
274
+ "flash-attn/instantiations/flash_fwd_hdim128_e4m3_paged_split_sm90.cu",
275
+ "flash-attn/instantiations/flash_fwd_hdim128_e4m3_paged_split_softcap_sm90.cu",
276
+ "flash-attn/instantiations/flash_fwd_hdim128_e4m3_sm90.cu",
277
+ "flash-attn/instantiations/flash_fwd_hdim128_e4m3_softcap_packgqa_sm90.cu",
278
+ "flash-attn/instantiations/flash_fwd_hdim128_e4m3_softcap_sm90.cu",
279
+ "flash-attn/instantiations/flash_fwd_hdim128_e4m3_split_sm90.cu",
280
+ "flash-attn/instantiations/flash_fwd_hdim128_e4m3_split_softcap_sm90.cu",
281
+ "flash-attn/instantiations/flash_fwd_hdim128_fp16_packgqa_sm90.cu",
282
+ "flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_sm90.cu",
283
+ "flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_softcap_sm90.cu",
284
+ "flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_split_sm90.cu",
285
+ "flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_split_softcap_sm90.cu",
286
+ "flash-attn/instantiations/flash_fwd_hdim128_fp16_sm90.cu",
287
+ "flash-attn/instantiations/flash_fwd_hdim128_fp16_softcap_packgqa_sm90.cu",
288
+ "flash-attn/instantiations/flash_fwd_hdim128_fp16_softcap_sm90.cu",
289
+ "flash-attn/instantiations/flash_fwd_hdim128_fp16_split_sm90.cu",
290
+ "flash-attn/instantiations/flash_fwd_hdim128_fp16_split_softcap_sm90.cu",
291
+ "flash-attn/instantiations/flash_fwd_hdim192_128_bf16_packgqa_sm90.cu",
292
+ "flash-attn/instantiations/flash_fwd_hdim192_128_bf16_paged_sm90.cu",
293
+ "flash-attn/instantiations/flash_fwd_hdim192_128_bf16_paged_softcap_sm90.cu",
294
+ "flash-attn/instantiations/flash_fwd_hdim192_128_bf16_paged_split_sm90.cu",
295
+ "flash-attn/instantiations/flash_fwd_hdim192_128_bf16_paged_split_softcap_sm90.cu",
296
+ "flash-attn/instantiations/flash_fwd_hdim192_128_bf16_sm90.cu",
297
+ "flash-attn/instantiations/flash_fwd_hdim192_128_bf16_softcap_packgqa_sm90.cu",
298
+ "flash-attn/instantiations/flash_fwd_hdim192_128_bf16_softcap_sm90.cu",
299
+ "flash-attn/instantiations/flash_fwd_hdim192_128_bf16_split_sm90.cu",
300
+ "flash-attn/instantiations/flash_fwd_hdim192_128_bf16_split_softcap_sm90.cu",
301
+ "flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_packgqa_sm90.cu",
302
+ "flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_paged_sm90.cu",
303
+ "flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_paged_softcap_sm90.cu",
304
+ "flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_paged_split_sm90.cu",
305
+ "flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_paged_split_softcap_sm90.cu",
306
+ "flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_sm90.cu",
307
+ "flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_softcap_packgqa_sm90.cu",
308
+ "flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_softcap_sm90.cu",
309
+ "flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_split_sm90.cu",
310
+ "flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_split_softcap_sm90.cu",
311
+ "flash-attn/instantiations/flash_fwd_hdim192_128_fp16_packgqa_sm90.cu",
312
+ "flash-attn/instantiations/flash_fwd_hdim192_128_fp16_paged_sm90.cu",
313
+ "flash-attn/instantiations/flash_fwd_hdim192_128_fp16_paged_softcap_sm90.cu",
314
+ "flash-attn/instantiations/flash_fwd_hdim192_128_fp16_paged_split_sm90.cu",
315
+ "flash-attn/instantiations/flash_fwd_hdim192_128_fp16_paged_split_softcap_sm90.cu",
316
+ "flash-attn/instantiations/flash_fwd_hdim192_128_fp16_sm90.cu",
317
+ "flash-attn/instantiations/flash_fwd_hdim192_128_fp16_softcap_packgqa_sm90.cu",
318
+ "flash-attn/instantiations/flash_fwd_hdim192_128_fp16_softcap_sm90.cu",
319
+ "flash-attn/instantiations/flash_fwd_hdim192_128_fp16_split_sm90.cu",
320
+ "flash-attn/instantiations/flash_fwd_hdim192_128_fp16_split_softcap_sm90.cu",
321
+ "flash-attn/instantiations/flash_fwd_hdim192_bf16_packgqa_sm90.cu",
322
+ "flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_sm90.cu",
323
+ "flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_softcap_sm90.cu",
324
+ "flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_split_sm90.cu",
325
+ "flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_split_softcap_sm90.cu",
326
+ "flash-attn/instantiations/flash_fwd_hdim192_bf16_sm90.cu",
327
+ "flash-attn/instantiations/flash_fwd_hdim192_bf16_softcap_packgqa_sm90.cu",
328
+ "flash-attn/instantiations/flash_fwd_hdim192_bf16_softcap_sm90.cu",
329
+ "flash-attn/instantiations/flash_fwd_hdim192_bf16_split_sm90.cu",
330
+ "flash-attn/instantiations/flash_fwd_hdim192_bf16_split_softcap_sm90.cu",
331
+ "flash-attn/instantiations/flash_fwd_hdim192_e4m3_packgqa_sm90.cu",
332
+ "flash-attn/instantiations/flash_fwd_hdim192_e4m3_paged_sm90.cu",
333
+ "flash-attn/instantiations/flash_fwd_hdim192_e4m3_paged_softcap_sm90.cu",
334
+ "flash-attn/instantiations/flash_fwd_hdim192_e4m3_paged_split_sm90.cu",
335
+ "flash-attn/instantiations/flash_fwd_hdim192_e4m3_paged_split_softcap_sm90.cu",
336
+ "flash-attn/instantiations/flash_fwd_hdim192_e4m3_sm90.cu",
337
+ "flash-attn/instantiations/flash_fwd_hdim192_e4m3_softcap_packgqa_sm90.cu",
338
+ "flash-attn/instantiations/flash_fwd_hdim192_e4m3_softcap_sm90.cu",
339
+ "flash-attn/instantiations/flash_fwd_hdim192_e4m3_split_sm90.cu",
340
+ "flash-attn/instantiations/flash_fwd_hdim192_e4m3_split_softcap_sm90.cu",
341
+ "flash-attn/instantiations/flash_fwd_hdim192_fp16_packgqa_sm90.cu",
342
+ "flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_sm90.cu",
343
+ "flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_softcap_sm90.cu",
344
+ "flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_split_sm90.cu",
345
+ "flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_split_softcap_sm90.cu",
346
+ "flash-attn/instantiations/flash_fwd_hdim192_fp16_sm90.cu",
347
+ "flash-attn/instantiations/flash_fwd_hdim192_fp16_softcap_packgqa_sm90.cu",
348
+ "flash-attn/instantiations/flash_fwd_hdim192_fp16_softcap_sm90.cu",
349
+ "flash-attn/instantiations/flash_fwd_hdim192_fp16_split_sm90.cu",
350
+ "flash-attn/instantiations/flash_fwd_hdim192_fp16_split_softcap_sm90.cu",
351
+ "flash-attn/instantiations/flash_fwd_hdim256_bf16_packgqa_sm90.cu",
352
+ "flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_sm90.cu",
353
+ "flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_softcap_sm90.cu",
354
+ "flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_split_sm90.cu",
355
+ "flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_split_softcap_sm90.cu",
356
+ "flash-attn/instantiations/flash_fwd_hdim256_bf16_sm90.cu",
357
+ "flash-attn/instantiations/flash_fwd_hdim256_bf16_softcap_packgqa_sm90.cu",
358
+ "flash-attn/instantiations/flash_fwd_hdim256_bf16_softcap_sm90.cu",
359
+ "flash-attn/instantiations/flash_fwd_hdim256_bf16_split_sm90.cu",
360
+ "flash-attn/instantiations/flash_fwd_hdim256_bf16_split_softcap_sm90.cu",
361
+ "flash-attn/instantiations/flash_fwd_hdim256_e4m3_packgqa_sm90.cu",
362
+ "flash-attn/instantiations/flash_fwd_hdim256_e4m3_paged_sm90.cu",
363
+ "flash-attn/instantiations/flash_fwd_hdim256_e4m3_paged_softcap_sm90.cu",
364
+ "flash-attn/instantiations/flash_fwd_hdim256_e4m3_paged_split_sm90.cu",
365
+ "flash-attn/instantiations/flash_fwd_hdim256_e4m3_paged_split_softcap_sm90.cu",
366
+ "flash-attn/instantiations/flash_fwd_hdim256_e4m3_sm90.cu",
367
+ "flash-attn/instantiations/flash_fwd_hdim256_e4m3_softcap_packgqa_sm90.cu",
368
+ "flash-attn/instantiations/flash_fwd_hdim256_e4m3_softcap_sm90.cu",
369
+ "flash-attn/instantiations/flash_fwd_hdim256_e4m3_split_sm90.cu",
370
+ "flash-attn/instantiations/flash_fwd_hdim256_e4m3_split_softcap_sm90.cu",
371
+ "flash-attn/instantiations/flash_fwd_hdim256_fp16_packgqa_sm90.cu",
372
+ "flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_sm90.cu",
373
+ "flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_softcap_sm90.cu",
374
+ "flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_split_sm90.cu",
375
+ "flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_split_softcap_sm90.cu",
376
+ "flash-attn/instantiations/flash_fwd_hdim256_fp16_sm90.cu",
377
+ "flash-attn/instantiations/flash_fwd_hdim256_fp16_softcap_packgqa_sm90.cu",
378
+ "flash-attn/instantiations/flash_fwd_hdim256_fp16_softcap_sm90.cu",
379
+ "flash-attn/instantiations/flash_fwd_hdim256_fp16_split_sm90.cu",
380
+ "flash-attn/instantiations/flash_fwd_hdim256_fp16_split_softcap_sm90.cu",
381
+ "flash-attn/instantiations/flash_fwd_hdim64_256_bf16_packgqa_sm90.cu",
382
+ "flash-attn/instantiations/flash_fwd_hdim64_256_bf16_paged_sm90.cu",
383
+ "flash-attn/instantiations/flash_fwd_hdim64_256_bf16_paged_softcap_sm90.cu",
384
+ "flash-attn/instantiations/flash_fwd_hdim64_256_bf16_paged_split_sm90.cu",
385
+ "flash-attn/instantiations/flash_fwd_hdim64_256_bf16_paged_split_softcap_sm90.cu",
386
+ "flash-attn/instantiations/flash_fwd_hdim64_256_bf16_sm90.cu",
387
+ "flash-attn/instantiations/flash_fwd_hdim64_256_bf16_softcap_packgqa_sm90.cu",
388
+ "flash-attn/instantiations/flash_fwd_hdim64_256_bf16_softcap_sm90.cu",
389
+ "flash-attn/instantiations/flash_fwd_hdim64_256_bf16_split_sm90.cu",
390
+ "flash-attn/instantiations/flash_fwd_hdim64_256_bf16_split_softcap_sm90.cu",
391
+ "flash-attn/instantiations/flash_fwd_hdim64_256_fp16_packgqa_sm90.cu",
392
+ "flash-attn/instantiations/flash_fwd_hdim64_256_fp16_paged_sm90.cu",
393
+ "flash-attn/instantiations/flash_fwd_hdim64_256_fp16_paged_softcap_sm90.cu",
394
+ "flash-attn/instantiations/flash_fwd_hdim64_256_fp16_paged_split_sm90.cu",
395
+ "flash-attn/instantiations/flash_fwd_hdim64_256_fp16_paged_split_softcap_sm90.cu",
396
+ "flash-attn/instantiations/flash_fwd_hdim64_256_fp16_sm90.cu",
397
+ "flash-attn/instantiations/flash_fwd_hdim64_256_fp16_softcap_packgqa_sm90.cu",
398
+ "flash-attn/instantiations/flash_fwd_hdim64_256_fp16_softcap_sm90.cu",
399
+ "flash-attn/instantiations/flash_fwd_hdim64_256_fp16_split_sm90.cu",
400
+ "flash-attn/instantiations/flash_fwd_hdim64_256_fp16_split_softcap_sm90.cu",
401
+ "flash-attn/instantiations/flash_fwd_hdim64_512_bf16_packgqa_sm90.cu",
402
+ "flash-attn/instantiations/flash_fwd_hdim64_512_bf16_paged_sm90.cu",
403
+ "flash-attn/instantiations/flash_fwd_hdim64_512_bf16_paged_softcap_sm90.cu",
404
+ "flash-attn/instantiations/flash_fwd_hdim64_512_bf16_paged_split_sm90.cu",
405
+ "flash-attn/instantiations/flash_fwd_hdim64_512_bf16_paged_split_softcap_sm90.cu",
406
+ "flash-attn/instantiations/flash_fwd_hdim64_512_bf16_sm90.cu",
407
+ "flash-attn/instantiations/flash_fwd_hdim64_512_bf16_softcap_packgqa_sm90.cu",
408
+ "flash-attn/instantiations/flash_fwd_hdim64_512_bf16_softcap_sm90.cu",
409
+ "flash-attn/instantiations/flash_fwd_hdim64_512_bf16_split_sm90.cu",
410
+ "flash-attn/instantiations/flash_fwd_hdim64_512_bf16_split_softcap_sm90.cu",
411
+ "flash-attn/instantiations/flash_fwd_hdim64_512_fp16_packgqa_sm90.cu",
412
+ "flash-attn/instantiations/flash_fwd_hdim64_512_fp16_paged_sm90.cu",
413
+ "flash-attn/instantiations/flash_fwd_hdim64_512_fp16_paged_softcap_sm90.cu",
414
+ "flash-attn/instantiations/flash_fwd_hdim64_512_fp16_paged_split_sm90.cu",
415
+ "flash-attn/instantiations/flash_fwd_hdim64_512_fp16_paged_split_softcap_sm90.cu",
416
+ "flash-attn/instantiations/flash_fwd_hdim64_512_fp16_sm90.cu",
417
+ "flash-attn/instantiations/flash_fwd_hdim64_512_fp16_softcap_packgqa_sm90.cu",
418
+ "flash-attn/instantiations/flash_fwd_hdim64_512_fp16_softcap_sm90.cu",
419
+ "flash-attn/instantiations/flash_fwd_hdim64_512_fp16_split_sm90.cu",
420
+ "flash-attn/instantiations/flash_fwd_hdim64_512_fp16_split_softcap_sm90.cu",
421
+ "flash-attn/instantiations/flash_fwd_hdim64_bf16_packgqa_sm90.cu",
422
+ "flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_sm90.cu",
423
+ "flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_softcap_sm90.cu",
424
+ "flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_split_sm90.cu",
425
+ "flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_split_softcap_sm90.cu",
426
+ "flash-attn/instantiations/flash_fwd_hdim64_bf16_sm90.cu",
427
+ "flash-attn/instantiations/flash_fwd_hdim64_bf16_softcap_packgqa_sm90.cu",
428
+ "flash-attn/instantiations/flash_fwd_hdim64_bf16_softcap_sm90.cu",
429
+ "flash-attn/instantiations/flash_fwd_hdim64_bf16_split_sm90.cu",
430
+ "flash-attn/instantiations/flash_fwd_hdim64_bf16_split_softcap_sm90.cu",
431
+ "flash-attn/instantiations/flash_fwd_hdim64_e4m3_packgqa_sm90.cu",
432
+ "flash-attn/instantiations/flash_fwd_hdim64_e4m3_paged_sm90.cu",
433
+ "flash-attn/instantiations/flash_fwd_hdim64_e4m3_paged_softcap_sm90.cu",
434
+ "flash-attn/instantiations/flash_fwd_hdim64_e4m3_paged_split_sm90.cu",
435
+ "flash-attn/instantiations/flash_fwd_hdim64_e4m3_paged_split_softcap_sm90.cu",
436
+ "flash-attn/instantiations/flash_fwd_hdim64_e4m3_sm90.cu",
437
+ "flash-attn/instantiations/flash_fwd_hdim64_e4m3_softcap_packgqa_sm90.cu",
438
+ "flash-attn/instantiations/flash_fwd_hdim64_e4m3_softcap_sm90.cu",
439
+ "flash-attn/instantiations/flash_fwd_hdim64_e4m3_split_sm90.cu",
440
+ "flash-attn/instantiations/flash_fwd_hdim64_e4m3_split_softcap_sm90.cu",
441
+ "flash-attn/instantiations/flash_fwd_hdim64_fp16_packgqa_sm90.cu",
442
+ "flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_sm90.cu",
443
+ "flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_softcap_sm90.cu",
444
+ "flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_split_sm90.cu",
445
+ "flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_split_softcap_sm90.cu",
446
+ "flash-attn/instantiations/flash_fwd_hdim64_fp16_sm90.cu",
447
+ "flash-attn/instantiations/flash_fwd_hdim64_fp16_softcap_packgqa_sm90.cu",
448
+ "flash-attn/instantiations/flash_fwd_hdim64_fp16_softcap_sm90.cu",
449
+ "flash-attn/instantiations/flash_fwd_hdim64_fp16_split_sm90.cu",
450
+ "flash-attn/instantiations/flash_fwd_hdim64_fp16_split_softcap_sm90.cu",
451
+ "flash-attn/instantiations/flash_fwd_hdim96_bf16_packgqa_sm90.cu",
452
+ "flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_sm90.cu",
453
+ "flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_softcap_sm90.cu",
454
+ "flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_split_sm90.cu",
455
+ "flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_split_softcap_sm90.cu",
456
+ "flash-attn/instantiations/flash_fwd_hdim96_bf16_sm90.cu",
457
+ "flash-attn/instantiations/flash_fwd_hdim96_bf16_softcap_packgqa_sm90.cu",
458
+ "flash-attn/instantiations/flash_fwd_hdim96_bf16_softcap_sm90.cu",
459
+ "flash-attn/instantiations/flash_fwd_hdim96_bf16_split_sm90.cu",
460
+ "flash-attn/instantiations/flash_fwd_hdim96_bf16_split_softcap_sm90.cu",
461
+ "flash-attn/instantiations/flash_fwd_hdim96_e4m3_packgqa_sm90.cu",
462
+ "flash-attn/instantiations/flash_fwd_hdim96_e4m3_paged_sm90.cu",
463
+ "flash-attn/instantiations/flash_fwd_hdim96_e4m3_paged_softcap_sm90.cu",
464
+ "flash-attn/instantiations/flash_fwd_hdim96_e4m3_paged_split_sm90.cu",
465
+ "flash-attn/instantiations/flash_fwd_hdim96_e4m3_paged_split_softcap_sm90.cu",
466
+ "flash-attn/instantiations/flash_fwd_hdim96_e4m3_sm90.cu",
467
+ "flash-attn/instantiations/flash_fwd_hdim96_e4m3_softcap_packgqa_sm90.cu",
468
+ "flash-attn/instantiations/flash_fwd_hdim96_e4m3_softcap_sm90.cu",
469
+ "flash-attn/instantiations/flash_fwd_hdim96_e4m3_split_sm90.cu",
470
+ "flash-attn/instantiations/flash_fwd_hdim96_e4m3_split_softcap_sm90.cu",
471
+ "flash-attn/instantiations/flash_fwd_hdim96_fp16_packgqa_sm90.cu",
472
+ "flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_sm90.cu",
473
+ "flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_softcap_sm90.cu",
474
+ "flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_split_sm90.cu",
475
+ "flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_split_softcap_sm90.cu",
476
+ "flash-attn/instantiations/flash_fwd_hdim96_fp16_sm90.cu",
477
+ "flash-attn/instantiations/flash_fwd_hdim96_fp16_softcap_packgqa_sm90.cu",
478
+ "flash-attn/instantiations/flash_fwd_hdim96_fp16_softcap_sm90.cu",
479
+ "flash-attn/instantiations/flash_fwd_hdim96_fp16_split_sm90.cu",
480
+ "flash-attn/instantiations/flash_fwd_hdim96_fp16_split_softcap_sm90.cu",
481
+ "flash-attn/instantiations/flash_fwd_hdimall_bf16_packgqa_sm90.cu",
482
+ "flash-attn/instantiations/flash_fwd_hdimall_bf16_paged_sm90.cu",
483
+ "flash-attn/instantiations/flash_fwd_hdimall_bf16_paged_softcap_sm90.cu",
484
+ "flash-attn/instantiations/flash_fwd_hdimall_bf16_paged_split_sm90.cu",
485
+ "flash-attn/instantiations/flash_fwd_hdimall_bf16_paged_split_softcap_sm90.cu",
486
+ "flash-attn/instantiations/flash_fwd_hdimall_bf16_sm90.cu",
487
+ "flash-attn/instantiations/flash_fwd_hdimall_bf16_softcap_packgqa_sm90.cu",
488
+ "flash-attn/instantiations/flash_fwd_hdimall_bf16_softcap_sm90.cu",
489
+ "flash-attn/instantiations/flash_fwd_hdimall_bf16_split_sm90.cu",
490
+ "flash-attn/instantiations/flash_fwd_hdimall_bf16_split_softcap_sm90.cu",
491
+ "flash-attn/instantiations/flash_fwd_hdimall_e4m3_packgqa_sm90.cu",
492
+ "flash-attn/instantiations/flash_fwd_hdimall_e4m3_paged_sm90.cu",
493
+ "flash-attn/instantiations/flash_fwd_hdimall_e4m3_paged_softcap_sm90.cu",
494
+ "flash-attn/instantiations/flash_fwd_hdimall_e4m3_paged_split_sm90.cu",
495
+ "flash-attn/instantiations/flash_fwd_hdimall_e4m3_paged_split_softcap_sm90.cu",
496
+ "flash-attn/instantiations/flash_fwd_hdimall_e4m3_sm90.cu",
497
+ "flash-attn/instantiations/flash_fwd_hdimall_e4m3_softcap_packgqa_sm90.cu",
498
+ "flash-attn/instantiations/flash_fwd_hdimall_e4m3_softcap_sm90.cu",
499
+ "flash-attn/instantiations/flash_fwd_hdimall_e4m3_split_sm90.cu",
500
+ "flash-attn/instantiations/flash_fwd_hdimall_e4m3_split_softcap_sm90.cu",
501
+ "flash-attn/instantiations/flash_fwd_hdimall_fp16_packgqa_sm90.cu",
502
+ "flash-attn/instantiations/flash_fwd_hdimall_fp16_paged_sm90.cu",
503
+ "flash-attn/instantiations/flash_fwd_hdimall_fp16_paged_softcap_sm90.cu",
504
+ "flash-attn/instantiations/flash_fwd_hdimall_fp16_paged_split_sm90.cu",
505
+ "flash-attn/instantiations/flash_fwd_hdimall_fp16_paged_split_softcap_sm90.cu",
506
+ "flash-attn/instantiations/flash_fwd_hdimall_fp16_sm90.cu",
507
+ "flash-attn/instantiations/flash_fwd_hdimall_fp16_softcap_packgqa_sm90.cu",
508
+ "flash-attn/instantiations/flash_fwd_hdimall_fp16_softcap_sm90.cu",
509
+ "flash-attn/instantiations/flash_fwd_hdimall_fp16_split_sm90.cu",
510
+ "flash-attn/instantiations/flash_fwd_hdimall_fp16_split_softcap_sm90.cu",
511
+ "flash-attn/instantiations/flash_fwd_hdimdiff_bf16_packgqa_sm90.cu",
512
+ "flash-attn/instantiations/flash_fwd_hdimdiff_bf16_paged_sm90.cu",
513
+ "flash-attn/instantiations/flash_fwd_hdimdiff_bf16_paged_softcap_sm90.cu",
514
+ "flash-attn/instantiations/flash_fwd_hdimdiff_bf16_paged_split_sm90.cu",
515
+ "flash-attn/instantiations/flash_fwd_hdimdiff_bf16_paged_split_softcap_sm90.cu",
516
+ "flash-attn/instantiations/flash_fwd_hdimdiff_bf16_sm90.cu",
517
+ "flash-attn/instantiations/flash_fwd_hdimdiff_bf16_softcap_packgqa_sm90.cu",
518
+ "flash-attn/instantiations/flash_fwd_hdimdiff_bf16_softcap_sm90.cu",
519
+ "flash-attn/instantiations/flash_fwd_hdimdiff_bf16_split_sm90.cu",
520
+ "flash-attn/instantiations/flash_fwd_hdimdiff_bf16_split_softcap_sm90.cu",
521
+ "flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_packgqa_sm90.cu",
522
+ "flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_paged_sm90.cu",
523
+ "flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_paged_softcap_sm90.cu",
524
+ "flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_paged_split_sm90.cu",
525
+ "flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_paged_split_softcap_sm90.cu",
526
+ "flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_sm90.cu",
527
+ "flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_softcap_packgqa_sm90.cu",
528
+ "flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_softcap_sm90.cu",
529
+ "flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_split_sm90.cu",
530
+ "flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_split_softcap_sm90.cu",
531
+ "flash-attn/instantiations/flash_fwd_hdimdiff_fp16_packgqa_sm90.cu",
532
+ "flash-attn/instantiations/flash_fwd_hdimdiff_fp16_paged_sm90.cu",
533
+ "flash-attn/instantiations/flash_fwd_hdimdiff_fp16_paged_softcap_sm90.cu",
534
+ "flash-attn/instantiations/flash_fwd_hdimdiff_fp16_paged_split_sm90.cu",
535
+ "flash-attn/instantiations/flash_fwd_hdimdiff_fp16_paged_split_softcap_sm90.cu",
536
+ "flash-attn/instantiations/flash_fwd_hdimdiff_fp16_sm90.cu",
537
+ "flash-attn/instantiations/flash_fwd_hdimdiff_fp16_softcap_packgqa_sm90.cu",
538
+ "flash-attn/instantiations/flash_fwd_hdimdiff_fp16_softcap_sm90.cu",
539
+ "flash-attn/instantiations/flash_fwd_hdimdiff_fp16_split_sm90.cu",
540
+ "flash-attn/instantiations/flash_fwd_hdimdiff_fp16_split_softcap_sm90.cu",
541
+ ]
542
+ include = ["flash-attn"]
543
+ depends = ["torch", "cutlass_3_9"]
544
+
545
+ # [kernel.flash_attn_sm100]
546
+ # backend = "cuda"
547
+ # cuda-capabilities = ["8.0", "9.0a", "10.0"]
548
+ # cuda-flags = [
549
+ # "-O3",
550
+ # "-std=c++17",
551
+ # "--ftemplate-backtrace-limit=0", # To debug template code
552
+ # "--use_fast_math",
553
+ # "-DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED",
554
+ # "-DCUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1",
555
+ # "-DCUTLASS_ENABLE_GDC_FOR_SM90",
556
+ # "--expt-relaxed-constexpr",
557
+ # "--expt-extended-lambda",
558
+ # "--use_fast_math",
559
+ # "-DNDEBUG",
560
+ # ]
561
+ # src = [
562
+ # "flash-attn/block.h",
563
+ # "flash-attn/copy_sm90_bulk_reduce.hpp",
564
+ # "flash-attn/epilogue_bwd.hpp",
565
+ # "flash-attn/epilogue_fwd.hpp",
566
+ # "flash-attn/flash.h",
567
+ # "flash-attn/flash_bwd_kernel_sm80.h",
568
+ # "flash-attn/flash_bwd_kernel_sm90.h",
569
+ # "flash-attn/flash_bwd_launch_template.h",
570
+ # "flash-attn/flash_bwd_postprocess_kernel.h",
571
+ # "flash-attn/flash_bwd_preprocess_kernel.h",
572
+ # "flash-attn/flash_fwd_launch_template.h",
573
+ # "flash-attn/flash_fwd_kernel_sm80.h",
574
+ # "flash-attn/flash_fwd_kernel_sm90.h",
575
+ # "flash-attn/heuristics.h",
576
+ # "flash-attn/mainloop_bwd_sm80.hpp",
577
+ # "flash-attn/mainloop_fwd_sm80.hpp",
578
+ # "flash-attn/mainloop_bwd_sm90_tma_gmma_ws.hpp",
579
+ # "flash-attn/mainloop_fwd_sm90_tma_gmma_ws.hpp",
580
+ # "flash-attn/mask.h",
581
+ # "flash-attn/named_barrier.hpp",
582
+ # "flash-attn/pack_gqa.h",
583
+ # "flash-attn/paged_kv.h",
584
+ # "flash-attn/rotary.h",
585
+ # "flash-attn/sm90_pipeline_no_cluster.hpp",
586
+ # "flash-attn/softmax.h",
587
+ # "flash-attn/tile_size.h",
588
+ # "flash-attn/tile_scheduler.hpp",
589
+ #
590
+ # "flash-attn/instantiations/flash_fwd_hdim128_bf16_sm100.cu",
591
+ # ]
592
+ # include = ["flash-attn"]
593
+ # depends = ["torch", "cutlass_3_9"]
flake.lock ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nodes": {
3
+ "flake-compat": {
4
+ "locked": {
5
+ "lastModified": 1747046372,
6
+ "narHash": "sha256-CIVLLkVgvHYbgI2UpXvIIBJ12HWgX+fjA8Xf8PUmqCY=",
7
+ "owner": "edolstra",
8
+ "repo": "flake-compat",
9
+ "rev": "9100a0f413b0c601e0533d1d94ffd501ce2e7885",
10
+ "type": "github"
11
+ },
12
+ "original": {
13
+ "owner": "edolstra",
14
+ "repo": "flake-compat",
15
+ "type": "github"
16
+ }
17
+ },
18
+ "flake-compat_2": {
19
+ "locked": {
20
+ "lastModified": 1733328505,
21
+ "narHash": "sha256-NeCCThCEP3eCl2l/+27kNNK7QrwZB1IJCrXfrbv5oqU=",
22
+ "owner": "edolstra",
23
+ "repo": "flake-compat",
24
+ "rev": "ff81ac966bb2cae68946d5ed5fc4994f96d0ffec",
25
+ "type": "github"
26
+ },
27
+ "original": {
28
+ "owner": "edolstra",
29
+ "repo": "flake-compat",
30
+ "type": "github"
31
+ }
32
+ },
33
+ "flake-utils": {
34
+ "inputs": {
35
+ "systems": "systems"
36
+ },
37
+ "locked": {
38
+ "lastModified": 1731533236,
39
+ "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
40
+ "owner": "numtide",
41
+ "repo": "flake-utils",
42
+ "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
43
+ "type": "github"
44
+ },
45
+ "original": {
46
+ "owner": "numtide",
47
+ "repo": "flake-utils",
48
+ "type": "github"
49
+ }
50
+ },
51
+ "flake-utils_2": {
52
+ "inputs": {
53
+ "systems": "systems_2"
54
+ },
55
+ "locked": {
56
+ "lastModified": 1731533236,
57
+ "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
58
+ "owner": "numtide",
59
+ "repo": "flake-utils",
60
+ "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
61
+ "type": "github"
62
+ },
63
+ "original": {
64
+ "owner": "numtide",
65
+ "repo": "flake-utils",
66
+ "type": "github"
67
+ }
68
+ },
69
+ "hf-nix": {
70
+ "inputs": {
71
+ "flake-compat": "flake-compat_2",
72
+ "flake-utils": "flake-utils_2",
73
+ "nixpkgs": "nixpkgs"
74
+ },
75
+ "locked": {
76
+ "lastModified": 1750234878,
77
+ "narHash": "sha256-q9DRC9zdpzUf88qqg1qbhP1qgJbE2cMtn8oUmosuyT8=",
78
+ "owner": "huggingface",
79
+ "repo": "hf-nix",
80
+ "rev": "c7132f90763d756da3e77da62e01be0a4546dc57",
81
+ "type": "github"
82
+ },
83
+ "original": {
84
+ "owner": "huggingface",
85
+ "repo": "hf-nix",
86
+ "type": "github"
87
+ }
88
+ },
89
+ "kernel-builder": {
90
+ "inputs": {
91
+ "flake-compat": "flake-compat",
92
+ "flake-utils": "flake-utils",
93
+ "hf-nix": "hf-nix",
94
+ "nixpkgs": [
95
+ "kernel-builder",
96
+ "hf-nix",
97
+ "nixpkgs"
98
+ ]
99
+ },
100
+ "locked": {
101
+ "lastModified": 1750275112,
102
+ "narHash": "sha256-gqAxmLLt0tYvuRYumOZHQgryMeEFdt6j3nEC8B5rT14=",
103
+ "owner": "huggingface",
104
+ "repo": "kernel-builder",
105
+ "rev": "1b63210b2a1fc3cda2e3a579e7aa8f8c8532626f",
106
+ "type": "github"
107
+ },
108
+ "original": {
109
+ "owner": "huggingface",
110
+ "repo": "kernel-builder",
111
+ "type": "github"
112
+ }
113
+ },
114
+ "nixpkgs": {
115
+ "locked": {
116
+ "lastModified": 1747820358,
117
+ "narHash": "sha256-fTqsZsUX6M3yeEvgyQvXcbGmT2CaRVyVwsi8eK29Oj4=",
118
+ "owner": "danieldk",
119
+ "repo": "nixpkgs",
120
+ "rev": "d3c1681180717528068082103bf323147de6ab0b",
121
+ "type": "github"
122
+ },
123
+ "original": {
124
+ "owner": "danieldk",
125
+ "ref": "cudatoolkit-12.9-kernel-builder",
126
+ "repo": "nixpkgs",
127
+ "type": "github"
128
+ }
129
+ },
130
+ "root": {
131
+ "inputs": {
132
+ "kernel-builder": "kernel-builder"
133
+ }
134
+ },
135
+ "systems": {
136
+ "locked": {
137
+ "lastModified": 1681028828,
138
+ "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
139
+ "owner": "nix-systems",
140
+ "repo": "default",
141
+ "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
142
+ "type": "github"
143
+ },
144
+ "original": {
145
+ "owner": "nix-systems",
146
+ "repo": "default",
147
+ "type": "github"
148
+ }
149
+ },
150
+ "systems_2": {
151
+ "locked": {
152
+ "lastModified": 1681028828,
153
+ "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
154
+ "owner": "nix-systems",
155
+ "repo": "default",
156
+ "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
157
+ "type": "github"
158
+ },
159
+ "original": {
160
+ "owner": "nix-systems",
161
+ "repo": "default",
162
+ "type": "github"
163
+ }
164
+ }
165
+ },
166
+ "root": "root",
167
+ "version": 7
168
+ }
flake.nix ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ description = "Flake for Hopper Flash Attention kernel";
3
+
4
+ inputs = {
5
+ kernel-builder.url = "github:huggingface/kernel-builder";
6
+ };
7
+
8
+ outputs =
9
+ {
10
+ self,
11
+ kernel-builder,
12
+ }:
13
+ kernel-builder.lib.genFlakeOutputs {
14
+ path = ./.;
15
+ rev = self.shortRev or self.dirtyShortRev or self.lastModifiedDate;
16
+ };
17
+ }
flash-attn/block.h ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ #pragma once
6
+
7
+ namespace flash {
8
+
9
+ template <class SeqlenInfo_t, int kBlockM, int kBlockN, bool Is_causal, bool Is_local, bool PackGQA=false, bool Split=false>
10
+ struct BlockMN {
11
+
12
+ static
13
+ CUTLASS_DEVICE
14
+ cute::tuple<int, int> get_n_block_min_max(
15
+ SeqlenInfo_t const& seqlen_info,
16
+ int const m_block, int const bidb, int const split_idx, int const num_splits,
17
+ int const window_size_left, int const window_size_right,
18
+ cutlass::FastDivmod const& attention_chunk_divmod,
19
+ cutlass::FastDivmod const& qhead_per_khead_divmod) {
20
+
21
+ int const seqlen_k = seqlen_info.seqlen_k;
22
+ int const seqlen_q = seqlen_info.seqlen_q;
23
+ int n_block_max = cute::ceil_div(seqlen_k, kBlockN);
24
+ if constexpr (Is_causal || Is_local) {
25
+ int m_idx_max = (m_block + 1) * kBlockM;
26
+ // TODO: check off-by-1 error
27
+ if (PackGQA) { m_idx_max = qhead_per_khead_divmod.divide(m_idx_max - 1) + 1 ; }
28
+ int const n_idx = m_idx_max + seqlen_info.seqlen_k - seqlen_info.seqlen_q;
29
+ int n_idx_right = !Is_local ? n_idx : n_idx + window_size_right;
30
+ if (Is_local && attention_chunk_divmod.divisor > 0) {
31
+ n_idx_right = std::min(n_idx_right, flash::round_up(attention_chunk_divmod, n_idx));
32
+ }
33
+ n_block_max = std::min(n_block_max, cute::ceil_div(n_idx_right, kBlockN));
34
+ }
35
+ int n_block_min = 0;
36
+ if constexpr (Is_local) {
37
+ int m_idx_min = m_block * kBlockM;
38
+ if (PackGQA) { m_idx_min = qhead_per_khead_divmod.divide(m_idx_min); }
39
+ int const n_idx = m_idx_min + seqlen_k - seqlen_q;
40
+ int n_idx_left = n_idx - window_size_left;
41
+ if (attention_chunk_divmod.divisor > 0) {
42
+ n_idx_left = std::max(n_idx_left, flash::round_down(attention_chunk_divmod, n_idx));
43
+ }
44
+ n_block_min = std::max(int(0), n_idx_left / kBlockN);
45
+ }
46
+ // if (threadIdx.x == 128) { printf("Inside, bid.x = %d, bid.y = %d, bid.z = %d, split_idx = %d, n_block_min: %d, n_block_max: %d\n", blockIdx.x, blockIdx.y, blockIdx.z, split_idx, n_block_min, n_block_max); }
47
+ if constexpr (Split) {
48
+ uint32_t num_splits_dynamic_u = reinterpret_cast<uint32_t const&>(split_idx) >> 16; // first 16 bits are for num_splits
49
+ int num_splits_dynamic = reinterpret_cast<int&>(num_splits_dynamic_u);
50
+ int split_idx_actual = split_idx & 0x0000FFFF;
51
+ int num_splits_actual = num_splits_dynamic > 0 ? num_splits_dynamic : num_splits;
52
+ int num_n_blocks_per_split = n_block_max <= n_block_min ? 0 : cute::ceil_div(n_block_max - n_block_min, num_splits_actual);
53
+ n_block_min = n_block_min + split_idx_actual * num_n_blocks_per_split;
54
+ n_block_max = std::min(n_block_min + num_n_blocks_per_split, n_block_max);
55
+ // if (threadIdx.x == 128) { printf("Inside, bid.x = %d, bid.y = %d, bid.z = %d, split_idx = %d, num_splits_dynamic = %d, num_splits_actual = %d, num_n_blocks_per_split = %d, n_block_min: %d, n_block_max: %d\n", blockIdx.x, blockIdx.y, blockIdx.z, split_idx, num_splits_dynamic, num_splits_actual, num_n_blocks_per_split, n_block_min, n_block_max); }
56
+ }
57
+ // if (threadIdx.x == 128) { printf("After split, inside, bid.y = %d, bid.z = %d, split_idx = %d, n_block_min: %d, n_block_max: %d\n", blockIdx.y, blockIdx.z, split_idx, n_block_min, n_block_max); }
58
+ return {n_block_min, n_block_max};
59
+ }
60
+
61
+ static
62
+ CUTLASS_DEVICE
63
+ cute::tuple<int, int> get_n_block_k_new_min_max(
64
+ SeqlenInfo_t const& seqlen_info,
65
+ int const m_block, int const bidb, int const split_idx, int const num_splits,
66
+ int const window_size_left, int const window_size_right,
67
+ cutlass::FastDivmod const& attention_chunk_divmod,
68
+ cutlass::FastDivmod const& qhead_per_khead_divmod) {
69
+
70
+ auto [n_block_min, n_block_max] = get_n_block_min_max(
71
+ seqlen_info, m_block, bidb, split_idx, num_splits,
72
+ window_size_left, window_size_right, attention_chunk_divmod, qhead_per_khead_divmod);
73
+ int const idx_k_new_min = std::max(n_block_min * kBlockN - seqlen_info.seqlen_k_og, 0);
74
+ int const idx_k_new_max = std::min(n_block_max * kBlockN - seqlen_info.seqlen_k_og, seqlen_info.seqlen_k_new);
75
+ int const n_block_new_min = idx_k_new_min / kBlockN;
76
+ int const n_block_new_max = idx_k_new_max > idx_k_new_min ? cute::ceil_div(idx_k_new_max, kBlockN) : n_block_new_min;
77
+ // if (threadIdx.x == 128 && m_block == 0) { printf("bidb = %d, seqlen_k_new = %d, seqlen_k_og = %d, n_block_min = %d, n_block_max = %d, idx_k_new_min = %d, idx_k_new_max = %d, n_block_new_min = %d, n_block_new_max = %d\n", bidb, seqlen_k_new, seqlen_k_og, n_block_min, n_block_max, idx_k_new_min, idx_k_new_max, n_block_new_min, n_block_new_max);}
78
+ return {n_block_new_min, n_block_new_max};
79
+ }
80
+
81
+ static
82
+ CUTLASS_DEVICE
83
+ cute::tuple<int, int> get_m_block_min_max(
84
+ SeqlenInfo_t const& seqlen_info,
85
+ int const n_block, int const bidb,
86
+ int const window_size_left, int const window_size_right, int const sink_token_length) {
87
+ // TODO: support attention_chunk
88
+ int const seqlen_q = seqlen_info.seqlen_q;
89
+ int const seqlen_k = seqlen_info.seqlen_k;
90
+ int m_block_max = cute::ceil_div(seqlen_q, kBlockM);
91
+ if constexpr (Is_local) {
92
+ if (n_block >= cute::ceil_div(sink_token_length, kBlockN)) {
93
+ m_block_max = std::min(m_block_max, cute::ceil_div((n_block + 1) * kBlockN + seqlen_q - seqlen_k + window_size_left, kBlockM));
94
+ }
95
+ }
96
+ int m_block_min = 0;
97
+ if constexpr (Is_causal || Is_local) {
98
+ m_block_min = std::max(m_block_min, (n_block * kBlockN + seqlen_q - seqlen_k - window_size_right) / kBlockM);
99
+ }
100
+ return {m_block_min, m_block_max};
101
+ }
102
+
103
+ // If we have separate iterations with causal or local masking at the start, where do we stop
104
+ static
105
+ CUTLASS_DEVICE
106
+ int get_n_block_min_causal_local_mask(
107
+ SeqlenInfo_t const& seqlen_info,
108
+ int const m_block, int const n_block_min, int const window_size_right,
109
+ cutlass::FastDivmod const& attention_chunk_divmod,
110
+ cutlass::FastDivmod const& qhead_per_khead_divmod) {
111
+ int const m_idx_min = !PackGQA ? m_block * kBlockM : qhead_per_khead_divmod.divide(m_block * kBlockM);
112
+ int const n_idx = m_idx_min + seqlen_info.seqlen_k - seqlen_info.seqlen_q;
113
+ int n_idx_right = !Is_local ? n_idx : n_idx + window_size_right;
114
+ if (Is_local && attention_chunk_divmod.divisor > 0) {
115
+ n_idx_right = std::min(n_idx_right, flash::round_up(attention_chunk_divmod, n_idx));
116
+ }
117
+ return std::max(n_block_min, n_idx_right / kBlockN);
118
+ }
119
+
120
+ // If we have separate iterations with local masking at the end, where do we stop the non-masked iterations
121
+ static
122
+ CUTLASS_DEVICE
123
+ int get_n_block_min_before_local_mask(
124
+ SeqlenInfo_t const& seqlen_info,
125
+ int const m_block, int const n_block_min, int const window_size_left,
126
+ cutlass::FastDivmod const& attention_chunk_divmod,
127
+ cutlass::FastDivmod const& qhead_per_khead_divmod) {
128
+ int const m_idx_max = !PackGQA ? (m_block + 1) * kBlockM : qhead_per_khead_divmod.divide((m_block + 1) * kBlockM - 1) + 1;
129
+ int const n_idx = m_idx_max + seqlen_info.seqlen_k - seqlen_info.seqlen_q;
130
+ int n_idx_left = !Is_local ? n_idx : n_idx - window_size_left;
131
+ if (Is_local && attention_chunk_divmod.divisor > 0) {
132
+ n_idx_left = std::max(n_idx_left, flash::round_down(attention_chunk_divmod, n_idx));
133
+ }
134
+ return !Is_local ? n_block_min : std::max(n_block_min, cute::ceil_div(n_idx_left, kBlockN));
135
+ }
136
+
137
+ };
138
+
139
+ } // namespace flash
flash-attn/copy_sm90_bulk_reduce.hpp ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ #pragma once
6
+
7
+ #include<cute/arch/copy_sm90_tma.hpp>
8
+
9
+ namespace cute
10
+ {
11
+
12
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
13
+
14
+ struct SM90_BULK_REDUCE_ADD
15
+ {
16
+ CUTE_HOST_DEVICE static void
17
+ copy(float const* smem_ptr,
18
+ float * gmem_ptr, int32_t store_bytes)
19
+ {
20
+ #if defined(CUTE_ARCH_TMA_SM90_ENABLED)
21
+ uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr);
22
+ asm volatile("cp.reduce.async.bulk.global.shared::cta.bulk_group.add.f32 [%0], [%1], %2;\n"
23
+ :
24
+ : "l"(gmem_ptr), "r"(smem_int_ptr), "r"(store_bytes)
25
+ : "memory");
26
+ #else
27
+ CUTE_INVALID_CONTROL_PATH("Trying to use BULK_REDUCE_ADD without CUTE_ARCH_TMA_SM90_ENABLED.");
28
+ #endif
29
+ }
30
+
31
+ CUTE_HOST_DEVICE static void
32
+ copy(float const* smem_ptr,
33
+ float * gmem_ptr, int32_t store_bytes, uint64_t cache_hint)
34
+ {
35
+ #if defined(CUTE_ARCH_TMA_SM90_ENABLED)
36
+ uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr);
37
+ asm volatile("cp.reduce.async.bulk.global.shared::cta.bulk_group.L2::cache_hint.add.f32 [%0], [%1], %2, %3;\n"
38
+ :
39
+ : "l"(gmem_ptr), "r"(smem_int_ptr), "r"(store_bytes), "l"(cache_hint)
40
+ : "memory");
41
+ #else
42
+ CUTE_INVALID_CONTROL_PATH("Trying to use BULK_REDUCE_ADD without CUTE_ARCH_TMA_SM90_ENABLED.");
43
+ #endif
44
+ }
45
+ };
46
+
47
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
48
+
49
+ } // end namespace cute
flash-attn/cuda_check.h ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2024, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ #pragma once
6
+
7
+ #include <assert.h>
8
+ #include <stdlib.h>
9
+
10
+ #define CHECK_CUDA(call) \
11
+ do { \
12
+ cudaError_t status_ = call; \
13
+ if (status_ != cudaSuccess) { \
14
+ fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__, cudaGetErrorString(status_)); \
15
+ exit(1); \
16
+ } \
17
+ } while(0)
18
+
19
+ #define CHECK_CUDA_KERNEL_LAUNCH() CHECK_CUDA(cudaGetLastError())
flash-attn/epilogue_bwd.hpp ADDED
@@ -0,0 +1,533 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ #pragma once
6
+
7
+ #include "cutlass/cutlass.h"
8
+ #include "cutlass/barrier.h"
9
+ #include "cute/tensor.hpp"
10
+
11
+ #include "cutlass/gemm/collective/builders/sm90_common.inl"
12
+
13
+ #include "seqlen.h"
14
+ #include "named_barrier.hpp"
15
+ #include "utils.h"
16
+
17
+ namespace flash {
18
+
19
+ using namespace cute;
20
+
21
+ template <class TileShape_MNK_, class Element_, class ArchTag_,
22
+ int NumEpilogueThreads_, bool Varlen_, bool dKV_swapAB_, int AtomLayoutKdKV=1>
23
+ struct CollectiveEpilogueBwd {
24
+
25
+ using TileShape_MNK = TileShape_MNK_;
26
+ using Element = Element_;
27
+ using ArchTag = ArchTag_;
28
+ static constexpr int NumEpilogueThreads = NumEpilogueThreads_;
29
+ static constexpr bool Varlen = Varlen_;
30
+ static constexpr bool dKV_swapAB = dKV_swapAB_;
31
+ static constexpr bool Use_TMA = !Varlen && ArchTag::kMinComputeCapability >= 90;
32
+
33
+ static_assert(ArchTag::kMinComputeCapability >= 80);
34
+
35
+ using GmemTiledCopydKVTMA = cute::SM90_TMA_STORE;
36
+
37
+ // These are for storing the output tensor without TMA (e.g., for setting output to zero)
38
+ static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
39
+ static_assert(get<2>(TileShape_MNK{}) % kGmemElemsPerLoad == 0, "Headdim must be a multiple of kGmemElemsPerLoad");
40
+ static constexpr int kHeadDim = get<2>(TileShape_MNK{});
41
+ static constexpr int kGmemThreadsPerRow = cutlass::gcd(kHeadDim / kGmemElemsPerLoad, NumEpilogueThreads);
42
+ static_assert(NumEpilogueThreads % kGmemThreadsPerRow == 0, "NumEpilogueThreads must be a multiple of kGmemThreadsPerRow");
43
+ using GmemLayoutAtom = Layout<Shape <Int<NumEpilogueThreads / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
44
+ Stride<Int<kGmemThreadsPerRow>, _1>>;
45
+ using GmemTiledCopydKV = decltype(
46
+ make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>{},
47
+ GmemLayoutAtom{},
48
+ Layout<Shape<_1, Int<kGmemElemsPerLoad>>>{})); // Val layout, 8 or 16 vals per store
49
+
50
+ using SmemLayoutAtomdKVTMA = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
51
+ // TODO: do we have to change this if dKV_swapAB is true?
52
+ decltype(cute::get<1>(TileShape_MNK{})), Int<CUTE_STATIC_V(cute::get<2>(TileShape_MNK{})) / AtomLayoutKdKV>>());
53
+ using SmemLayoutdKVTMA = decltype(tile_to_shape(SmemLayoutAtomdKVTMA{}, select<1, 2>(TileShape_MNK{})));
54
+ using SmemLayoutdKVtTMA =
55
+ decltype(cute::composition(SmemLayoutdKVTMA{},
56
+ make_layout(make_shape(get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{})),
57
+ make_stride(decltype(get<1>(TileShape_MNK{})){}, _1{}))));
58
+
59
+ // If we don't use TMA
60
+ static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : (kHeadDim % 32 == 0 ? 32 : 16);
61
+ static constexpr int kSwizzle = kBlockKSmem == 64 ? 3 : (kBlockKSmem == 32 ? 2 : 1);
62
+ using SmemLayoutAtomdKVSTG =
63
+ decltype(composition(Swizzle<kSwizzle, 3, 3>{},
64
+ Layout<Shape<Int<8>, Int<kBlockKSmem>>,
65
+ Stride<Int<kBlockKSmem>, _1>>{}));
66
+
67
+ using SmemLayoutAtomdKV = std::conditional_t<Use_TMA, SmemLayoutAtomdKVTMA, SmemLayoutAtomdKVSTG>;
68
+ using SmemLayoutdKV = decltype(tile_to_shape(SmemLayoutAtomdKV{}, select<1, 2>(TileShape_MNK{})));
69
+ using SmemLayoutdKVt =
70
+ decltype(cute::composition(SmemLayoutdKV{},
71
+ make_layout(make_shape(get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{})),
72
+ make_stride(decltype(get<1>(TileShape_MNK{})){}, _1{}))));
73
+
74
+ using SmemCopyAtomdKV = Copy_Atom<
75
+ std::conditional_t<
76
+ ArchTag::kMinComputeCapability >= 90,
77
+ std::conditional_t<!dKV_swapAB, cute::SM90_U32x4_STSM_N, cute::SM90_U16x8_STSM_T>,
78
+ AutoVectorizingCopyWithAssumedAlignment<128>
79
+ >,
80
+ Element>;
81
+
82
+ static constexpr size_t SmemAlignmentdKV = ArchTag::kMinComputeCapability >= 90 ? cutlass::detail::alignment_for_swizzle(SmemLayoutdKV{}) : 128;
83
+ static_assert(SmemAlignmentdKV >= 128, "Require at least 128B alignment");
84
+
85
+ struct TensorStorage : cute::aligned_struct<SmemAlignmentdKV> {
86
+ cute::array_aligned<Element, cute::cosize_v<SmemLayoutdKV>, SmemAlignmentdKV> smem_dk;
87
+ cute::array_aligned<Element, cute::cosize_v<SmemLayoutdKV>, SmemAlignmentdKV> smem_dv;
88
+ };
89
+
90
+ using ShapedKV = cute::Shape<int32_t, int32_t, int32_t, int32_t>; // (seqlen_k, d, head, batch)
91
+ using StridedKV = cute::Stride<int64_t, _1, int64_t, int64_t>;
92
+
93
+ using TMA_dKV = std::conditional_t<
94
+ Use_TMA,
95
+ decltype(make_tma_copy(
96
+ GmemTiledCopydKVTMA{},
97
+ make_tensor(make_gmem_ptr(static_cast<Element*>(nullptr)), ShapedKV{}, StridedKV{}),
98
+ SmemLayoutdKVTMA{},
99
+ select<1, 2>(TileShape_MNK{}),
100
+ _1{})), // no mcast for dKV
101
+ std::nullptr_t
102
+ >;
103
+
104
+ // Host side kernel arguments
105
+ struct Arguments {
106
+ Element* ptr_dK;
107
+ ShapedKV const shape_dK;
108
+ StridedKV const stride_dK;
109
+ Element* ptr_dV;
110
+ ShapedKV const shape_dV;
111
+ StridedKV const stride_dV;
112
+ int const num_heads_q;
113
+ int* dk_semaphore;
114
+ int* dv_semaphore;
115
+ int const* cu_seqlens;
116
+ int const* seqused;
117
+ };
118
+
119
+ // Device side kernel params
120
+ struct Params {
121
+ Element* ptr_dK;
122
+ ShapedKV const shape_dK;
123
+ StridedKV const stride_dK;
124
+ Element* ptr_dV;
125
+ ShapedKV const shape_dV;
126
+ StridedKV const stride_dV;
127
+ TMA_dKV tma_store_dK, tma_store_dV;
128
+ int const* cu_seqlens = nullptr;
129
+ int const* seqused = nullptr;
130
+ };
131
+
132
+ static Params
133
+ to_underlying_arguments(Arguments const& args) {
134
+ Tensor mdK = make_tensor(make_gmem_ptr(args.ptr_dK), args.shape_dK, args.stride_dK);
135
+ Tensor mdV = make_tensor(make_gmem_ptr(args.ptr_dV), args.shape_dV, args.stride_dV);
136
+ TMA_dKV tma_store_dK = [&] {
137
+ if constexpr (Use_TMA) {
138
+ return make_tma_copy(GmemTiledCopydKVTMA{}, mdK, SmemLayoutdKVTMA{}, select<1, 2>(TileShape_MNK{}), _1{}); // no mcast for dKV
139
+ } else {
140
+ return nullptr;
141
+ }
142
+ }();
143
+ TMA_dKV tma_store_dV = [&] {
144
+ if constexpr (Use_TMA) {
145
+ return make_tma_copy(GmemTiledCopydKVTMA{}, mdV, SmemLayoutdKVTMA{}, select<1, 2>(TileShape_MNK{}), _1{}); // no mcast for dKV
146
+ } else {
147
+ return nullptr;
148
+ }
149
+ }();
150
+ return {args.ptr_dK, args.shape_dK, args.stride_dK, args.ptr_dV, args.shape_dV, args.stride_dV,
151
+ tma_store_dK, tma_store_dV, args.cu_seqlens, args.seqused};
152
+ }
153
+
154
+ /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
155
+ CUTLASS_DEVICE
156
+ static void prefetch_tma_descriptors(Params const& params) {
157
+ if constexpr (Use_TMA) {
158
+ cute::prefetch_tma_descriptor(params.tma_store_dK.get_tma_descriptor());
159
+ cute::prefetch_tma_descriptor(params.tma_store_dV.get_tma_descriptor());
160
+ }
161
+ }
162
+
163
+ template <typename SharedStorage, typename FrgTensorO, typename TiledMma>
164
+ CUTLASS_DEVICE void
165
+ store(Params const& params,
166
+ FrgTensorO const& tdKrdK,
167
+ FrgTensorO const& tdVrdV,
168
+ SharedStorage& shared_storage,
169
+ TiledMma tiled_mma,
170
+ int thread_idx,
171
+ cute::tuple<int32_t, int32_t, int32_t> const& block_coord
172
+ ) {
173
+
174
+ auto [n_block, bidh, bidb] = block_coord;
175
+ Tensor sdK = cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_dk.data()), SmemLayoutdKV{}));
176
+ Tensor sdV = cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_dv.data()), SmemLayoutdKV{}));
177
+ Tensor sdKt = cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_dk.data()), SmemLayoutdKVt{}));
178
+ Tensor sdVt = cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_dv.data()), SmemLayoutdKVt{}));
179
+ auto smem_tiled_copy_dKV = make_tiled_copy_C(SmemCopyAtomdKV{}, tiled_mma);
180
+ auto smem_thr_copy_dKV = smem_tiled_copy_dKV.get_thread_slice(thread_idx);
181
+
182
+ Tensor tdVrdV_out = make_tensor_like<Element>(tdVrdV);
183
+ flash::convert_type_out(tdVrdV, tdVrdV_out);
184
+ Tensor tdKrdK_out = make_tensor_like<Element>(tdKrdK);
185
+ flash::convert_type_out(tdKrdK, tdKrdK_out);
186
+ Tensor taccdKrdK = smem_thr_copy_dKV.retile_S(tdKrdK_out); // ((Atom,AtomNum), MMA_M, MMA_N)
187
+ Tensor taccdVrdV = smem_thr_copy_dKV.retile_S(tdVrdV_out); // ((Atom,AtomNum), MMA_M, MMA_N)
188
+ // if (blockIdx.x == 0 && threadIdx.x == 128) { print(smem_thr_copy_dKV); print(sdK); printf("\n"); print(sdKt); printf("\n"); }
189
+ Tensor taccdKsdK = smem_thr_copy_dKV.partition_D(cute::conditional_return<!dKV_swapAB>(sdK, sdKt)); // ((Atom,AtomNum),PIPE_M,PIPE_N)
190
+ Tensor taccdVsdV = smem_thr_copy_dKV.partition_D(cute::conditional_return<!dKV_swapAB>(sdV, sdVt)); // ((Atom,AtomNum),PIPE_M,PIPE_N)
191
+
192
+ // Make sure all WGs have finished reading K and V
193
+ flash::named_barrier_sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
194
+ cute::copy(smem_tiled_copy_dKV, taccdVrdV, taccdVsdV);
195
+ cute::copy(smem_tiled_copy_dKV, taccdKrdK, taccdKsdK);
196
+ if constexpr (Use_TMA) {
197
+ cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible to TMA
198
+ cutlass::arch::NamedBarrier::arrive(NumEpilogueThreads + cutlass::NumThreadsPerWarp,
199
+ cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
200
+
201
+ Tensor mdK = params.tma_store_dK.get_tma_tensor(params.shape_dK);
202
+ Tensor mdV = params.tma_store_dV.get_tma_tensor(params.shape_dV);
203
+ Tensor gdK = local_tile(mdK(_, _, bidh, bidb), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K)
204
+ Tensor gdV = local_tile(mdV(_, _, bidh, bidb), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K)
205
+ auto block_tma_dK = params.tma_store_dK.get_slice(_0{});
206
+ auto block_tma_dV = params.tma_store_dV.get_slice(_0{});
207
+ Tensor tdKgdK = block_tma_dK.partition_D(gdK); // (TMA, TMA_M, TMA_K)
208
+ Tensor tdKsdK = block_tma_dK.partition_S(sdK); // (TMA, TMA_M, TMA_K)
209
+ Tensor tdVgdV = block_tma_dV.partition_D(gdV); // (TMA, TMA_M, TMA_K)
210
+ Tensor tdVsdV = block_tma_dV.partition_S(sdV); // (TMA, TMA_M, TMA_K)
211
+ int warp_idx_sync = __shfl_sync(0xffffffff, thread_idx / cutlass::NumThreadsPerWarp, 0);
212
+ if (warp_idx_sync == NumEpilogueThreads / cutlass::NumThreadsPerWarp - 1) {
213
+ cutlass::arch::NamedBarrier::sync(NumEpilogueThreads + cutlass::NumThreadsPerWarp,
214
+ cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
215
+ if (cute::elect_one_sync()) {
216
+ cute::copy(params.tma_store_dV, tdVsdV, tdVgdV);
217
+ cute::copy(params.tma_store_dK, tdKsdK, tdKgdK);
218
+ tma_store_arrive();
219
+ }
220
+ }
221
+ tma_store_wait<0>();
222
+ // // Tell warp 0 that smem_k and smem_v are ready
223
+ // cutlass::arch::NamedBarrier::arrive(NumEpilogueThreads + cutlass::NumThreadsPerWarp, static_cast<uint32_t>(BwdNamedBarriers::KVEmpty) /*id*/);
224
+
225
+ } else {
226
+ flash::named_barrier_sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
227
+ static constexpr int kBlockN = get<1>(TileShape_MNK{});
228
+ flash::SeqlenInfo<Varlen, kBlockN> seqlen_info{bidb, size<0>(params.shape_dK), params.cu_seqlens, params.seqused};
229
+ bool const is_varlen = Varlen && params.cu_seqlens;
230
+ Tensor mdK = make_tensor(make_gmem_ptr(params.ptr_dK), params.shape_dK, params.stride_dK)(_, _, bidh, !is_varlen ? bidb : 0);
231
+ Tensor gdK = local_tile(cute::domain_offset(make_coord(seqlen_info.offset, _0{}), mdK), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K)
232
+ Tensor mdV = make_tensor(make_gmem_ptr(params.ptr_dV), params.shape_dV, params.stride_dV)(_, _, bidh, !is_varlen ? bidb : 0);
233
+ Tensor gdV = local_tile(cute::domain_offset(make_coord(seqlen_info.offset, _0{}), mdV), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K)
234
+
235
+ GmemTiledCopydKV gmem_tiled_copy_dKV;
236
+ auto gmem_thr_copy_dKV = gmem_tiled_copy_dKV.get_thread_slice(thread_idx);
237
+ Tensor tdKVgdV = gmem_thr_copy_dKV.partition_D(gdV);
238
+ Tensor tdKVsdV = gmem_thr_copy_dKV.partition_S(sdV); // (TMA, TMA_M, TMA_K)
239
+ Tensor tdKVgdK = gmem_thr_copy_dKV.partition_D(gdK);
240
+ Tensor tdKVsdK = gmem_thr_copy_dKV.partition_S(sdK); // (TMA, TMA_M, TMA_K)
241
+ Tensor tdKVrdV = make_fragment_like(tdKVgdV);
242
+ Tensor tdKVrdK = make_fragment_like(tdKVgdK);
243
+ Tensor cdKV = cute::make_identity_tensor(select<1, 2>(TileShape_MNK{})); // (BLK_N,BLK_K) -> (blk_n,blk_k)
244
+ // Repeat the partitioning with identity layouts
245
+ Tensor tdKVcdKV = gmem_thr_copy_dKV.partition_D(cdKV);
246
+ Tensor tdKVpdV = make_tensor<bool>(make_shape(size<2>(tdKVgdV)));
247
+ Tensor tdKVpdK = make_tensor<bool>(make_shape(size<2>(tdKVgdK)));
248
+ #pragma unroll
249
+ for (int k = 0; k < size(tdKVpdV); ++k) { tdKVpdV(k) = get<1>(tdKVcdKV(_0{}, _0{}, k)) < get<1>(params.shape_dV); }
250
+ #pragma unroll
251
+ for (int k = 0; k < size(tdKVpdK); ++k) { tdKVpdK(k) = get<1>(tdKVcdKV(_0{}, _0{}, k)) < get<1>(params.shape_dK); }
252
+ // Need to check OOB when reading from smem if kBlockN isn't evenly tiled
253
+ static constexpr bool EvenN = kBlockN % CUTE_STATIC_V(size<0>(GmemLayoutAtom{})) == 0;
254
+ flash::copy</*Is_even_MN=*/EvenN, /*Is_even_K=*/true, /*Clear_OOB_MN=*/false>(
255
+ gmem_tiled_copy_dKV, tdKVsdV, tdKVrdV, tdKVcdKV, tdKVpdV, kBlockN);
256
+ flash::copy</*Is_even_MN=*/EvenN, /*Is_even_K=*/true, /*Clear_OOB_MN=*/false>(
257
+ gmem_tiled_copy_dKV, tdKVsdK, tdKVrdK, tdKVcdKV, tdKVpdK, kBlockN);
258
+ // // Tell warp 0 that smem_k and smem_v are ready
259
+ // cutlass::arch::fence_view_async_shared(); // ensure smem reads are done before next TMA to smem_k/v
260
+ // flash::named_barrier_arrive(NumEpilogueThreads + cutlass::NumThreadsPerWarp, static_cast<uint32_t>(BwdNamedBarriers::KVEmpty) /*id*/);
261
+ // Construct identity layout for gdKV
262
+ // Clear_OOB_K must be false since we don't want to write zeros to gmem
263
+ flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
264
+ gmem_tiled_copy_dKV, tdKVrdV, tdKVgdV, tdKVcdKV, tdKVpdV, std::min(seqlen_info.seqlen - n_block * kBlockN, kBlockN)
265
+ );
266
+ flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
267
+ gmem_tiled_copy_dKV, tdKVrdK, tdKVgdK, tdKVcdKV, tdKVpdK, std::min(seqlen_info.seqlen - n_block * kBlockN, kBlockN)
268
+ );
269
+ }
270
+ }
271
+
272
+ CUTLASS_DEVICE void
273
+ store_tail() {
274
+ // if constexpr (Use_TMA) { tma_store_wait<0>(); }
275
+ }
276
+
277
+ // Write 0 to dK and dV
278
+ CUTLASS_DEVICE void
279
+ store_zero(
280
+ Params const& params,
281
+ int thread_idx,
282
+ cute::tuple<int32_t, int32_t, int32_t> const& block_coord
283
+ ) {
284
+ static constexpr int kBlockN = get<1>(TileShape_MNK{});
285
+ auto [n_block, bidh, bidb] = block_coord;
286
+ flash::SeqlenInfo<Varlen, kBlockN> seqlen_info{bidb, size<0>(params.shape_dK), params.cu_seqlens, params.seqused};
287
+ bool const is_varlen = Varlen && params.cu_seqlens;
288
+ Tensor mdK = make_tensor(make_gmem_ptr(params.ptr_dK), params.shape_dK, params.stride_dK)(_, _, bidh, !is_varlen ? bidb : 0);
289
+ Tensor gdK = local_tile(cute::domain_offset(make_coord(seqlen_info.offset, _0{}), mdK), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K)
290
+ Tensor mdV = make_tensor(make_gmem_ptr(params.ptr_dV), params.shape_dV, params.stride_dV)(_, _, bidh, !is_varlen ? bidb : 0);
291
+ Tensor gdV = local_tile(cute::domain_offset(make_coord(seqlen_info.offset, _0{}), mdV), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K)
292
+
293
+ GmemTiledCopydKV gmem_tiled_copy_dKV;
294
+ auto gmem_thr_copy_dKV = gmem_tiled_copy_dKV.get_thread_slice(thread_idx);
295
+ Tensor tdKVgdK = gmem_thr_copy_dKV.partition_D(gdK);
296
+ Tensor tdKVgdV = gmem_thr_copy_dKV.partition_D(gdV);
297
+ Tensor tdKVrdKV = make_fragment_like(tdKVgdK);
298
+ clear(tdKVrdKV);
299
+ // Construct identity layout for gdKV
300
+ Tensor cdKV = cute::make_identity_tensor(select<1, 2>(TileShape_MNK{})); // (BLK_M,BLK_K) -> (blk_m,blk_k)
301
+ // Repeat the partitioning with identity layouts
302
+ Tensor tdKVcdKV = gmem_thr_copy_dKV.partition_D(cdKV);
303
+ Tensor tdKVpdK = make_tensor<bool>(make_shape(size<2>(tdKVgdK)));
304
+ Tensor tdKVpdV = make_tensor<bool>(make_shape(size<2>(tdKVgdV)));
305
+ #pragma unroll
306
+ for (int k = 0; k < size(tdKVpdK); ++k) { tdKVpdK(k) = get<1>(tdKVcdKV(_0{}, _0{}, k)) < get<1>(params.shape_dK); }
307
+ #pragma unroll
308
+ for (int k = 0; k < size(tdKVpdV); ++k) { tdKVpdV(k) = get<1>(tdKVcdKV(_0{}, _0{}, k)) < get<1>(params.shape_dV); }
309
+ // Clear_OOB_K must be false since we don't want to write zeros to gmem
310
+ flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
311
+ gmem_tiled_copy_dKV, tdKVrdKV, tdKVgdK, tdKVcdKV, tdKVpdK, seqlen_info.seqlen - n_block * kBlockN
312
+ );
313
+ flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
314
+ gmem_tiled_copy_dKV, tdKVrdKV, tdKVgdV, tdKVcdKV, tdKVpdV, seqlen_info.seqlen - n_block * kBlockN
315
+ );
316
+ }
317
+
318
+ };
319
+
320
+ template <class TileShape_MNK_, class ElementAccum, class ArchTag_,
321
+ int NumEpilogueThreads_, bool Varlen_, bool Deterministic>
322
+ struct CollectiveEpilogueBwdGQA {
323
+
324
+ using TileShape_MNK = TileShape_MNK_;
325
+ using Element = ElementAccum;
326
+ using ArchTag = ArchTag_;
327
+ static constexpr int NumEpilogueThreads = NumEpilogueThreads_;
328
+ static constexpr bool Varlen = Varlen_;
329
+ static constexpr bool Use_TMA = ArchTag::kMinComputeCapability >= 90;
330
+
331
+ static_assert(ArchTag::kMinComputeCapability >= 80);
332
+
333
+ static constexpr int kBlockN = get<1>(TileShape_MNK{});
334
+ static constexpr int kHeadDim = get<2>(TileShape_MNK{});
335
+ static_assert(NumEpilogueThreads % cutlass::NumThreadsPerWarp == 0, "NumEpilogueThreads must be a multiple of NumThreadsPerWarp");
336
+ static constexpr int NumWarpGroups = NumEpilogueThreads / cutlass::NumThreadsPerWarpGroup;
337
+ // Thread layout, 256 or 384 threads per row
338
+ // We split into NumWarpGroups so that we can use the same postprocessing kernel as dQ
339
+ using R2SLayoutAtomdKVaccum = Layout<Shape<Int<cutlass::NumThreadsPerWarpGroup>, Int<NumWarpGroups>>>;
340
+ using R2STiledCopydKVaccum = decltype(make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{}, R2SLayoutAtomdKVaccum{},
341
+ Layout<Shape < _4>>{})); // Val layout, 4 vals per store
342
+ // For Sm80
343
+ using R2GLayoutAtomdKVaccum = Layout<Shape<Int<NumEpilogueThreads>>>;
344
+ using R2GTiledCopydKVaccum = decltype(make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{}, R2GLayoutAtomdKVaccum{},
345
+ Layout<Shape < _1>>{})); // Val layout, 1 vals per store
346
+
347
+ using SmemLayoutdKVaccum = Layout<Shape<Int<kBlockN * kHeadDim / NumWarpGroups>, Int<NumWarpGroups>>>;
348
+ using SmemLayoutdKVaccumFlat = Layout<Shape<Int<kBlockN * kHeadDim>>>;
349
+
350
+ // Strangely without this SmemAlignment, the total smem for hdim 128 (80 x 128) is 228KB even though we
351
+ // only need 227KB. We use the same alignment as the non-GQA epilogue to avoid this issue.
352
+ static constexpr int SmemAlignment = kHeadDim % 64 == 0 ? 1024 : (kHeadDim % 32 == 0 ? 512 : 256);
353
+ struct TensorStorageTMA : cute::aligned_struct<SmemAlignment> {
354
+ cute::array_aligned<ElementAccum, cute::cosize_v<SmemLayoutdKVaccum>, SmemAlignment> smem_dkv;
355
+ };
356
+ struct TensorStorageSTG {
357
+ cute::array<ElementAccum, 0> smem_dkv;
358
+ };
359
+ using TensorStorage = std::conditional_t<Use_TMA, TensorStorageTMA, TensorStorageSTG>;
360
+
361
+ using ShapedKV = cute::Shape<int32_t, int32_t, int32_t>; // (seqlen_k_rounded * d, head, batch)
362
+ using StridedKV = cute::Stride<_1, int64_t, int64_t>;
363
+
364
+ // Host side kernel arguments
365
+ struct Arguments {
366
+ ElementAccum* ptr_dKaccum;
367
+ ShapedKV const shape_dKaccum;
368
+ StridedKV const stride_dKaccum;
369
+ ElementAccum* ptr_dVaccum;
370
+ ShapedKV const shape_dVaccum;
371
+ StridedKV const stride_dVaccum;
372
+ int num_heads_q;
373
+ int* dk_semaphore;
374
+ int* dv_semaphore;
375
+ int const* cu_seqlens;
376
+ int const* seqused;
377
+ };
378
+
379
+ // Device side kernel params
380
+ struct Params {
381
+ ElementAccum* ptr_dKaccum;
382
+ ShapedKV const shape_dKaccum;
383
+ StridedKV const stride_dKaccum;
384
+ ElementAccum* ptr_dVaccum;
385
+ ShapedKV const shape_dVaccum;
386
+ StridedKV const stride_dVaccum;
387
+ cutlass::FastDivmod qhead_per_khead_divmod;
388
+ int* dk_semaphore;
389
+ int* dv_semaphore;
390
+ int const* cu_seqlens = nullptr;
391
+ int const* seqused = nullptr;
392
+ };
393
+
394
+ static Params
395
+ to_underlying_arguments(Arguments const& args) {
396
+ if constexpr (Deterministic) {
397
+ assert(args.dk_semaphore != nullptr);
398
+ assert(args.dv_semaphore != nullptr);
399
+ }
400
+ return {args.ptr_dKaccum, args.shape_dKaccum, args.stride_dKaccum, args.ptr_dVaccum, args.shape_dVaccum, args.stride_dVaccum,
401
+ cutlass::FastDivmod(cute::ceil_div(args.num_heads_q, get<1>(args.shape_dKaccum))),
402
+ args.dk_semaphore, args.dv_semaphore,
403
+ args.cu_seqlens, args.seqused};
404
+ }
405
+
406
+ /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
407
+ CUTLASS_DEVICE
408
+ static void prefetch_tma_descriptors(Params const& params) {
409
+ }
410
+
411
+ template <typename SharedStorage, typename FrgTensorO, typename TiledMma>
412
+ CUTLASS_DEVICE void
413
+ store(Params const& params,
414
+ FrgTensorO const& tdKrdK,
415
+ FrgTensorO const& tdVrdV,
416
+ SharedStorage& shared_storage,
417
+ TiledMma tiled_mma,
418
+ int thread_idx,
419
+ cute::tuple<int32_t, int32_t, int32_t> const& block_coord
420
+ ) {
421
+
422
+ auto [n_block, bidh, bidb] = block_coord;
423
+ int bidh_idx_in_group;
424
+ int bidh_kv = params.qhead_per_khead_divmod.divmod(bidh_idx_in_group, bidh);
425
+ Tensor sdKV = make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_dkv.data()), SmemLayoutdKVaccum{});
426
+ Tensor sdKV_flat = make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_dkv.data()), SmemLayoutdKVaccumFlat{});
427
+ static constexpr int dKV_TMA_num_bytes = CUTE_STATIC_V(size(sdKV_flat)) * sizeof(ElementAccum);
428
+
429
+ flash::SeqlenInfo<Varlen, kBlockN> seqlen_info{bidb, size<0>(params.shape_dKaccum), params.cu_seqlens, params.seqused};
430
+ bool const is_varlen = Varlen && params.cu_seqlens;
431
+ Tensor mdKaccum = make_tensor(make_gmem_ptr(params.ptr_dKaccum), params.shape_dKaccum, params.stride_dKaccum)(_, bidh_kv, !is_varlen ? bidb : 0);
432
+ Tensor mdVaccum = make_tensor(make_gmem_ptr(params.ptr_dVaccum), params.shape_dVaccum, params.stride_dVaccum)(_, bidh_kv, !is_varlen ? bidb : 0);
433
+ Tensor gdKaccum = local_tile(domain_offset(make_coord(seqlen_info.offset_padded * kHeadDim), mdKaccum), Shape<Int<kBlockN * kHeadDim>>{}, make_coord(n_block)); // (M * K)
434
+ Tensor gdVaccum = local_tile(domain_offset(make_coord(seqlen_info.offset_padded * kHeadDim), mdVaccum), Shape<Int<kBlockN * kHeadDim>>{}, make_coord(n_block)); // (M * K)
435
+
436
+ R2STiledCopydKVaccum r2s_tiled_copy_dKVaccum;
437
+ auto r2s_thr_copy_dKVaccum = r2s_tiled_copy_dKVaccum.get_thread_slice(thread_idx);
438
+ Tensor tdKVsdKVaccum = r2s_thr_copy_dKVaccum.partition_D(sdKV);
439
+
440
+ // Only used if !Use_TMA
441
+ R2GTiledCopydKVaccum r2g_tiled_copy_dKVaccum;
442
+ auto r2g_thr_copy_dKVaccum = r2g_tiled_copy_dKVaccum.get_thread_slice(thread_idx);
443
+
444
+ // Make sure all WGs have finished reading K and V, otherwise we get racy dQ
445
+ // because smem_q could be changed.
446
+ flash::named_barrier_sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
447
+ if constexpr (Use_TMA) {
448
+ Tensor taccdKVrdV = r2s_thr_copy_dKVaccum.retile_S(tdVrdV); // ((Atom,AtomNum), MMA_M, MMA_N)
449
+ cute::copy(r2s_tiled_copy_dKVaccum, taccdKVrdV, tdKVsdKVaccum);
450
+ }
451
+
452
+ // int const num_batch = params.num_batch;
453
+ int const num_batch = get<2>(params.shape_dKaccum);
454
+ int const num_head_kv = get<1>(params.shape_dKaccum);
455
+ int *lock_ptr = !Deterministic ? nullptr : params.dv_semaphore + bidb * num_head_kv + bidh_kv;
456
+ using Barrier = cutlass::GenericBarrier<cutlass::detail::SyncwarpSync>;
457
+
458
+ // if (thread_idx == 0) { printf("blockIdx.x = %d, blockIdx.y = %d, blockIdx.z = %d, bidb = %d, bidh_kv = %d, lock_ptr = %p, dv_semaphore = %p, num_batch = %d, num_head_kv = %d, n_block = %d, bihd_idx_in_group = %d\n", blockIdx.x, blockIdx.y, blockIdx.z, bidb, bidh_kv, lock_ptr, params.dv_semaphore, num_batch, num_head_kv, n_block, bidh_idx_in_group);}
459
+
460
+ if constexpr (Deterministic) {
461
+ Barrier::wait_eq(lock_ptr, thread_idx, n_block * num_batch * num_head_kv, bidh_idx_in_group);
462
+ }
463
+ // if (thread_idx == 0) { printf("After barrier blockIdx.x = %d, blockIdx.y = %d, blockIdx.z = %d, bidb = %d, bidh_kv = %d, lock_ptr = %p, dv_semaphore = %p\n", blockIdx.x, blockIdx.y, blockIdx.z, bidb, bidh_kv, lock_ptr, params.dv_semaphore);}
464
+ if constexpr (Use_TMA) {
465
+ cutlass::arch::fence_view_async_shared();
466
+ cutlass::arch::NamedBarrier::sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
467
+ if (thread_idx == 0) {
468
+ SM90_BULK_REDUCE_ADD::copy(raw_pointer_cast(sdKV_flat.data()), raw_pointer_cast(gdVaccum.data()), dKV_TMA_num_bytes, static_cast<uint64_t>(TMA::CacheHintSm90::EVICT_LAST));
469
+ tma_store_arrive();
470
+ tma_store_wait<0>();
471
+ }
472
+ } else {
473
+ Tensor tdVrdV_atomic = r2g_thr_copy_dKVaccum.retile_S(tdVrdV);
474
+ Tensor tdVgdV_atomic = r2g_thr_copy_dKVaccum.partition_D(gdVaccum);
475
+ static_assert(CUTE_STATIC_V(size(tdVrdV_atomic)) == CUTE_STATIC_V(size(tdVgdV_atomic)));
476
+ #pragma unroll
477
+ for (int i = 0; i < size(tdVrdV_atomic); ++i) { atomicAdd(&tdVgdV_atomic(i), tdVrdV_atomic(i)); }
478
+ }
479
+ if constexpr (Deterministic) {
480
+ Barrier::arrive_inc(lock_ptr, thread_idx, n_block * num_batch * num_head_kv);
481
+ }
482
+
483
+ if constexpr (Use_TMA) {
484
+ cutlass::arch::NamedBarrier::sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
485
+ Tensor taccdKVrdK = r2s_thr_copy_dKVaccum.retile_S(tdKrdK); // ((Atom,AtomNum), MMA_M, MMA_N)
486
+ cute::copy(r2s_tiled_copy_dKVaccum, taccdKVrdK, tdKVsdKVaccum);
487
+ }
488
+ lock_ptr = !Deterministic ? nullptr : params.dk_semaphore + bidb * num_head_kv + bidh_kv;
489
+ // if (thread_idx == 0) { printf("blockIdx.x = %d, blockIdx.y = %d, blockIdx.z = %d, bidb = %d, bidh_kv = %d, lock_ptr = %p, dk_semaphore = %p, num_batch = %d, num_head_kv = %d, n_block = %d, bihd_idx_in_group = %d\n", blockIdx.x, blockIdx.y, blockIdx.z, bidb, bidh_kv, lock_ptr, params.dk_semaphore, num_batch, num_head_kv, n_block, bidh_idx_in_group);}
490
+
491
+ if constexpr (Deterministic) {
492
+ Barrier::wait_eq(lock_ptr, thread_idx, n_block * num_batch * num_head_kv, bidh_idx_in_group);
493
+ }
494
+ // if (thread_idx == 0) { printf("After barrier blockIdx.x = %d, blockIdx.y = %d, blockIdx.z = %d, bidb = %d, bidh_kv = %d, lock_ptr = %p, dk_semaphore = %p\n", blockIdx.x, blockIdx.y, blockIdx.z, bidb, bidh_kv, lock_ptr, params.dk_semaphore);}
495
+ if constexpr (Use_TMA) {
496
+ cutlass::arch::fence_view_async_shared();
497
+ cutlass::arch::NamedBarrier::sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
498
+ if (thread_idx == 0) {
499
+ SM90_BULK_REDUCE_ADD::copy(raw_pointer_cast(sdKV_flat.data()), raw_pointer_cast(gdKaccum.data()), dKV_TMA_num_bytes, static_cast<uint64_t>(TMA::CacheHintSm90::EVICT_LAST));
500
+ tma_store_arrive();
501
+ tma_store_wait<0>();
502
+ }
503
+ } else {
504
+ Tensor tdKrdK_atomic = r2g_thr_copy_dKVaccum.retile_S(tdKrdK);
505
+ Tensor tdKgdK_atomic = r2g_thr_copy_dKVaccum.partition_D(gdKaccum);
506
+ static_assert(CUTE_STATIC_V(size(tdKrdK_atomic)) == CUTE_STATIC_V(size(tdKgdK_atomic)));
507
+ #pragma unroll
508
+ for (int i = 0; i < size(tdKrdK_atomic); ++i) { atomicAdd(&tdKgdK_atomic(i), tdKrdK_atomic(i)); }
509
+ }
510
+ if constexpr (Deterministic) {
511
+ Barrier::arrive_inc(lock_ptr, thread_idx, n_block * num_batch * num_head_kv);
512
+ }
513
+ // // Tell warp 0 that smem_k and smem_v are ready
514
+ // flash::named_barrier_arrive(NumEpilogueThreads + cutlass::NumThreadsPerWarp, static_cast<uint32_t>(BwdNamedBarriers::KVEmpty) /*id*/);
515
+ }
516
+
517
+ CUTLASS_DEVICE void
518
+ store_tail() {
519
+ }
520
+
521
+ // Write 0 to dK and dV
522
+ CUTLASS_DEVICE void
523
+ store_zero(
524
+ Params const& params,
525
+ int thread_idx,
526
+ cute::tuple<int32_t, int32_t, int32_t> const& block_coord
527
+ ) {
528
+ // Don't need to do anything since dKaccum and dVaccum are already zero-initialized
529
+ }
530
+
531
+ };
532
+
533
+ } // namespace flash
flash-attn/epilogue_fwd.hpp ADDED
@@ -0,0 +1,484 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ #pragma once
6
+
7
+ #include <cutlass/cutlass.h>
8
+ #include <cutlass/fast_math.h> // For FastDivMod
9
+ #include "cute/tensor.hpp"
10
+
11
+ #include "cutlass/gemm/collective/builders/sm90_common.inl"
12
+ #include "cutlass/epilogue/collective/builders/sm90_common.inl"
13
+
14
+ #include "seqlen.h"
15
+ #include "named_barrier.hpp"
16
+ #include "pack_gqa.h"
17
+ #include "utils.h"
18
+
19
+ namespace flash {
20
+
21
+ using namespace cute;
22
+
23
+ template <class TileShape_MNK_PV_, class ClusterShape_, class Element_, class ArchTag_,
24
+ int NumEpilogueThreads_, bool Varlen_, bool PackGQA_, bool Split_, bool FP8PermuteCol=false>
25
+ struct CollectiveEpilogueFwd {
26
+
27
+ using TileShape_MNK_PV = TileShape_MNK_PV_;
28
+ using ClusterShape = ClusterShape_;
29
+ using Element = Element_;
30
+ using ElementPartial = float;
31
+ using ArchTag = ArchTag_;
32
+ static constexpr int NumEpilogueThreads = NumEpilogueThreads_;
33
+ static constexpr bool Varlen = Varlen_;
34
+ static constexpr bool PackGQA = PackGQA_;
35
+ static constexpr bool Split = Split_;
36
+ static constexpr bool Use_smem = !(Split && !Varlen);
37
+ static constexpr bool Use_TMA_O = ArchTag::kMinComputeCapability >= 90 && !Varlen && !Split && !PackGQA;
38
+
39
+ static_assert(ArchTag::kMinComputeCapability >= 80);
40
+ static_assert(ArchTag::kMinComputeCapability >= 90 || CUTE_STATIC_V(size(ClusterShape{})) == 1);
41
+ static_assert(sizeof(Element) <= 2);
42
+
43
+ static constexpr int kBlockM = get<0>(TileShape_MNK_PV{});
44
+ static constexpr int kHeadDimV = get<1>(TileShape_MNK_PV{});
45
+
46
+ static constexpr bool LargeHeadDimV = kHeadDimV > 256;
47
+
48
+ using GmemTiledCopyOTMA = cute::SM90_TMA_STORE;
49
+
50
+ // These are for storing the output tensor without TMA (e.g., for setting output to zero)
51
+ static constexpr int kGmemElemsPerStore = sizeof(cute::uint128_t) / sizeof(Element);
52
+ static_assert(kHeadDimV % kGmemElemsPerStore == 0, "Headdim must be a multiple of kGmemElemsPerStore");
53
+ // We want each "row" to have 64 elements (128 bytes, i.e. 1 cache line). We want each thread to have 4 elements
54
+ // in the M direction and 2 elements in the K direction. In the case of PackGQA, this reduces the number of times
55
+ // we need to call divmod.
56
+ static constexpr int kBytePerRow = kHeadDimV * sizeof(Element);
57
+ static constexpr int kBlockKGmem = (kBytePerRow % 128 == 0 ? 128 : (kBytePerRow % 64 == 0 ? 64 : 32)) / sizeof(Element);
58
+ static constexpr int kGmemThreadsPerRow = kBlockKGmem / kGmemElemsPerStore;
59
+ // If PackGQA, we split the work of compute O_ptr among threads in the same row, so we need this to within a warp
60
+ static_assert(cutlass::NumThreadsPerWarp % kGmemThreadsPerRow == 0);
61
+ static_assert(NumEpilogueThreads % kGmemThreadsPerRow == 0, "NumEpilogueThreads must be a multiple of kGmemThreadsPerRow");
62
+ using GmemLayoutAtom = Layout<Shape <Int<NumEpilogueThreads / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
63
+ Stride<Int<kGmemThreadsPerRow>, _1>>;
64
+ static_assert(kBlockM % CUTE_STATIC_V(shape<0>(GmemLayoutAtom{})) == 0, "kBlockM must be a multiple of NumEpilogueThreads / kGmemThreadsPerRow");
65
+ using GmemTiledCopyO = decltype(
66
+ make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>{},
67
+ GmemLayoutAtom{},
68
+ Layout<Shape<_1, Int<kGmemElemsPerStore>>>{})); // Val layout, 8 or 16 vals per store
69
+
70
+ using SmemLayoutAtomOTMA = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
71
+ decltype(cute::get<0>(TileShape_MNK_PV{})), decltype(cute::get<1>(TileShape_MNK_PV{}))>());
72
+ using SmemLayoutOTMA = decltype(tile_to_shape(SmemLayoutAtomOTMA{}, select<0, 1>(TileShape_MNK_PV{})));
73
+ static constexpr int kSwizzle = kBlockKGmem == 128 ? 4 : (kBlockKGmem == 64 ? 3 : (kBlockKGmem == 32 ? 2 : 1));
74
+ static constexpr int kSwizzleBase = sizeof(Element) == 4 ? 2 : (sizeof(Element) == 2 ? 3 : 4);
75
+ using SmemLayoutAtomO = decltype(
76
+ composition(Swizzle<kSwizzle, kSwizzleBase, kSwizzleBase>{},
77
+ Layout<Shape<_8, Int<kBlockKGmem>>,
78
+ Stride<Int<kBlockKGmem>, _1>>{}));
79
+ using SmemLayoutOSTS = decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 1>(TileShape_MNK_PV{})));
80
+ using SmemLayoutO = std::conditional_t<ArchTag::kMinComputeCapability >= 90, SmemLayoutOTMA, SmemLayoutOSTS>;
81
+
82
+ using ShapeO = cute::Shape<int32_t, int32_t, int32_t, int32_t, int32_t>; // (seqlen_q, d, head, batch, num_splits)
83
+ using StrideO = cute::Stride<int64_t, _1, int64_t, int64_t, int64_t>;
84
+ using StrideLSE = cute::Stride<_1, int64_t, int64_t, int64_t>; // (seqlen_q, head, batch, num_splits)
85
+ // ((qhead_per_khead, seqlen_q), d, nheads_kv, batch, num_splits)
86
+ using ShapeOPacked = std::conditional_t<!PackGQA, ShapeO, cute::Shape<cute::Shape<int32_t, int32_t>, int32_t, int32_t, int32_t, int32_t>>;
87
+ using StrideOPacked = std::conditional_t<!PackGQA, StrideO, cute::Stride<cute::Stride<int64_t, int64_t>, _1, int64_t, int64_t, int64_t>>;
88
+ // ((qhead_per_khead, seqlen_q), nheads_kv, batch, num_splits)
89
+ using ShapeLSEPacked = std::conditional_t<!PackGQA, cute::Shape<int32_t, int32_t, int32_t, int32_t>, cute::Shape<cute::Shape<int32_t, int32_t>, int32_t, int32_t, int32_t>>;
90
+ using StrideLSEPacked = std::conditional_t<!PackGQA, StrideLSE, cute::Stride<cute::Stride<int64_t, _1>, int64_t, int64_t, int64_t>>;
91
+
92
+ using CopyOpR2S = std::conditional_t<
93
+ ArchTag::kMinComputeCapability >= 90,
94
+ // cute::SM90_U32x4_STSM_N if Element size is 2 bytes (fp16, bf16)
95
+ decltype(cutlass::epilogue::collective::detail::sm90_get_smem_store_op_for_accumulator<StrideO, Element>()),
96
+ AutoVectorizingCopyWithAssumedAlignment<128>
97
+ >;
98
+ using SmemCopyAtomO = Copy_Atom<CopyOpR2S, Element>;
99
+
100
+ // static constexpr size_t SmemAlignmentO = cutlass::detail::alignment_for_swizzle(SmemLayoutO{});
101
+ // static_assert(SmemAlignmentO >= 128, "Require at least 128B alignment");
102
+ // struct TensorStorage : cute::aligned_struct<SmemAlignmentO> {
103
+ // cute::array_aligned<Element, Use_smem ? cute::cosize_v<SmemLayoutO> : 0, SmemAlignmentO> smem_o;
104
+ // };
105
+ struct TensorStorage : cute::aligned_struct<128> {
106
+ cute::array_aligned<Element, Use_smem ? cute::cosize_v<SmemLayoutO> : 0> smem_o;
107
+ };
108
+
109
+ using TMA_O = std::conditional_t<
110
+ Use_TMA_O,
111
+ decltype(make_tma_copy(
112
+ GmemTiledCopyOTMA{},
113
+ make_tensor(make_gmem_ptr(static_cast<Element*>(nullptr)), ShapeO{}, StrideO{}),
114
+ SmemLayoutOTMA{},
115
+ select<0, 1>(TileShape_MNK_PV{}),
116
+ _1{})), // no mcast for O
117
+ std::nullptr_t
118
+ >;
119
+
120
+ // Host side kernel arguments
121
+ struct Arguments {
122
+ Element* ptr_O;
123
+ ShapeO const shape_O;
124
+ StrideO const stride_O;
125
+ ElementPartial* ptr_O_partial;
126
+ StrideO const stride_O_partial;
127
+ float* ptr_LSE;
128
+ StrideLSE const stride_LSE;
129
+ float* ptr_LSE_partial;
130
+ StrideLSE const stride_LSE_partial;
131
+ int32_t const nheads_kv;
132
+ int const* cu_seqlens = nullptr;
133
+ int const* seqused = nullptr;
134
+ };
135
+
136
+ // Device side kernel params
137
+ struct Params {
138
+ Element* ptr_O;
139
+ ShapeO const shape_O;
140
+ StrideO const stride_O;
141
+ ShapeOPacked const shape_O_packed;
142
+ StrideOPacked const stride_O_packed;
143
+ ElementPartial* ptr_O_partial;
144
+ StrideO const stride_O_partial;
145
+ StrideOPacked const stride_O_partial_packed;
146
+ float* ptr_LSE;
147
+ StrideLSE const stride_LSE;
148
+ ShapeLSEPacked const shape_LSE_packed;
149
+ StrideLSEPacked const stride_LSE_packed;
150
+ float* ptr_LSE_partial;
151
+ StrideLSE const stride_LSE_partial;
152
+ StrideLSEPacked const stride_LSE_partial_packed;
153
+ cutlass::FastDivmod qhead_per_khead_divmod;
154
+ TMA_O tma_store_O;
155
+ int const* cu_seqlens = nullptr;
156
+ int const* seqused = nullptr;
157
+ };
158
+
159
+ static Params
160
+ to_underlying_arguments(Arguments const& args) {
161
+ Tensor mO = make_tensor(make_gmem_ptr(args.ptr_O), args.shape_O, args.stride_O);
162
+ TMA_O tma_store_O = [&]{
163
+ if constexpr (Use_TMA_O) {
164
+ return make_tma_copy(GmemTiledCopyOTMA{}, mO, SmemLayoutO{}, select<0, 1>(TileShape_MNK_PV{}), _1{}); // no mcast
165
+ } else {
166
+ return nullptr;
167
+ }
168
+ }();
169
+ // If PackGQA, reshape O to be ((qhead_per_khead, seqlen_q), head_size, nhead_k, batch_size, num_splits)
170
+ int const qhead_per_khead = !PackGQA ? 1 : cute::ceil_div(get<2>(args.shape_O), args.nheads_kv);
171
+ auto const shape_O_packed = cute::conditional_return<!PackGQA>(
172
+ args.shape_O,
173
+ make_shape(make_shape(qhead_per_khead, get<0>(args.shape_O)), get<1>(args.shape_O), args.nheads_kv, get<3>(args.shape_O), get<4>(args.shape_O))
174
+ );
175
+ auto const stride_O_packed = cute::conditional_return<!PackGQA>(
176
+ args.stride_O,
177
+ make_stride(make_stride(get<2>(args.stride_O), get<0>(args.stride_O)), get<1>(args.stride_O), get<2>(args.stride_O) * qhead_per_khead, get<3>(args.stride_O), get<4>(args.stride_O))
178
+ );
179
+ auto const stride_O_partial_packed = cute::conditional_return<!PackGQA>(
180
+ args.stride_O_partial,
181
+ make_stride(make_stride(get<2>(args.stride_O_partial), get<0>(args.stride_O_partial)), get<1>(args.stride_O_partial), get<2>(args.stride_O_partial) * qhead_per_khead, get<3>(args.stride_O_partial), get<4>(args.stride_O_partial))
182
+ );
183
+ // If PackGQA, Reshape LSE to be ((qhead_per_khead, seqlen_q), nhead_k, batch_size, num_splits)
184
+ auto const shape_LSE_packed = cute::conditional_return<!PackGQA>(
185
+ select<0, 2, 3, 4>(args.shape_O),
186
+ make_shape(make_shape(qhead_per_khead, get<0>(args.shape_O)), args.nheads_kv, get<3>(args.shape_O), get<4>(args.shape_O))
187
+ );
188
+ auto const stride_LSE_packed = cute::conditional_return<!PackGQA>(
189
+ args.stride_LSE,
190
+ make_stride(make_stride(get<1>(args.stride_LSE), get<0>(args.stride_LSE)), get<1>(args.stride_LSE) * qhead_per_khead, get<2>(args.stride_LSE), get<3>(args.stride_LSE))
191
+ );
192
+ auto const stride_LSE_partial_packed = cute::conditional_return<!PackGQA>(
193
+ args.stride_LSE_partial,
194
+ make_stride(make_stride(get<1>(args.stride_LSE_partial), get<0>(args.stride_LSE_partial)), get<1>(args.stride_LSE_partial) * qhead_per_khead, get<2>(args.stride_LSE_partial), get<3>(args.stride_LSE_partial))
195
+ );
196
+ return {args.ptr_O, args.shape_O, args.stride_O, shape_O_packed, stride_O_packed,
197
+ args.ptr_O_partial, args.stride_O_partial, stride_O_partial_packed,
198
+ args.ptr_LSE, args.stride_LSE, shape_LSE_packed, stride_LSE_packed,
199
+ args.ptr_LSE_partial, args.stride_LSE_partial, stride_LSE_partial_packed,
200
+ cutlass::FastDivmod(qhead_per_khead),
201
+ tma_store_O, args.cu_seqlens, args.seqused};
202
+ }
203
+
204
+ /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
205
+ CUTLASS_DEVICE
206
+ static void prefetch_tma_descriptors(Params const& params) {
207
+ if constexpr (Use_TMA_O) {
208
+ cute::prefetch_tma_descriptor(params.tma_store_O.get_tma_descriptor());
209
+ }
210
+ }
211
+
212
+ template <typename SharedStorage, typename FrgTensorO, typename FrgTensorLSE, typename TiledMma>
213
+ CUTLASS_DEVICE void
214
+ store(Params const& params,
215
+ FrgTensorO& tOrO,
216
+ FrgTensorLSE const& lse,
217
+ SharedStorage& shared_storage,
218
+ TiledMma tiled_mma,
219
+ int thread_idx,
220
+ cute::tuple<int32_t, int32_t, int32_t, int32_t> const& block_coord
221
+ ) {
222
+
223
+ auto [m_block, bidh, bidb, split_idx] = block_coord;
224
+ int num_splits = get<4>(params.shape_O_packed);
225
+ if constexpr (Split && Varlen) {
226
+ uint32_t num_splits_dynamic_u = reinterpret_cast<uint32_t const&>(split_idx) >> 16; // first 16 bits are for num_splits
227
+ int num_splits_dynamic = reinterpret_cast<int&>(num_splits_dynamic_u);
228
+ num_splits = num_splits_dynamic > 0 ? num_splits_dynamic : num_splits;
229
+ split_idx &= 0x0000FFFF; // Only use the lower 16 bits of split_idx
230
+ }
231
+ bool const is_split = !Split ? false : (!Varlen ? true : num_splits > 1);
232
+
233
+ Tensor sO = make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_o.data()), SmemLayoutO{});
234
+ // Tensor sO_pi = cute::as_position_independent_swizzle_tensor(sO);
235
+
236
+ static constexpr bool NeedFP8Permute = FP8PermuteCol && (sizeof(Element) == 2 || sizeof(Element) == 4);
237
+ // If we will possibly need tOrO in FP32, we'd want to permute tOrO before type conversion.
238
+ // Otherwise we can permute after conversion.
239
+ if constexpr (NeedFP8Permute && Split) { flash::permute_output_fp8_Vcolmajor(tOrO); }
240
+ Tensor tOrO_out = make_tensor_like<Element>(tOrO);
241
+ flash::convert_type_out(tOrO, tOrO_out);
242
+ if constexpr (NeedFP8Permute && !Split) { flash::permute_output_fp8_Vcolmajor(tOrO_out); }
243
+
244
+ // Make sure all WGs have finished reading V
245
+ // Technically we don't need this if we're not using smem, but the mainloop makes the assumption that
246
+ // all epilogue threads sync at least once during the epilogue (so that we can start loading Q with
247
+ // cp.async if we need).
248
+ flash::named_barrier_sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
249
+
250
+ // Step 1: Write O from rmem -> smem
251
+ if constexpr (Use_smem) {
252
+ auto smem_tiled_copy_O = make_tiled_copy_C(SmemCopyAtomO{}, tiled_mma);
253
+ auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(thread_idx);
254
+ Tensor taccOrO = smem_thr_copy_O.retile_S(tOrO_out); // ((Atom,AtomNum), MMA_M, MMA_N)
255
+ Tensor taccOsO = smem_thr_copy_O.partition_D(sO); // ((Atom,AtomNum),PIPE_M,PIPE_N)
256
+ // Tensor taccOsO = smem_thr_copy_O.partition_D(sO_pi); // ((Atom,AtomNum),PIPE_M,PIPE_N)
257
+ cute::copy(smem_tiled_copy_O, taccOrO, taccOsO);
258
+ if constexpr (Use_TMA_O) {
259
+ cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible to TMA
260
+ cutlass::arch::NamedBarrier::arrive(NumEpilogueThreads + cutlass::NumThreadsPerWarp,
261
+ cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
262
+ } else {
263
+ flash::named_barrier_sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
264
+ }
265
+ } else {
266
+ if constexpr (ArchTag::kMinComputeCapability >= 90) {
267
+ #pragma unroll
268
+ for (uint32_t cta_id = 0; cta_id < size(ClusterShape{}); ++cta_id) {
269
+ shared_storage.pipelines.barrier_O.arrive(cta_id);
270
+ }
271
+ }
272
+ }
273
+
274
+ flash::SeqlenInfo<Varlen, kBlockM> seqlen_info{bidb, size<0>(params.shape_O), params.cu_seqlens, params.seqused};
275
+ bool is_varlen = Varlen && params.cu_seqlens;
276
+ int offset_o = seqlen_info.offset;
277
+ int seqlen_o = seqlen_info.seqlen;
278
+ int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / cutlass::NumThreadsPerWarpGroup, 0);
279
+
280
+ // Step 2: Write LSE from rmem -> gmem
281
+ auto thread_mma = tiled_mma.get_thread_slice(thread_idx);
282
+ // (MMA,MMA_M,MMA_K)
283
+ Tensor taccOcO = thread_mma.partition_C(cute::make_identity_tensor(select<0, 1>(TileShape_MNK_PV{})));
284
+ static_assert(decltype(size<0, 0>(taccOcO))::value == 2);
285
+ static_assert(decltype(size<0, 1>(taccOcO))::value == 2);
286
+ Tensor taccOcO_rowcol = make_tensor(taccOcO.data(), flash::convert_layout_acc_rowcol(taccOcO.layout()));
287
+ Tensor taccOcO_row = taccOcO_rowcol(_, _0{});
288
+ CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M
289
+
290
+ using PackGQA_t = flash::PackGQAManager<get<0>(TileShape_MNK_PV{}), get<1>(TileShape_MNK_PV{}), NumEpilogueThreads, Element>;
291
+ using PackGQApartial_t = flash::PackGQAManager<get<0>(TileShape_MNK_PV{}), get<1>(TileShape_MNK_PV{}), NumEpilogueThreads, ElementPartial>;
292
+
293
+ Tensor mLSE = make_tensor(make_gmem_ptr((!is_split ? params.ptr_LSE : params.ptr_LSE_partial) + offset_o * get<0>(!is_split ? params.stride_LSE : params.stride_LSE_partial)),
294
+ params.shape_LSE_packed,
295
+ !is_split ? params.stride_LSE_packed : params.stride_LSE_partial_packed)(_, bidh, !is_varlen ? bidb : 0, !is_split ? 0 : split_idx);
296
+ // if (thread_idx == 0) { printf("Before LSE write, m_block: %d, bidh: %d, bidb: %d, split_idx: %d, offset_o: %d, seqlen_o: %d\n", m_block, bidh, bidb, split_idx, offset_o, seqlen_o); print(mLSE); printf("\n"); }
297
+ if (!LargeHeadDimV || warp_group_idx == 0) {
298
+ if constexpr (!PackGQA) {
299
+ #pragma unroll
300
+ for (int mi = 0; mi < size(lse); ++mi) {
301
+ int const row = m_block * kBlockM + get<0>(taccOcO_row(mi));
302
+ if (get<1>(taccOcO_row(_0{})) == 0 && row < seqlen_o) { mLSE(row) = lse(mi); }
303
+ }
304
+ } else {
305
+ PackGQA_t::store_LSE(mLSE, lse, tiled_mma, params.qhead_per_khead_divmod, thread_idx, seqlen_o, m_block);
306
+ }
307
+ }
308
+
309
+ // Step 3: Write O from smem -> gmem
310
+ if constexpr (Use_TMA_O) {
311
+ Tensor mO = params.tma_store_O.get_tma_tensor(params.shape_O)(_, _, bidh, bidb, split_idx);
312
+ Tensor gO = local_tile(mO, select<0, 1>(TileShape_MNK_PV{}), make_coord(m_block, _0{})); // (M, K)
313
+ auto block_tma_O = params.tma_store_O.get_slice(_0{});
314
+ Tensor tOgO = block_tma_O.partition_D(gO); // (TMA, TMA_M, TMA_K)
315
+ Tensor tOsO = block_tma_O.partition_S(sO); // (TMA, TMA_M, TMA_K)
316
+ int warp_idx_sync = __shfl_sync(0xffffffff, thread_idx / cutlass::NumThreadsPerWarp, 0);
317
+ if (warp_idx_sync == NumEpilogueThreads / cutlass::NumThreadsPerWarp - 1) {
318
+ cutlass::arch::NamedBarrier::sync(NumEpilogueThreads + cutlass::NumThreadsPerWarp,
319
+ cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
320
+ if (cute::elect_one_sync()) {
321
+ cute::copy(params.tma_store_O, tOsO, tOgO);
322
+ tma_store_arrive();
323
+ tma_store_wait<0>();
324
+ #pragma unroll
325
+ for (uint32_t cta_id = 0; cta_id < size(ClusterShape{}); ++cta_id) {
326
+ shared_storage.pipelines.barrier_O.arrive(cta_id);
327
+ }
328
+ }
329
+ }
330
+ } else { // Don't use TMA in Varlen case since we don't want to overwrite the output of another sequence
331
+ if (!is_split) {
332
+ Tensor mO = make_tensor(make_gmem_ptr(params.ptr_O + offset_o * get<0>(params.stride_O)), params.shape_O_packed, params.stride_O_packed)(_, _, bidh, !is_varlen ? bidb : 0, _0{});
333
+ Tensor gO = local_tile(mO, select<0, 1>(TileShape_MNK_PV{}), make_coord(m_block, _0{})); // (M, K)
334
+ // if (thread_idx == 0) { printf("Before O write, m_block: %d, bidh: %d, bidb: %d, split_idx: %d, offset_o: %d, seqlen_o: %d, mO_addr = %p, addr diff = %d\n", m_block, bidh, bidb, split_idx, offset_o, seqlen_o, mO.data(), reinterpret_cast<int>(&mO(0)) - reinterpret_cast<int>(params.ptr_O)); }
335
+ GmemTiledCopyO gmem_tiled_copy_O;
336
+ auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx);
337
+ Tensor tOsO = gmem_thr_copy_O.partition_S(sO); // ((Atom,AtomNum),ATOM_M,ATOM_N)
338
+ // Tensor tOsO = gmem_thr_copy_O.partition_S(sO_pi); // ((Atom,AtomNum),ATOM_M,ATOM_N)
339
+ Tensor tOrO = make_fragment_like(tOsO);
340
+ cute::copy(gmem_tiled_copy_O, tOsO, tOrO);
341
+ if constexpr (ArchTag::kMinComputeCapability >= 90) {
342
+ cutlass::arch::fence_view_async_shared(); // ensure smem reads are done before next TMA to smem_v
343
+ #pragma unroll
344
+ for (uint32_t cta_id = 0; cta_id < size(ClusterShape{}); ++cta_id) {
345
+ shared_storage.pipelines.barrier_O.arrive(cta_id);
346
+ }
347
+ }
348
+ if constexpr (!PackGQA) {
349
+ // (BLK_M,BLK_K) -> (blk_m,blk_k)
350
+ Tensor tOcO = gmem_thr_copy_O.partition_D(cute::make_identity_tensor(select<0, 1>(TileShape_MNK_PV{})));
351
+ Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOsO)));
352
+ #pragma unroll
353
+ for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < get<1>(params.shape_O); }
354
+ Tensor tOgO = gmem_thr_copy_O.partition_D(gO);
355
+ // Clear_OOB_K must be false since we don't want to write zeros to gmem
356
+ flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
357
+ gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, seqlen_o - m_block * kBlockM
358
+ );
359
+ } else {
360
+ // If PackGQA, we split the work of compute O_ptr among threads in the same row
361
+ PackGQA_t::store_O(mO, tOrO, params.qhead_per_khead_divmod, thread_idx, seqlen_o, m_block);
362
+ }
363
+ } else {
364
+ Tensor mOpartial = make_tensor(make_gmem_ptr(params.ptr_O_partial + offset_o * get<0>(params.stride_O_partial)), params.shape_O_packed, params.stride_O_partial_packed)(_, _, bidh, !is_varlen ? bidb : 0, split_idx);
365
+ Tensor gOpartial = local_tile(mOpartial, select<0, 1>(TileShape_MNK_PV{}), make_coord(m_block, _0{})); // (M, K)
366
+ // We already arrived on barrier_O earlier if !Use_smem
367
+ if constexpr (Use_smem) {
368
+ if constexpr (ArchTag::kMinComputeCapability >= 90) {
369
+ #pragma unroll
370
+ for (uint32_t cta_id = 0; cta_id < size(ClusterShape{}); ++cta_id) {
371
+ shared_storage.pipelines.barrier_O.arrive(cta_id);
372
+ }
373
+ }
374
+ }
375
+ if constexpr (!PackGQA) {
376
+ static constexpr int kGmemElemsPerStoreDirect = 2;
377
+ cute::Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementPartial> gmem_copy_direct;
378
+ // Reshape acc from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N))
379
+ Tensor tOrO_rowcol = make_tensor(tOrO.data(), flash::convert_layout_acc_rowcol(tOrO.layout()));
380
+ Tensor tOrO_copy = cute::tiled_divide(tOrO_rowcol, Shape<_1, Int<kGmemElemsPerStoreDirect>>{});
381
+ Tensor tOgO = thread_mma.partition_C(gOpartial);
382
+ Tensor tOgO_rowcol = make_tensor(tOgO.data(), flash::convert_layout_acc_rowcol(tOgO.layout()));
383
+ Tensor tOgO_copy = cute::tiled_divide(tOgO_rowcol, Shape<_1, Int<kGmemElemsPerStoreDirect>>{});
384
+ Tensor taccOcO_col = taccOcO_rowcol(_0{}, _);
385
+ #pragma unroll
386
+ for (int m = 0; m < size(taccOcO_row); ++m) {
387
+ if (get<0>(taccOcO_row(m)) < seqlen_o - m_block * kBlockM) {
388
+ #pragma unroll
389
+ for (int k = 0; k < size(taccOcO_col) / kGmemElemsPerStoreDirect; ++k) {
390
+ if (get<1>(taccOcO_col(k * kGmemElemsPerStoreDirect)) < get<1>(params.shape_O)) {
391
+ cute::copy(gmem_copy_direct, tOrO_copy(_, m, k), tOgO_copy(_, m, k));
392
+ }
393
+ }
394
+ }
395
+ }
396
+ } else {
397
+ PackGQApartial_t::store_O_direct(mOpartial, tOrO, tiled_mma, params.qhead_per_khead_divmod, thread_idx, seqlen_o, m_block);
398
+ }
399
+ }
400
+ }
401
+ }
402
+
403
+ CUTLASS_DEVICE void
404
+ store_tail() {
405
+ // Don't need to do tma_store_wait<0>() here since we already did in @store
406
+ }
407
+
408
+ // Write 0 to output and -inf to LSE
409
+ CUTLASS_DEVICE void
410
+ store_zero(
411
+ Params const& params,
412
+ int thread_idx,
413
+ cute::tuple<int32_t, int32_t, int32_t, int32_t> const& block_coord
414
+ ) {
415
+ static constexpr int kBlockM = get<0>(TileShape_MNK_PV{});
416
+ auto [m_block, bidh, bidb, split_idx] = block_coord;
417
+ int num_splits = get<4>(params.shape_O_packed);
418
+ if constexpr (Split && Varlen) {
419
+ uint32_t num_splits_dynamic_u = reinterpret_cast<uint32_t const&>(split_idx) >> 16; // first 16 bits are for num_splits
420
+ int num_splits_dynamic = reinterpret_cast<int&>(num_splits_dynamic_u);
421
+ num_splits = num_splits_dynamic > 0 ? num_splits_dynamic : num_splits;
422
+ split_idx &= 0x0000FFFF; // Only use the lower 16 bits of split_idx
423
+ }
424
+ bool const is_split = !Split ? false : (!Varlen ? true : num_splits > 1);
425
+
426
+ flash::SeqlenInfo<Varlen, kBlockM> seqlen_info{bidb, size<0>(params.shape_O), params.cu_seqlens, params.seqused};
427
+ bool const is_varlen = Varlen && params.cu_seqlens;
428
+ int offset_o = seqlen_info.offset;
429
+ int seqlen_o = seqlen_info.seqlen;
430
+ int qhead_per_khead = !PackGQA ? 1 : params.qhead_per_khead_divmod.divisor;
431
+ Tensor mLSE = make_tensor(make_gmem_ptr((!is_split ? params.ptr_LSE : params.ptr_LSE_partial) + offset_o * get<0>(!is_split ? params.stride_LSE : params.stride_LSE_partial)),
432
+ params.shape_LSE_packed,
433
+ !is_split ? params.stride_LSE_packed : params.stride_LSE_partial_packed)(_, bidh, !is_varlen ? bidb : 0, !is_split ? 0 : split_idx);
434
+ Tensor gLSE = local_tile(mLSE, Shape<Int<kBlockM>>{}, make_coord(m_block));
435
+
436
+ static_assert(kBlockM <= NumEpilogueThreads);
437
+ if (thread_idx < kBlockM) {
438
+ const int row = m_block * kBlockM + thread_idx;
439
+ if constexpr (!PackGQA) {
440
+ if (row < seqlen_o) { mLSE(row) = -INFINITY; }
441
+ } else {
442
+ if (row < seqlen_o * qhead_per_khead) {
443
+ int m_idx, h_idx;
444
+ m_idx = params.qhead_per_khead_divmod.divmod(h_idx, row);
445
+ // mLSE has shape ((qhead_per_khead, seqlen_q)) and it's unhappy with just 1 "make_coord"
446
+ mLSE(make_coord(make_coord(h_idx, m_idx))) = -INFINITY;
447
+ }
448
+ }
449
+ }
450
+
451
+ // If split, we don't have to write 0 to mOpartial if the mha_combine kernel is used,
452
+ // since it will not use the value of O if LSE is -inf.
453
+ if (!is_split) {
454
+ Tensor mO = make_tensor(make_gmem_ptr(params.ptr_O + offset_o * get<0>(params.stride_O)), params.shape_O_packed, params.stride_O_packed)(_, _, bidh, !is_varlen ? bidb : 0, _0{});
455
+
456
+ GmemTiledCopyO gmem_tiled_copy_O;
457
+ auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx);
458
+ Tensor tOcO = gmem_thr_copy_O.partition_D(cute::make_identity_tensor(select<0, 1>(TileShape_MNK_PV{})));
459
+ if constexpr (!PackGQA) {
460
+ Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOcO)));
461
+ #pragma unroll
462
+ for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < get<1>(params.shape_O); }
463
+ Tensor gO = local_tile(mO, select<0, 1>(TileShape_MNK_PV{}), make_coord(m_block, _0{})); // (M, K)
464
+ Tensor tOgO = gmem_thr_copy_O.partition_D(gO);
465
+ Tensor tOrO = make_fragment_like(tOgO);
466
+ cute::clear(tOrO);
467
+ // Clear_OOB_K must be false since we don't want to write zeros to gmem
468
+ flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
469
+ gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, seqlen_o - m_block * kBlockM
470
+ );
471
+ } else {
472
+ // If PackGQA, we split the work of compute O_ptr among threads in the same row
473
+ using PackGQA_t = flash::PackGQAManager<get<0>(TileShape_MNK_PV{}), get<1>(TileShape_MNK_PV{}), NumEpilogueThreads, Element>;
474
+ Tensor tOrO = make_tensor<Element>(make_shape(Shape<_1, Int<kGmemElemsPerStore>>{}, size<1>(tOcO), size<2>(tOcO)));
475
+ cute::clear(tOrO);
476
+ PackGQA_t::store_O(mO, tOrO, params.qhead_per_khead_divmod, thread_idx, seqlen_o, m_block);
477
+ }
478
+ }
479
+
480
+ }
481
+
482
+ };
483
+
484
+ } // namespace flash
flash-attn/flash.h ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2023, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ #pragma once
6
+
7
+ #include <cuda.h>
8
+ #include <vector>
9
+
10
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
11
+
12
+ struct Qkv_params {
13
+ using index_t = int64_t;
14
+ // The QKV matrices.
15
+ void *__restrict__ q_ptr;
16
+ void *__restrict__ k_ptr;
17
+ void *__restrict__ v_ptr;
18
+
19
+ // The stride between rows of the Q, K and V matrices.
20
+ index_t q_batch_stride;
21
+ index_t k_batch_stride;
22
+ index_t v_batch_stride;
23
+ index_t q_row_stride;
24
+ index_t k_row_stride;
25
+ index_t v_row_stride;
26
+ index_t q_head_stride;
27
+ index_t k_head_stride;
28
+ index_t v_head_stride;
29
+ index_t v_dim_stride;
30
+
31
+ // The number of heads.
32
+ int h, h_k;
33
+ };
34
+
35
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
36
+
37
+ struct Flash_fwd_params : public Qkv_params {
38
+ using index_t = int64_t;
39
+
40
+ // The O matrix (output).
41
+ void * __restrict__ o_ptr;
42
+ void * __restrict__ oaccum_ptr;
43
+
44
+ // The stride between rows of O.
45
+ index_t o_batch_stride;
46
+ index_t o_row_stride;
47
+ index_t o_head_stride;
48
+
49
+ // The pointer to the softmax sum.
50
+ void * __restrict__ softmax_lse_ptr;
51
+ void * __restrict__ softmax_lseaccum_ptr;
52
+
53
+ // For FP8 scaling
54
+ float * __restrict__ q_descale_ptr;
55
+ float * __restrict__ k_descale_ptr;
56
+ float * __restrict__ v_descale_ptr;
57
+ index_t q_descale_batch_stride;
58
+ index_t q_descale_head_stride;
59
+ index_t k_descale_batch_stride;
60
+ index_t k_descale_head_stride;
61
+ index_t v_descale_batch_stride;
62
+ index_t v_descale_head_stride;
63
+
64
+ // The dimensions.
65
+ int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, rotary_dim;
66
+ int total_q, total_k, total_knew;
67
+ int b_k; // When having KV cache and with cache_batch_idx, K & V might have larger batch size than Q
68
+ int dv, dv_rounded; // For the case where V headdim is different from Q/K headdim
69
+
70
+ // The scaling factors for the kernel.
71
+ float scale_softmax;
72
+ float softcap;
73
+
74
+ // array of length b+1 holding starting offset of each sequence.
75
+ int * __restrict__ cu_seqlens_q;
76
+ int * __restrict__ cu_seqlens_k;
77
+ int * __restrict__ cu_seqlens_knew;
78
+ int * __restrict__ leftpad_k;
79
+
80
+ // If provided, the actual length of each q/k sequence.
81
+ int *__restrict__ seqused_q;
82
+ int *__restrict__ seqused_k;
83
+
84
+ // The stride between rows of Oaccum.
85
+ index_t oaccum_split_stride;
86
+ index_t oaccum_batch_stride;
87
+ index_t oaccum_row_stride;
88
+ index_t oaccum_head_stride;
89
+
90
+ // The stride between rows of LSEaccum.
91
+ index_t lseaccum_split_stride;
92
+ index_t lseaccum_batch_stride;
93
+ index_t lseaccum_head_stride;
94
+
95
+ // The K_new and V_new matrices.
96
+ void * __restrict__ knew_ptr;
97
+ void * __restrict__ vnew_ptr;
98
+
99
+ // The stride between rows of the Q, K and V matrices.
100
+ index_t knew_batch_stride;
101
+ index_t vnew_batch_stride;
102
+ index_t knew_row_stride;
103
+ index_t vnew_row_stride;
104
+ index_t knew_head_stride;
105
+ index_t vnew_head_stride;
106
+
107
+ void *__restrict__ qv_ptr;
108
+ index_t qv_batch_stride;
109
+ index_t qv_row_stride;
110
+ index_t qv_head_stride;
111
+
112
+ // The cos and sin matrices for rotary embedding.
113
+ void * __restrict__ rotary_cos_ptr;
114
+ void * __restrict__ rotary_sin_ptr;
115
+ int *__restrict__ seqlens_rotary;
116
+
117
+ // The indices to index into the KV cache.
118
+ int * __restrict__ kv_batch_idx;
119
+
120
+ // Paged KV cache
121
+ int * __restrict__ page_table;
122
+ index_t page_table_batch_stride;
123
+ int page_size;
124
+ int num_pages;
125
+ bool pagedkv_tma;
126
+
127
+ // The dropout probability (probability of keeping an activation).
128
+ float p_dropout;
129
+ // uint32_t p_dropout_in_uint;
130
+ // uint16_t p_dropout_in_uint16_t;
131
+ uint8_t p_dropout_in_uint8_t;
132
+
133
+ // Scale factor of 1 / (1 - p_dropout).
134
+ float rp_dropout;
135
+
136
+ // Local window size
137
+ int window_size_left, window_size_right;
138
+ int attention_chunk;
139
+
140
+ // Pointer to the RNG seed (idx 0) and offset (idx 1).
141
+ uint64_t * rng_state;
142
+
143
+ bool is_bf16;
144
+ bool is_fp32;
145
+ bool is_e4m3;
146
+ bool is_causal;
147
+ bool is_local;
148
+
149
+ bool is_rotary_interleaved;
150
+
151
+ int num_splits; // For split-KV version
152
+ bool pack_gqa;
153
+
154
+ int * __restrict__ tile_count_semaphore;
155
+ // int * __restrict__ num_m_blocks_ptr;
156
+ // int * __restrict__ num_n_blocks_ptr;
157
+ int * __restrict__ num_splits_dynamic_ptr;
158
+ bool skip_scheduler_metadata_computation;
159
+
160
+ int arch;
161
+ int num_sm;
162
+ };
163
+
164
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
165
+
166
+ struct Flash_bwd_params : public Flash_fwd_params {
167
+ using index_t = int64_t;
168
+
169
+ // The dO and dQKV matrices.
170
+ void *__restrict__ do_ptr;
171
+ void *__restrict__ dq_ptr;
172
+ void *__restrict__ dk_ptr;
173
+ void *__restrict__ dv_ptr;
174
+
175
+ // To accumulate dQ
176
+ void *__restrict__ dq_accum_ptr;
177
+ void *__restrict__ dk_accum_ptr;
178
+ void *__restrict__ dv_accum_ptr;
179
+
180
+ // // To accumulate dK and dV in case we're splitting the bwd along seqlen_q
181
+ // dimension void *__restrict__ dk_accum_ptr; void *__restrict__
182
+ // dv_accum_ptr;
183
+
184
+ // The stride between rows of the dO, dQ, dK and dV matrices.
185
+ index_t do_batch_stride;
186
+ index_t do_row_stride;
187
+ index_t do_head_stride;
188
+ index_t dq_batch_stride;
189
+ index_t dk_batch_stride;
190
+ index_t dv_batch_stride;
191
+ index_t dq_row_stride;
192
+ index_t dk_row_stride;
193
+ index_t dv_row_stride;
194
+ index_t dq_head_stride;
195
+ index_t dk_head_stride;
196
+ index_t dv_head_stride;
197
+
198
+ // The pointer to the softmax d sum.
199
+ void *__restrict__ dsoftmax_sum;
200
+ void *__restrict__ softmax_lse_log2_ptr;
201
+
202
+ int *__restrict__ dq_semaphore;
203
+ int *__restrict__ dk_semaphore;
204
+ int *__restrict__ dv_semaphore;
205
+
206
+ bool deterministic;
207
+ index_t dq_accum_split_stride;
208
+ };
209
+
210
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
211
+
212
+ template <int Arch, typename T, int kHeadDim, int kHeadDimV, bool Split, bool PagedKVNonTMA, bool Has_softcap, bool PackGQA>
213
+ void run_mha_fwd_(Flash_fwd_params &params, cudaStream_t stream);
214
+ void prepare_varlen_num_blocks(Flash_fwd_params &params, cudaStream_t stream, bool packgqa, int blockM, int blockN, bool enable_pdl);
215
+ template <int Arch, typename T, int kHeadDim, bool Has_softcap>
216
+ void run_mha_bwd_(Flash_bwd_params &params, cudaStream_t stream);
217
+ template <typename T, typename Tpartial, int kBlockK>
218
+ void run_mha_fwd_combine_(Flash_fwd_params &params, cudaStream_t stream, bool enable_pdl);
flash-attn/flash_api.cpp ADDED
@@ -0,0 +1,1720 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ #include <Python.h>
6
+ #include <torch/nn/functional/padding.h>
7
+ #include <ATen/cuda/CUDAContextLight.h>
8
+ #include <c10/cuda/CUDAGuard.h>
9
+
10
+ #include <cutlass/numeric_types.h>
11
+
12
+ #include "flash.h"
13
+ #include "static_switch.h"
14
+ #include "tile_size.h"
15
+ #include "heuristics.h"
16
+ #include "cuda_check.h"
17
+
18
+
19
+ extern "C" {
20
+ /* Creates a dummy empty _C module that can be imported from Python.
21
+ The import from Python will load the .so consisting of this file
22
+ in this extension, so that the TORCH_LIBRARY static initializers
23
+ below are run. */
24
+ PyObject* PyInit__C(void)
25
+ {
26
+ static struct PyModuleDef module_def = {
27
+ PyModuleDef_HEAD_INIT,
28
+ "_C", /* name of module */
29
+ NULL, /* module documentation, may be NULL */
30
+ -1, /* size of per-interpreter state of the module,
31
+ or -1 if the module keeps state in global variables. */
32
+ NULL, /* methods */
33
+ };
34
+ return PyModule_Create(&module_def);
35
+ }
36
+ }
37
+
38
+ #define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA")
39
+ #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
40
+ #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
41
+
42
+ void set_params_fprop(Flash_fwd_params &params,
43
+ // sizes
44
+ const size_t b,
45
+ const size_t seqlen_q,
46
+ const size_t seqlen_k,
47
+ const size_t seqlen_q_rounded,
48
+ const size_t seqlen_k_rounded,
49
+ const size_t h,
50
+ const size_t h_k,
51
+ const size_t d,
52
+ const size_t d_rounded,
53
+ // device pointers
54
+ const at::Tensor q,
55
+ const at::Tensor k,
56
+ const at::Tensor v,
57
+ at::Tensor out,
58
+ void *cu_seqlens_q_d,
59
+ void *cu_seqlens_k_d,
60
+ void *seqused_q,
61
+ void *seqused_k,
62
+ void *softmax_lse_d,
63
+ float p_dropout,
64
+ float softmax_scale,
65
+ int window_size_left,
66
+ int window_size_right,
67
+ int attention_chunk,
68
+ const float softcap=0.f,
69
+ const int sm_margin=0) {
70
+
71
+ // Reset the parameters
72
+ params = {};
73
+
74
+ params.is_bf16 = q.dtype() == torch::kBFloat16;
75
+ params.is_e4m3 = q.dtype() == torch::kFloat8_e4m3fn;
76
+
77
+ // Set the pointers and strides.
78
+ params.q_ptr = q.data_ptr();
79
+ params.k_ptr = k.data_ptr();
80
+ params.v_ptr = v.data_ptr();
81
+ // All stride are in elements, not bytes.
82
+ params.q_row_stride = q.stride(-3);
83
+ params.k_row_stride = k.stride(-3);
84
+ params.v_row_stride = v.stride(-3);
85
+ params.q_head_stride = q.stride(-2);
86
+ params.k_head_stride = k.stride(-2);
87
+ params.v_head_stride = v.stride(-2);
88
+ params.v_dim_stride = v.stride(-1);
89
+ params.o_ptr = out.data_ptr();
90
+ params.o_row_stride = out.stride(-3);
91
+ params.o_head_stride = out.stride(-2);
92
+
93
+ if (cu_seqlens_q_d == nullptr) {
94
+ params.q_batch_stride = q.stride(0);
95
+ params.o_batch_stride = out.stride(0);
96
+ }
97
+ if (cu_seqlens_k_d == nullptr) {
98
+ params.k_batch_stride = k.stride(0);
99
+ params.v_batch_stride = v.stride(0);
100
+ }
101
+
102
+ params.cu_seqlens_q = static_cast<int *>(cu_seqlens_q_d);
103
+ params.cu_seqlens_k = static_cast<int *>(cu_seqlens_k_d);
104
+ params.seqused_q = static_cast<int *>(seqused_q);
105
+ params.seqused_k = static_cast<int *>(seqused_k);
106
+
107
+ // Softmax sum
108
+ params.softmax_lse_ptr = softmax_lse_d;
109
+
110
+ // Set the dimensions.
111
+ params.b = b;
112
+ params.h = h;
113
+ params.h_k = h_k;
114
+ params.seqlen_q = seqlen_q;
115
+ params.seqlen_k = seqlen_k;
116
+ params.seqlen_q_rounded = seqlen_q_rounded;
117
+ params.seqlen_k_rounded = seqlen_k_rounded;
118
+ params.d = d;
119
+ params.d_rounded = d_rounded;
120
+
121
+ // Set the different scale values.
122
+ params.scale_softmax = softmax_scale;
123
+ params.softcap = softcap;
124
+
125
+ // Set this to probability of keeping an element to simplify things.
126
+ params.p_dropout = 1.f - p_dropout;
127
+ // Convert p from float to int so we don't have to convert the random uint to float to compare.
128
+ // [Minor] We want to round down since when we do the comparison we use <= instead of <
129
+ // params.p_dropout_in_uint = uint32_t(std::floor(params.p_dropout * 4294967295.0));
130
+ // params.p_dropout_in_uint16_t = uint16_t(std::floor(params.p_dropout * 65535.0));
131
+ params.p_dropout_in_uint8_t = uint8_t(std::floor(params.p_dropout * 255.0));
132
+ params.rp_dropout = 1.f / params.p_dropout;
133
+ TORCH_CHECK(p_dropout < 1.f);
134
+ #ifdef FLASHATTENTION_DISABLE_DROPOUT
135
+ TORCH_CHECK(p_dropout == 0.0f, "This flash attention build does not support dropout.");
136
+ #endif
137
+
138
+ // Causal is the special case where window_size_right == 0 and window_size_left < 0.
139
+ // Local is the more general case where window_size_right >= 0 or window_size_left >= 0.
140
+ params.is_causal = window_size_left < 0 && window_size_right == 0 && attention_chunk == 0;
141
+ params.is_local = (window_size_left >= 0 || window_size_right >= 0 || attention_chunk >= 1) && !params.is_causal;
142
+
143
+ // TODO: check this
144
+ if (window_size_left < 0) { window_size_left = seqlen_k - 1; }
145
+ if (window_size_right < 0) { window_size_right = seqlen_q - 1; }
146
+ if (attention_chunk > 0) {
147
+ window_size_left = std::min(window_size_left, attention_chunk - 1);
148
+ window_size_right = std::min(window_size_right, attention_chunk - 1);
149
+ }
150
+ params.window_size_left = window_size_left;
151
+ params.window_size_right = window_size_right;
152
+ params.attention_chunk = attention_chunk;
153
+
154
+ params.arch = at::cuda::getCurrentDeviceProperties()->major * 10 + at::cuda::getCurrentDeviceProperties()->minor;
155
+ params.num_sm = at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin;
156
+
157
+ #ifdef FLASHATTENTION_DISABLE_LOCAL
158
+ TORCH_CHECK(!params.is_local, "This flash attention build does not support local attention.");
159
+ #endif
160
+ }
161
+
162
+ void set_params_dgrad(Flash_bwd_params &params,
163
+ // sizes
164
+ const size_t b,
165
+ const size_t seqlen_q,
166
+ const size_t seqlen_k,
167
+ const size_t seqlen_q_rounded,
168
+ const size_t seqlen_k_rounded,
169
+ const size_t h,
170
+ const size_t h_k,
171
+ const size_t d,
172
+ const size_t d_rounded,
173
+ // device pointers
174
+ const at::Tensor q,
175
+ const at::Tensor k,
176
+ const at::Tensor v,
177
+ const at::Tensor out,
178
+ const at::Tensor dout,
179
+ at::Tensor dq,
180
+ at::Tensor dk,
181
+ at::Tensor dv,
182
+ void *cu_seqlens_q_d,
183
+ void *cu_seqlens_k_d,
184
+ void *seqused_q,
185
+ void *seqused_k,
186
+ void *dq_accum_d,
187
+ void *dk_accum_d,
188
+ void *dv_accum_d,
189
+ void *softmax_lse_d,
190
+ void *dsoftmax_sum_d,
191
+ float p_dropout,
192
+ float softmax_scale,
193
+ int window_size_left,
194
+ int window_size_right,
195
+ int attention_chunk,
196
+ const float softcap=0.f,
197
+ bool deterministic=false,
198
+ int const sm_margin=0) {
199
+
200
+ set_params_fprop(params,
201
+ b, seqlen_q, seqlen_k, seqlen_q_rounded, seqlen_k_rounded, h, h_k, d, d_rounded,
202
+ q, k, v, out,
203
+ cu_seqlens_q_d,
204
+ cu_seqlens_k_d,
205
+ seqused_q,
206
+ seqused_k,
207
+ softmax_lse_d,
208
+ p_dropout,
209
+ softmax_scale,
210
+ window_size_left,
211
+ window_size_right,
212
+ attention_chunk,
213
+ softcap,
214
+ sm_margin);
215
+
216
+ // Set the pointers and strides.
217
+ params.do_ptr = dout.data_ptr();
218
+ params.do_row_stride = dout.stride(-3);
219
+ params.do_head_stride = dout.stride(-2);
220
+ params.dq_ptr = dq.data_ptr();
221
+ params.dk_ptr = dk.data_ptr();
222
+ params.dv_ptr = dv.data_ptr();
223
+ params.dq_row_stride = dq.stride(-3);
224
+ params.dk_row_stride = dk.stride(-3);
225
+ params.dv_row_stride = dv.stride(-3);
226
+ params.dq_head_stride = dq.stride(-2);
227
+ params.dk_head_stride = dk.stride(-2);
228
+ params.dv_head_stride = dv.stride(-2);
229
+
230
+ if (cu_seqlens_q_d == nullptr) {
231
+ params.do_batch_stride = dout.stride(0);
232
+ params.dq_batch_stride = dq.stride(0);
233
+ params.dk_batch_stride = dk.stride(0);
234
+ params.dv_batch_stride = dv.stride(0);
235
+ }
236
+
237
+ params.dq_accum_ptr = dq_accum_d;
238
+ params.dk_accum_ptr = dk_accum_d;
239
+ params.dv_accum_ptr = dv_accum_d;
240
+
241
+ // Softmax sum
242
+ params.dsoftmax_sum = dsoftmax_sum_d;
243
+
244
+ params.deterministic = deterministic;
245
+ }
246
+
247
+ template <int Arch, int Split, bool PagedKVNonTMA, bool PackGQA, bool Has_softcap>
248
+ void run_mha_fwd_constexpr(Flash_fwd_params &params, cudaStream_t stream) {
249
+ if (!params.is_e4m3) {
250
+ if (params.is_bf16) {
251
+ #ifndef FLASHATTENTION_DISABLE_HDIM64
252
+ if (params.d <= 64) {
253
+ if constexpr (Arch == 90) {
254
+ if (params.dv > 256) {
255
+ return run_mha_fwd_<Arch, cutlass::bfloat16_t, 64, 512, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);
256
+ } else if (params.dv > 64) {
257
+ return run_mha_fwd_<Arch, cutlass::bfloat16_t, 64, 256, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);
258
+ }
259
+ }
260
+ return run_mha_fwd_<Arch, cutlass::bfloat16_t, 64, 64, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);
261
+ }
262
+ #endif
263
+ #ifndef FLASHATTENTION_DISABLE_HDIM96
264
+ if (params.d <= 96) { return run_mha_fwd_<Arch, cutlass::bfloat16_t, 96, 96, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); }
265
+ #endif
266
+ #ifndef FLASHATTENTION_DISABLE_HDIM128
267
+ if (params.d <= 128) { return run_mha_fwd_<Arch, cutlass::bfloat16_t, 128, 128, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); }
268
+ #endif
269
+ #ifndef FLASHATTENTION_DISABLE_HDIM192
270
+ if (params.d <= 192) {
271
+ if constexpr (Arch == 90) {
272
+ if (params.dv <= 128) {
273
+ return run_mha_fwd_<Arch, cutlass::bfloat16_t, 192, 128, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);
274
+ }
275
+ }
276
+ return run_mha_fwd_<Arch, cutlass::bfloat16_t, 192, 192, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);
277
+ }
278
+ #endif
279
+ #ifndef FLASHATTENTION_DISABLE_HDIM256
280
+ if (params.d <= 256) { return run_mha_fwd_<Arch, cutlass::bfloat16_t, 256, 256, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); }
281
+ #endif
282
+ } else {
283
+ #ifndef FLASHATTENTION_DISABLE_FP16
284
+ #ifndef FLASHATTENTION_DISABLE_HDIM64
285
+ if (params.d <= 64) {
286
+ if constexpr (Arch == 90) {
287
+ if (params.dv > 256) {
288
+ return run_mha_fwd_<Arch, cutlass::half_t, 64, 512, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);
289
+ } else if (params.dv > 64) {
290
+ return run_mha_fwd_<Arch, cutlass::half_t, 64, 256, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);
291
+ }
292
+ }
293
+ return run_mha_fwd_<Arch, cutlass::half_t, 64, 64, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);
294
+ }
295
+ #endif
296
+ #ifndef FLASHATTENTION_DISABLE_HDIM96
297
+ if (params.d <= 96) { return run_mha_fwd_<Arch, cutlass::half_t, 96, 96, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); }
298
+ #endif
299
+ #ifndef FLASHATTENTION_DISABLE_HDIM128
300
+ if (params.d <= 128) { return run_mha_fwd_<Arch, cutlass::half_t, 128, 128, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); }
301
+ #endif
302
+ #ifndef FLASHATTENTION_DISABLE_HDIM192
303
+ if (params.d <= 192) {
304
+ if constexpr (Arch == 90) {
305
+ if (params.dv <= 128) {
306
+ return run_mha_fwd_<Arch, cutlass::half_t, 192, 128, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);
307
+ }
308
+ }
309
+ return run_mha_fwd_<Arch, cutlass::half_t, 192, 192, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);
310
+ }
311
+ #endif
312
+ #ifndef FLASHATTENTION_DISABLE_HDIM256
313
+ if (params.d <= 256) { return run_mha_fwd_<Arch, cutlass::half_t, 256, 256, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); }
314
+ #endif
315
+ #else
316
+ TORCH_CHECK(false, "This flash attention build does not support FP16.");
317
+ #endif
318
+ }
319
+ } else {
320
+ #ifndef FLASHATTENTION_DISABLE_FP8
321
+ #ifndef FLASHATTENTION_DISABLE_HDIM64
322
+ if (params.d <= 64) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 64, 64, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); }
323
+ #endif
324
+ #ifndef FLASHATTENTION_DISABLE_HDIM96
325
+ if (params.d <= 96) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 96, 96, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); }
326
+ #endif
327
+ #ifndef FLASHATTENTION_DISABLE_HDIM128
328
+ if (params.d <= 128) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 128, 128, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); }
329
+ #endif
330
+ #ifndef FLASHATTENTION_DISABLE_HDIM192
331
+ if (params.d <= 192) {
332
+ if constexpr (Arch == 90) {
333
+ if (params.dv <= 128) {
334
+ return run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 128, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);
335
+ }
336
+ }
337
+ return run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 192, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);
338
+ }
339
+ #endif
340
+ #ifndef FLASHATTENTION_DISABLE_HDIM256
341
+ if (params.d <= 256) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 256, 256, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); }
342
+ #endif
343
+ #else
344
+ TORCH_CHECK(false, "This flash attention build does not support FP8.");
345
+ #endif
346
+ }
347
+ }
348
+
349
+ void run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream) {
350
+ // HEADDIM_SWITCH(params.d, [&] {
351
+ // run_mha_fwd_<cutlass::half_t, kHeadSize>(params, stream);
352
+ // });
353
+ TORCH_CHECK(params.num_splits >= 1);
354
+ ARCH_SWITCH(params.arch, Arch, [&] {
355
+ SPLIT_SWITCH(params.num_splits > 1, Split, [&] {
356
+ PAGEDKV_SWITCH(params.page_table && !params.pagedkv_tma, PagedKVNonTMA, [&] {
357
+ PACKGQA_SWITCH(params.pack_gqa, PackGQA_, [&] {
358
+ // Always enable PackGQA for Sm8x or PagedKVNonTMA or Split to reduce compilation
359
+ static constexpr bool PackGQA = PackGQA_ || Arch < 90 || PagedKVNonTMA || Split;
360
+ SOFTCAP_SWITCH(params.softcap > 0.0, Has_softcap, [&] {
361
+ run_mha_fwd_constexpr<Arch, Split, PagedKVNonTMA, PackGQA, Has_softcap>(params, stream);
362
+ });
363
+ });
364
+ });
365
+ });
366
+ });
367
+ }
368
+
369
+ void run_mha_fwd_combine(Flash_fwd_params &params, cudaStream_t stream, bool enable_pdl=false) {
370
+ #ifndef FLASHATTENTION_DISABLE_SPLIT
371
+ // If hdim is 96 or 192, it's faster to round them to 128 or 256 respectively
372
+ // so that kBlockM is smaller and we have more parallelism.
373
+ if (params.is_fp32) {
374
+ if (params.dv <= 64) {
375
+ run_mha_fwd_combine_<float, float, 64>(params, stream, enable_pdl);
376
+ } else {
377
+ run_mha_fwd_combine_<float, float, 128>(params, stream, enable_pdl);
378
+ }
379
+ } else if (params.is_bf16) {
380
+ if (params.dv <= 64) {
381
+ run_mha_fwd_combine_<cutlass::bfloat16_t, float, 64>(params, stream, enable_pdl);
382
+ } else {
383
+ run_mha_fwd_combine_<cutlass::bfloat16_t, float, 128>(params, stream, enable_pdl);
384
+ }
385
+ } else {
386
+ if (params.dv <= 64) {
387
+ run_mha_fwd_combine_<cutlass::half_t, float, 64>(params, stream, enable_pdl);
388
+ } else {
389
+ run_mha_fwd_combine_<cutlass::half_t, float, 128>(params, stream, enable_pdl);
390
+ }
391
+ }
392
+ #else
393
+ TORCH_CHECK(false, "This flash attention build does not support combine kernels.");
394
+ #endif
395
+ }
396
+
397
+ inline bool get_pagedkv_tma(Flash_fwd_params const& params) {
398
+ if (params.arch < 90 || !params.page_table || params.leftpad_k || params.knew_ptr) { return false; }
399
+ // This needs to match the kernel configs
400
+ auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, false /*paged_kv_non_TMA*/, params.softcap > 0.f);
401
+ int const kBlockM = std::get<0>(kBlockMN_kernel_args_sm90);
402
+ int const kBlockN = std::get<1>(kBlockMN_kernel_args_sm90);
403
+ // Heuristic: when seqlen_q <= kBlockM, we're not compute bound, and somehow using TMA is slower,
404
+ // at least for MLA.
405
+ return params.page_size % kBlockN == 0 && params.seqlen_q * (params.h / params.h_k) > kBlockM;
406
+ }
407
+
408
+ inline bool get_pack_gqa(Flash_fwd_params const& params) {
409
+ // Always enable PackGQA for Sm8x or PagedKVNonTMA or Split to reduce compilation and binary size.
410
+ // Has little effect on speed.
411
+ if (params.arch < 90 || (params.page_table && !params.pagedkv_tma) || params.num_splits > 1) { return true; }
412
+ #ifdef FLASHATTENTION_DISABLE_PACKGQA
413
+ return false;
414
+ #else
415
+ // params.page_table must already be set
416
+ if (params.h == params.h_k) { return false; }
417
+ // This needs to match the kernel configs
418
+ auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table && !params.pagedkv_tma, params.softcap > 0.f);
419
+ int const kBlockM = std::get<0>(kBlockMN_kernel_args_sm90);
420
+ return should_pack_gqa(params.cu_seqlens_q || params.seqused_q, params.seqlen_q, params.h / params.h_k, kBlockM);
421
+ #endif
422
+ }
423
+
424
+ inline int get_num_splits(Flash_fwd_params const& params) {
425
+ #ifdef FLASHATTENTION_DISABLE_SPLIT
426
+ return 1;
427
+ #else
428
+ // Always enable PackGQA for Split
429
+ // params.page_table must already be set
430
+ // This needs to match the kernel configs
431
+ bool varlen = params.cu_seqlens_q || params.cu_seqlens_k || params.seqused_q || params.seqused_k || params.leftpad_k;
432
+ auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table && !params.pagedkv_tma, params.softcap > 0.f);
433
+ // Strictly speaking we need to pass in (varlen && params.num_splits > 1) but num_splits
434
+ // has not been set here. It's OK though because we might just underestimate kBlockN a bit
435
+ auto kBlockMN_kernel_args_sm8x = tile_size_fwd_sm8x(params.arch == 86 || params.arch == 89, params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, params.page_table, varlen, params.softcap > 0.f, params.knew_ptr);
436
+ int const kBlockM = params.arch >= 90 ? std::get<0>(kBlockMN_kernel_args_sm90) : std::get<0>(kBlockMN_kernel_args_sm8x);
437
+ int const kBlockN = params.arch >= 90 ? std::get<1>(kBlockMN_kernel_args_sm90) : std::get<1>(kBlockMN_kernel_args_sm8x);
438
+ int seqlen_q_packgqa = params.seqlen_q * (params.h / params.h_k);
439
+ // If is_local, we're not going to load all of seqlen_k
440
+ int const seqlen_k_loaded = !params.is_local
441
+ ? params.seqlen_k
442
+ : std::max(0, std::min(params.seqlen_k, params.window_size_right + params.window_size_left + 1 + kBlockM));
443
+ int const num_n_blocks = (seqlen_k_loaded + kBlockN - 1) / kBlockN;
444
+ int const num_m_blocks = (seqlen_q_packgqa + kBlockM - 1) / kBlockM;
445
+ int const size_one_kv_head = params.seqlen_k * (params.d + params.dv) * (params.is_e4m3 ? 1 : 2);
446
+ // Always enable PackGQA for Split
447
+ // If varlen, we use dynamic split, so this heuristic just needs to get an upper bound on num_splits.
448
+ // We assume the case where there's 1 long sequence and the rest are short, i.e. pretending
449
+ // that batch = 1.
450
+ int total_mblocks = (params.num_splits_dynamic_ptr ? 1 : params.b) * params.h_k * num_m_blocks;
451
+ return num_splits_heuristic(total_mblocks, params.num_sm, num_n_blocks, num_m_blocks, size_one_kv_head, params.is_causal || params.is_local, 128);
452
+ #endif
453
+ }
454
+
455
+ inline int get_max_headdim() {
456
+ #ifndef FLASHATTENTION_DISABLE_HDIM256
457
+ return 256;
458
+ #endif
459
+ #ifndef FLASHATTENTION_DISABLE_HDIM192
460
+ return 192;
461
+ #endif
462
+ #ifndef FLASHATTENTION_DISABLE_HDIM128
463
+ return 128;
464
+ #endif
465
+ #ifndef FLASHATTENTION_DISABLE_HDIM96
466
+ return 96;
467
+ #endif
468
+ #ifndef FLASHATTENTION_DISABLE_HDIM64
469
+ return 64;
470
+ #endif
471
+ return 0;
472
+ }
473
+
474
+ inline int round_up_headdim(int head_size) {
475
+ #ifndef FLASHATTENTION_DISABLE_HDIM64
476
+ if (head_size <= 64) { return 64; }
477
+ #endif
478
+ #ifndef FLASHATTENTION_DISABLE_HDIM96
479
+ if (head_size <= 96) { return 96; }
480
+ #endif
481
+ #ifndef FLASHATTENTION_DISABLE_HDIM128
482
+ if (head_size <= 128) { return 128; }
483
+ #endif
484
+ #ifndef FLASHATTENTION_DISABLE_HDIM192
485
+ if (head_size <= 192) { return 192; }
486
+ #endif
487
+ #ifndef FLASHATTENTION_DISABLE_HDIM256
488
+ if (head_size <= 256) { return 256; }
489
+ #endif
490
+ return 256;
491
+ }
492
+
493
+ inline int round_up_headdimv(int head_size) {
494
+ if (head_size <= 64) { return 64; }
495
+ if (head_size <= 96) { return 96; }
496
+ if (head_size <= 128) { return 128; }
497
+ if (head_size <= 192) { return 192; }
498
+ if (head_size <= 256) { return 256; }
499
+ return 512;
500
+ }
501
+
502
+ // Only applicable to the case where seqused_k (i.e. cache_seqlens) is available
503
+ at::Tensor
504
+ mha_fwd_get_scheduler_metadata(
505
+ int64_t batch_size,
506
+ int64_t max_seqlen_q,
507
+ int64_t max_seqlen_k,
508
+ int64_t num_heads,
509
+ int64_t num_heads_k,
510
+ int64_t headdim,
511
+ int64_t headdim_v,
512
+ at::ScalarType qkv_dtype,
513
+ at::Tensor seqused_k, // b
514
+ std::optional<at::Tensor> cu_seqlens_q_, // b+1
515
+ std::optional<at::Tensor> cu_seqlens_k_, // b+1
516
+ std::optional<at::Tensor> cu_seqlens_k_new_, // b+1
517
+ std::optional<at::Tensor> seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used.
518
+ std::optional<at::Tensor> leftpad_k_, // b
519
+ std::optional<int64_t> page_size,
520
+ int64_t max_seqlen_k_new, // 0 means we're not appending new KV
521
+ bool is_causal,
522
+ int64_t window_size_left,
523
+ int64_t window_size_right,
524
+ int64_t attention_chunk,
525
+ bool has_softcap,
526
+ int64_t num_splits,
527
+ std::optional<bool> pack_gqa_,
528
+ int64_t sm_margin
529
+ ) {
530
+
531
+ TORCH_CHECK(qkv_dtype == at::ScalarType::Half || qkv_dtype == at::ScalarType::BFloat16 || qkv_dtype == at::ScalarType::Float8_e4m3fn,
532
+ "FlashAttention only supports fp16, bf16, and fp8_e4m3 data type");
533
+ TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
534
+
535
+ // Reset the parameters
536
+ Flash_fwd_params params{};
537
+ params.is_bf16 = qkv_dtype == at::ScalarType::BFloat16;
538
+ params.is_e4m3 = qkv_dtype == at::ScalarType::Float8_e4m3fn;
539
+ params.b = batch_size;
540
+ params.seqlen_q = max_seqlen_q;
541
+ params.seqlen_k = max_seqlen_k;
542
+ params.h = num_heads;
543
+ params.h_k = num_heads_k;
544
+ params.d = headdim;
545
+ params.dv = headdim_v;
546
+ params.d_rounded = round_up_headdim(headdim);
547
+ params.dv_rounded = headdim_v == headdim ? params.d_rounded : round_up_headdimv(headdim_v);
548
+ params.seqlen_knew = max_seqlen_k_new;
549
+
550
+ bool const is_varlen_q = cu_seqlens_q_.has_value();
551
+ params.cu_seqlens_q = is_varlen_q ? cu_seqlens_q_.value().data_ptr<int>() : nullptr;
552
+ bool const is_varlen_k = cu_seqlens_k_.has_value();
553
+ params.cu_seqlens_k = is_varlen_k ? cu_seqlens_k_.value().data_ptr<int>() : nullptr;
554
+ params.cu_seqlens_knew = cu_seqlens_k_new_.has_value() ? cu_seqlens_k_new_.value().data_ptr<int>() : nullptr;
555
+ params.seqused_q = seqused_q_.has_value() ? seqused_q_.value().data_ptr<int>() : nullptr;
556
+ params.seqused_k = seqused_k.data_ptr<int>();
557
+ params.leftpad_k = leftpad_k_.has_value() ? leftpad_k_.value().data_ptr<int>() : nullptr;
558
+ params.knew_ptr = params.seqlen_knew > 0 ? reinterpret_cast<int*>(1) : nullptr;
559
+ if (window_size_left >= max_seqlen_k - 1) { window_size_left = -1; }
560
+ if (window_size_right >= max_seqlen_q - 1) { window_size_right = -1; }
561
+ // causal=true is the same as causal=false in this case
562
+ if (max_seqlen_q == 1 && window_size_left == -1 && window_size_right == -1 && attention_chunk == 0) {
563
+ // Special case of hdim 128 where we want causal to have kBlockN=128, better for pagedKV and TMA
564
+ if ((headdim <= 64 || headdim > 128) || !page_size.has_value()) {
565
+ is_causal = false;
566
+ }
567
+ }
568
+ if (is_causal) { window_size_right = 0; }
569
+
570
+ params.is_causal = window_size_left < 0 && window_size_right == 0 && attention_chunk == 0;
571
+ params.is_local = (window_size_left >= 0 || window_size_right >= 0 || attention_chunk >= 1) && !params.is_causal;
572
+ if (window_size_left < 0) { window_size_left = max_seqlen_k - 1; }
573
+ if (window_size_right < 0) { window_size_right = max_seqlen_q - 1; }
574
+ if (attention_chunk > 0) {
575
+ window_size_left = std::min(window_size_left, attention_chunk - 1);
576
+ window_size_right = std::min(window_size_right, attention_chunk - 1);
577
+ }
578
+ params.window_size_left = window_size_left;
579
+ params.window_size_right = window_size_right;
580
+ params.attention_chunk = attention_chunk;
581
+ params.arch = at::cuda::getCurrentDeviceProperties()->major * 10 + at::cuda::getCurrentDeviceProperties()->minor;
582
+ params.num_sm = at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin;
583
+ params.softcap = has_softcap ? 1.0f : 0.0f;
584
+
585
+ params.page_size = page_size.has_value() ? page_size.value() : 1;
586
+ params.page_table = !page_size.has_value() ? nullptr : reinterpret_cast<int*>(1);
587
+
588
+ bool const use_dynamic_split = params.b <= 992;
589
+ params.num_splits_dynamic_ptr = !use_dynamic_split ? nullptr : reinterpret_cast<int*>(1);
590
+
591
+ params.pagedkv_tma = get_pagedkv_tma(params);
592
+ params.num_splits = num_splits <= 0 ? get_num_splits(params) : num_splits;
593
+ // Always enable PackGQA for Split, and get_pack_gqa requires params.num_splits to decide
594
+ params.pack_gqa = pack_gqa_.has_value() ? pack_gqa_.value() : get_pack_gqa(params);
595
+
596
+ bool is_varlen = true;
597
+
598
+ // Otherwise the kernel will be launched from cuda:0 device
599
+ // Cast to char to avoid compiler warning about narrowing
600
+ at::cuda::CUDAGuard device_guard{(char)seqused_k.get_device()};
601
+
602
+ auto opts = seqused_k.options();
603
+ // This needs to be set after get_num_splits
604
+ at::Tensor tile_count_semaphore; // Contains the semaphore and optionally num_splits_dynamic
605
+ bool const scheduler_needs_semaphore = params.arch >= 90 || params.num_splits > 1;
606
+ if (scheduler_needs_semaphore || use_dynamic_split) {
607
+ tile_count_semaphore = torch::empty({int(scheduler_needs_semaphore) + int(use_dynamic_split) * params.b}, opts.dtype(torch::kInt32));
608
+ if (scheduler_needs_semaphore) {
609
+ if (!use_dynamic_split) { tile_count_semaphore.zero_(); } // If varlen we'll manually do the zero-ing
610
+ params.tile_count_semaphore = tile_count_semaphore.data_ptr<int>();
611
+ } else {
612
+ params.tile_count_semaphore = nullptr;
613
+ }
614
+ params.num_splits_dynamic_ptr = use_dynamic_split ? tile_count_semaphore.data_ptr<int>() + 1 : nullptr;
615
+ }
616
+
617
+ if (params.num_splits_dynamic_ptr) {
618
+ auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table && !params.pagedkv_tma, params.softcap > 0.f);
619
+ auto kBlockMN_kernel_args_sm8x = tile_size_fwd_sm8x(params.arch == 86 || params.arch == 89, params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, params.page_table, is_varlen && params.num_splits > 1, params.softcap > 0.f, params.knew_ptr);
620
+ int const kBlockM = params.arch >= 90 ? std::get<0>(kBlockMN_kernel_args_sm90) : std::get<0>(kBlockMN_kernel_args_sm8x);
621
+ int const kBlockN = params.arch >= 90 ? std::get<1>(kBlockMN_kernel_args_sm90) : std::get<1>(kBlockMN_kernel_args_sm8x);
622
+ auto stream = at::cuda::getCurrentCUDAStream().stream();
623
+ prepare_varlen_num_blocks(params, stream, params.pack_gqa, kBlockM, kBlockN, false /*enable_pdl*/);
624
+ CHECK_CUDA_KERNEL_LAUNCH();
625
+ }
626
+ return tile_count_semaphore;
627
+ }
628
+
629
+ // b: batch_size
630
+ // b_k: batch_size_k
631
+ // s_q: seqlen_q
632
+ // s_k: seqlen_k
633
+ // s_k_new: seqlen_k_new
634
+ // h: num_heads
635
+ // h_k: num_heads_k
636
+ // d: head_size
637
+ std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
638
+ mha_fwd(at::Tensor q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q
639
+ at::Tensor k, // (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size, h_k, d) if there is page_table.
640
+ at::Tensor v, // (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages, page_size, h_k, dv) if there is page_table.
641
+ std::optional<at::Tensor> k_new_, // (b, s_k_new, h_k, d) or (total_k_new, h_k, d) if there is cu_seqlens_k_new
642
+ std::optional<at::Tensor> v_new_, // (b, s_k_new, h_k, dv) or (total_k_new, h_k, dv) if there is cu_seqlens_k_new
643
+ std::optional<at::Tensor> q_v_, // (b, s_q, h, dv) or (total_q_new, h, dv) if there is cu_seqlens_q
644
+ std::optional<at::Tensor> out_, // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q
645
+ std::optional<at::Tensor> cu_seqlens_q_, // b+1
646
+ std::optional<at::Tensor> cu_seqlens_k_, // b+1
647
+ std::optional<at::Tensor> cu_seqlens_k_new_, // b+1
648
+ std::optional<at::Tensor> seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used.
649
+ std::optional<at::Tensor> seqused_k_, // b. If given, only this many elements of each batch element's keys are used.
650
+ std::optional<int64_t> max_seqlen_q_,
651
+ // TODO: check if we need max_seqlen_k
652
+ std::optional<int64_t> max_seqlen_k_,
653
+ std::optional<at::Tensor> page_table_, // (b_k, max_num_pages_per_seq)
654
+ std::optional<at::Tensor> kv_batch_idx_, // b. indices to index into the KV cache
655
+ std::optional<at::Tensor> leftpad_k_, // b
656
+ std::optional<at::Tensor> rotary_cos_, // seqlen_ro x (rotary_dim / 2)
657
+ std::optional<at::Tensor> rotary_sin_, // seqlen_ro x (rotary_dim / 2)
658
+ std::optional<at::Tensor> seqlens_rotary_, // b
659
+ std::optional<at::Tensor> q_descale_, // (b, h_k), not (b, h)
660
+ std::optional<at::Tensor> k_descale_, // (b, h_k)
661
+ std::optional<at::Tensor> v_descale_, // (b, h_k)
662
+ std::optional<double> softmax_scale_,
663
+ bool is_causal,
664
+ int64_t window_size_left,
665
+ int64_t window_size_right,
666
+ int64_t attention_chunk,
667
+ double softcap,
668
+ bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
669
+ std::optional<at::Tensor> scheduler_metadata_, // (b + 1)
670
+ int64_t num_splits,
671
+ std::optional<bool> pack_gqa_,
672
+ int64_t sm_margin
673
+ ) {
674
+
675
+ auto dprops = at::cuda::getCurrentDeviceProperties();
676
+ bool is_sm8x = dprops->major >= 8;
677
+ TORCH_CHECK(is_sm8x, "FlashAttention only supports Ampere GPUs or newer.");
678
+
679
+ auto q_type = q.scalar_type();
680
+ TORCH_CHECK(q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16 || q_type == at::ScalarType::Float8_e4m3fn,
681
+ "FlashAttention only supports fp16, bf16, and fp8_e4m3 data type");
682
+ if (dprops->major < 9) {
683
+ TORCH_CHECK(q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16,
684
+ "FlashAttention on Ampere/Ada cards only supports fp16 and bf16 data type");
685
+ }
686
+ TORCH_CHECK(k.scalar_type() == q_type, "query and key must have the same dtype");
687
+ TORCH_CHECK(v.scalar_type() == q_type, "query and value must have the same dtype");
688
+
689
+ CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
690
+
691
+ TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
692
+ TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
693
+ TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
694
+
695
+ at::Tensor page_table;
696
+ const bool paged_KV = page_table_.has_value();
697
+ if (paged_KV) {
698
+ page_table = page_table_.value();
699
+ CHECK_DEVICE(page_table);
700
+ TORCH_CHECK(page_table.dtype() == torch::kInt32, "page_table must have dtype torch.int32");
701
+ TORCH_CHECK(page_table.stride(-1) == 1, "page_table must have contiguous last dimension");
702
+ }
703
+
704
+ at::Tensor cu_seqlens_q;
705
+ bool const is_varlen_q = cu_seqlens_q_.has_value();
706
+ if (is_varlen_q) {
707
+ cu_seqlens_q = cu_seqlens_q_.value();
708
+ CHECK_DEVICE(cu_seqlens_q); CHECK_CONTIGUOUS(cu_seqlens_q);
709
+ TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype torch.int32");
710
+ TORCH_CHECK(max_seqlen_q_.has_value(), "max_seqlen_q must be provided if cu_seqlens_q is provided");
711
+ }
712
+ at::Tensor cu_seqlens_k;
713
+ bool const is_varlen_k = cu_seqlens_k_.has_value();
714
+ if (is_varlen_k) {
715
+ cu_seqlens_k = cu_seqlens_k_.value();
716
+ CHECK_DEVICE(cu_seqlens_k); CHECK_CONTIGUOUS(cu_seqlens_k);
717
+ TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype torch.int32");
718
+ TORCH_CHECK(max_seqlen_k_.has_value(), "max_seqlen_k must be provided if cu_seqlens_k is provided");
719
+ TORCH_CHECK(!paged_KV, "If cu_seqlens_k is passed in, then page table is not supported");
720
+ TORCH_CHECK(!kv_batch_idx_.has_value(), "If cu_seqlens_k is passed in, then page table is not supported");
721
+ }
722
+
723
+ auto const sizes = q.sizes();
724
+ const int batch_size = !is_varlen_q ? sizes[0] : cu_seqlens_q.size(0) - 1;
725
+ int seqlen_q = !is_varlen_q ? sizes[1] : max_seqlen_q_.value();
726
+ int total_q = !is_varlen_q ? batch_size * sizes[1] : sizes[0];
727
+ int num_heads = q.size(-2);
728
+ int const head_size = q.size(-1);
729
+ int const head_size_v = v.size(-1);
730
+ int const max_num_pages_per_seq = !paged_KV ? 0 : page_table.size(1);
731
+ int const num_pages = !paged_KV ? 0 : k.size(0);
732
+ int const page_size = !paged_KV ? 1 : k.size(1);
733
+ int const seqlen_k = !is_varlen_k ? (!paged_KV ? k.size(1) : max_num_pages_per_seq * page_size) : max_seqlen_k_.value();
734
+ int const total_k = !is_varlen_k ? batch_size * k.size(1) : k.size(0);
735
+ int const num_heads_k = k.size(-2);
736
+ int const batch_size_k = !paged_KV ? (!is_varlen_k ? k.size(0) : cu_seqlens_k.size(0) - 1) : page_table.size(0);
737
+ double softmax_scale = 1.0 / sqrt(double(head_size));
738
+ if (softmax_scale_.has_value()) {
739
+ softmax_scale = softmax_scale_.value();
740
+ }
741
+ if (!kv_batch_idx_.has_value()) {
742
+ TORCH_CHECK(batch_size == batch_size_k, "batch_size must be equal to batch_size_k");
743
+ }
744
+ int const max_headdim = get_max_headdim();
745
+ TORCH_CHECK(head_size <= max_headdim, "FlashAttention forward only supports head dimension at most " + std::to_string(max_headdim));
746
+ TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
747
+ if (head_size_v != head_size) {
748
+ TORCH_CHECK((head_size > 128 && head_size <= 192 && head_size_v > 96 && head_size_v <= 128) ||
749
+ (head_size <= 64 && head_size_v <= 512),
750
+ "If V headdim is different from Q/K dim, we only support Q/K headdim in (128, 192] and V headdim in (96, 128], "
751
+ "or (Q/K <= 64 and V <= 512).");
752
+ TORCH_CHECK(dprops->major == 9, "Only Hopper supports different V headdim");
753
+ if (head_size_v > 256) {
754
+ TORCH_CHECK(q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16,
755
+ "HeaddimV > 256 requires fp16 and bf16 data type");
756
+ }
757
+ }
758
+
759
+ // This needs to go before kBlockM & kBlockN since we rely on the correct window_size and is_causal to set kBlockM
760
+ // TODO: check this
761
+ if (window_size_left >= seqlen_k - 1) { window_size_left = -1; }
762
+ if (window_size_right >= seqlen_q - 1) { window_size_right = -1; }
763
+ // causal=true is the same as causal=false in this case
764
+ if (seqlen_q == 1 && window_size_left == -1 && window_size_right == -1 && attention_chunk == 0) {
765
+ // Special case of hdim 128 where we want causal to have kBlockN=128, better for pagedKV and TMA
766
+ if ((head_size <= 64 || head_size > 128) || !paged_KV) {
767
+ is_causal = false;
768
+ }
769
+ }
770
+ if (is_causal) { window_size_right = 0; }
771
+
772
+ if (!is_varlen_q) {
773
+ CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size);
774
+ } else {
775
+ CHECK_SHAPE(q, total_q, num_heads, head_size);
776
+ CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
777
+ }
778
+ if (!paged_KV) {
779
+ if (!is_varlen_k) {
780
+ CHECK_SHAPE(k, batch_size_k, seqlen_k, num_heads_k, head_size);
781
+ CHECK_SHAPE(v, batch_size_k, seqlen_k, num_heads_k, head_size_v);
782
+ } else {
783
+ CHECK_SHAPE(k, total_k, num_heads_k, head_size);
784
+ CHECK_SHAPE(v, total_k, num_heads_k, head_size_v);
785
+ CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
786
+ }
787
+ } else {
788
+ CHECK_SHAPE(k, num_pages, page_size, num_heads_k, head_size);
789
+ CHECK_SHAPE(v, num_pages, page_size, num_heads_k, head_size_v);
790
+ CHECK_SHAPE(page_table, batch_size_k, max_num_pages_per_seq);
791
+ }
792
+
793
+ if (seqused_q_.has_value()){
794
+ auto seqused_q = seqused_q_.value();
795
+ TORCH_CHECK(seqused_q.dtype() == torch::kInt32, "seqused_q must have dtype int32");
796
+ CHECK_DEVICE(seqused_q); CHECK_CONTIGUOUS(seqused_q);
797
+ CHECK_SHAPE(seqused_q, batch_size);
798
+ }
799
+ if (seqused_k_.has_value()) {
800
+ auto seqused_k = seqused_k_.value();
801
+ TORCH_CHECK(seqused_k.dtype() == torch::kInt32, "seqused_k must have dtype int32");
802
+ CHECK_DEVICE(seqused_k); CHECK_CONTIGUOUS(seqused_k);
803
+ CHECK_SHAPE(seqused_k, batch_size);
804
+ }
805
+
806
+ if (leftpad_k_.has_value()) {
807
+ auto leftpad_k = leftpad_k_.value();
808
+ TORCH_CHECK(leftpad_k.dtype() == torch::kInt32, "leftpad_k must have dtype int32");
809
+ CHECK_DEVICE(leftpad_k); CHECK_CONTIGUOUS(leftpad_k);
810
+ CHECK_SHAPE(leftpad_k, batch_size);
811
+ }
812
+
813
+ // This is what we will template on
814
+ bool const is_varlen = is_varlen_q || is_varlen_k || seqused_q_.has_value() || seqused_k_.has_value() || leftpad_k_.has_value();
815
+ #ifdef FLASHATTENTION_DISABLE_VARLEN
816
+ TORCH_CHECK(!is_varlen, "This flash attention build does not support varlen.");
817
+ #endif
818
+
819
+ int const alignment = q_type == torch::kFloat8_e4m3fn ? 16 : 8;
820
+ TORCH_CHECK(head_size % alignment == 0, "head_size should be a multiple of " + std::to_string(alignment));
821
+ TORCH_CHECK(head_size_v % alignment == 0, "head_size_v should be a multiple of " + std::to_string(alignment));
822
+
823
+ auto opts = q.options();
824
+ auto out_type = q_type == at::ScalarType::Float8_e4m3fn ? at::ScalarType::BFloat16 : q_type;
825
+ at::Tensor out;
826
+ if (out_.has_value()) {
827
+ out = out_.value();
828
+ TORCH_CHECK(out.scalar_type() == out_type, "For FP16/BF16 input, output must have the same dtype as inputs. For FP8 input, output must have dtype BF16");
829
+ CHECK_DEVICE(out);
830
+ TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
831
+ if (!is_varlen_q) {
832
+ CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_v);
833
+ } else {
834
+ CHECK_SHAPE(out, total_q, num_heads, head_size_v);
835
+ }
836
+ } else {
837
+ out = !is_varlen_q
838
+ ? torch::empty({batch_size, seqlen_q, num_heads, head_size_v}, opts.dtype(out_type))
839
+ : torch::empty({total_q, num_heads, head_size_v}, opts.dtype(out_type));
840
+ }
841
+
842
+ auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
843
+ int const head_size_rounded = round_up_headdim(head_size);
844
+ int const head_size_v_rounded = head_size_v == head_size ? head_size_rounded : round_up_headdimv(head_size_v);
845
+ int const seqlen_q_rounded = round_multiple(seqlen_q, 128);
846
+ int const seqlen_k_rounded = round_multiple(seqlen_k, 128);
847
+
848
+ // Otherwise the kernel will be launched from cuda:0 device
849
+ // Cast to char to avoid compiler warning about narrowing
850
+ at::cuda::CUDAGuard device_guard{(char)q.get_device()};
851
+
852
+ at::Tensor softmax_lse;
853
+ if (!is_varlen_q) {
854
+ softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
855
+ } else {
856
+ softmax_lse = torch::empty({num_heads, total_q}, opts.dtype(at::kFloat));
857
+ }
858
+
859
+ Flash_fwd_params params;
860
+ set_params_fprop(params,
861
+ batch_size,
862
+ seqlen_q, seqlen_k,
863
+ seqlen_q_rounded, seqlen_k_rounded,
864
+ num_heads, num_heads_k,
865
+ head_size, head_size_rounded,
866
+ q, k, v, out,
867
+ !is_varlen_q ? nullptr : cu_seqlens_q.data_ptr(),
868
+ !is_varlen_k ? nullptr : cu_seqlens_k.data_ptr(),
869
+ seqused_q_.has_value() ? seqused_q_.value().data_ptr() : nullptr,
870
+ seqused_k_.has_value() ? seqused_k_.value().data_ptr() : nullptr,
871
+ softmax_lse.data_ptr(),
872
+ /*p_dropout=*/0.f,
873
+ softmax_scale,
874
+ window_size_left,
875
+ window_size_right,
876
+ attention_chunk,
877
+ softcap,
878
+ sm_margin);
879
+ params.total_q = total_q;
880
+ params.total_k = total_k;
881
+ params.b_k = batch_size_k;
882
+ params.dv = head_size_v;
883
+ params.dv_rounded = head_size_v_rounded;
884
+ if (leftpad_k_.has_value()) { // This needs to be set before get_pagedkv_tma
885
+ params.leftpad_k = static_cast<int *>(leftpad_k_.value().data_ptr());
886
+ }
887
+ if (paged_KV) {
888
+ params.page_table = page_table.data_ptr<int>();
889
+ params.page_table_batch_stride = page_table.stride(0);
890
+ }
891
+ params.page_size = page_size;
892
+ params.num_pages = num_pages;
893
+
894
+ if (k_new_.has_value()) { // This needs to be set before get_pagedkv_tma
895
+ at::Tensor k_new, v_new;
896
+ TORCH_CHECK(v_new_.has_value(), "If k_new is supplied, v_new must also be passed in");
897
+ TORCH_CHECK(seqused_k_.has_value(), "If k_new is supplied, seqlens_k must also be passed in");
898
+ TORCH_CHECK(seqlen_q <= seqlen_k, "If k_new is supplied, it must have seqlen <= the seqlen of the KV cache");
899
+ at::Tensor cu_seqlens_k_new;
900
+ bool const is_varlen_k_new = cu_seqlens_k_new_.has_value();
901
+ if (is_varlen_k_new) {
902
+ cu_seqlens_k_new = cu_seqlens_k_new_.value();
903
+ CHECK_DEVICE(cu_seqlens_k_new); CHECK_CONTIGUOUS(cu_seqlens_k_new);
904
+ TORCH_CHECK(cu_seqlens_k_new.dtype() == torch::kInt32, "cu_seqlens_k_new must have dtype torch.int32");
905
+ }
906
+ k_new = k_new_.value();
907
+ v_new = v_new_.value();
908
+ TORCH_CHECK(k_new.dtype() == q_type, "k_new must have the same dtype as query");
909
+ TORCH_CHECK(v_new.dtype() == q_type, "v_new must have the same dtype as query");
910
+ CHECK_DEVICE(k_new); CHECK_DEVICE(v_new);
911
+ TORCH_CHECK(k_new.stride(-1) == 1, "k_new tensor must have contiguous last dimension");
912
+ TORCH_CHECK(v_new.stride(-1) == 1, "v_new tensor must have contiguous last dimension");
913
+ // We don't need max_seqlen_k_new, so seqlen_k_new can be whatever when is_varlen_k_new
914
+ int seqlen_k_new = !is_varlen_k_new ? k_new.size(1) : 0;
915
+ int total_k_new = !is_varlen_k_new ? batch_size * k_new.size(1): k_new.size(0);
916
+ if (!is_varlen_k_new) {
917
+ CHECK_SHAPE(k_new, batch_size, seqlen_k_new, num_heads_k, head_size);
918
+ CHECK_SHAPE(v_new, batch_size, seqlen_k_new, num_heads_k, head_size_v);
919
+ } else {
920
+ CHECK_SHAPE(k_new, total_k_new, num_heads_k, head_size);
921
+ CHECK_SHAPE(v_new, total_k_new, num_heads_k, head_size_v);
922
+ CHECK_SHAPE(cu_seqlens_k_new, batch_size + 1);
923
+ }
924
+ params.seqlen_knew = seqlen_k_new;
925
+ params.total_knew = total_k_new;
926
+ params.knew_ptr = k_new.data_ptr();
927
+ params.vnew_ptr = v_new.data_ptr();
928
+ // All stride are in elements, not bytes.
929
+ params.knew_row_stride = k_new.stride(-3);
930
+ params.vnew_row_stride = v_new.stride(-3);
931
+ params.knew_head_stride = k_new.stride(-2);
932
+ params.vnew_head_stride = v_new.stride(-2);
933
+ if (!is_varlen_k_new) {
934
+ params.knew_batch_stride = k_new.stride(0);
935
+ params.vnew_batch_stride = v_new.stride(0);
936
+ }
937
+ if (is_varlen_k_new) {
938
+ params.cu_seqlens_knew = static_cast<int*>(cu_seqlens_k_new.data_ptr());
939
+ }
940
+ }
941
+
942
+ // 992 = 32 * 31 is the max supported batch in prepare_varlen_num_blocks kernel
943
+ bool const use_dynamic_split = is_varlen && params.b <= 992;
944
+ // Temporarily set num_splits_dynamic_ptr to 1 since get_num_splits checks it
945
+ params.num_splits_dynamic_ptr = !use_dynamic_split ? nullptr : reinterpret_cast<int*>(1);
946
+
947
+ params.pagedkv_tma = get_pagedkv_tma(params);
948
+ params.num_splits = num_splits <= 0 ? get_num_splits(params) : num_splits;
949
+ // Always enable PackGQA for Split, and get_pack_gqa requires params.num_splits to decide
950
+ params.pack_gqa = pack_gqa_.has_value() ? pack_gqa_.value() : get_pack_gqa(params);
951
+
952
+ // This needs to be set after get_num_splits
953
+ at::Tensor tile_count_semaphore; // Contains the semaphore and optionally num_splits_dynamic
954
+ // We don't use the persistent scheduler if Split and not Varlen
955
+ bool const scheduler_needs_semaphore = params.arch >= 90
956
+ ? (((params.is_causal || params.is_local) && (params.num_splits == 1)) || is_varlen)
957
+ : ((params.is_causal && !is_varlen) || (is_varlen && params.num_splits > 1));
958
+ if (scheduler_needs_semaphore || use_dynamic_split) {
959
+ int metadata_size = int(scheduler_needs_semaphore) + int(use_dynamic_split) * params.b;
960
+ params.skip_scheduler_metadata_computation = scheduler_metadata_.has_value();
961
+ if (scheduler_metadata_.has_value()) {
962
+ at::Tensor scheduler_metadata = scheduler_metadata_.value();
963
+ CHECK_DEVICE(scheduler_metadata);
964
+ CHECK_SHAPE(scheduler_metadata, metadata_size);
965
+ CHECK_CONTIGUOUS(scheduler_metadata);
966
+ TORCH_CHECK(scheduler_metadata.dtype() == torch::kInt32, "scheduler_metadata must have dtype int32");
967
+ tile_count_semaphore = scheduler_metadata;
968
+ } else {
969
+ tile_count_semaphore = torch::empty({metadata_size}, opts.dtype(torch::kInt32));
970
+ }
971
+ if (scheduler_needs_semaphore && !use_dynamic_split) {
972
+ tile_count_semaphore.zero_(); // If varlen we'll manually do the zero-ing
973
+ }
974
+ params.tile_count_semaphore = scheduler_needs_semaphore ? tile_count_semaphore.data_ptr<int>() : nullptr;
975
+ params.num_splits_dynamic_ptr = use_dynamic_split ? tile_count_semaphore.data_ptr<int>() + 1 : nullptr;
976
+ }
977
+
978
+ if (q_v_.has_value()) {
979
+ TORCH_CHECK(head_size <= 64, "q_v is only supported for head_size <= 64");
980
+ TORCH_CHECK(q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16,
981
+ "q_v is only supported for fp16 and bf16 data type");
982
+ TORCH_CHECK(params.arch == 90, "q_v is only supported for Hopper GPUs");
983
+ at::Tensor q_v = q_v_.value();
984
+ TORCH_CHECK(q_v.dtype() == q_type, "q_v must have the same dtype as query");
985
+ CHECK_DEVICE(q_v);
986
+ TORCH_CHECK(q_v.stride(-1) == 1, "q_v tensor must have contiguous last dimension");
987
+ if (!is_varlen_q) {
988
+ CHECK_SHAPE(q_v, batch_size, seqlen_q, num_heads, head_size_v);
989
+ } else {
990
+ CHECK_SHAPE(q_v, total_q, num_heads, head_size_v);
991
+ }
992
+ params.qv_ptr = q_v.data_ptr();
993
+ // All stride are in elements, not bytes.
994
+ params.qv_row_stride = q_v.stride(-3);
995
+ params.qv_head_stride = q_v.stride(-2);
996
+ if (!is_varlen_q) {
997
+ params.qv_batch_stride = q_v.stride(0);
998
+ }
999
+ }
1000
+
1001
+ if (rotary_cos_.has_value()) {
1002
+ TORCH_CHECK(k_new_.has_value(), "If rotary cos/sin are provided, new key / value to be appended to KV cache must also be provided");
1003
+ auto rotary_cos = rotary_cos_.value();
1004
+ CHECK_DEVICE(rotary_cos); CHECK_CONTIGUOUS(rotary_cos);
1005
+ params.rotary_dim = rotary_cos.size(1) * 2;
1006
+ TORCH_CHECK(params.rotary_dim <= head_size, "rotary_dim must be <= headdim");
1007
+ TORCH_CHECK(params.rotary_dim % 16 == 0, "Only rotary dimensions divisible by 16 are currently supported");
1008
+ const int seqlen_ro = rotary_cos.size(0);
1009
+ if (paged_KV) {
1010
+ TORCH_CHECK(seqlen_ro >= seqlen_k, "cos/sin seqlen must be at least the seqlen of KV cache");
1011
+ }
1012
+ CHECK_SHAPE(rotary_cos, seqlen_ro, params.rotary_dim / 2);
1013
+ TORCH_CHECK(rotary_cos.scalar_type() == q_type, "rotary_cos must have the same dtype as query");
1014
+
1015
+ TORCH_CHECK(rotary_sin_.has_value(), "If rotary cos is provided, rotary sin must also be provided");
1016
+ auto rotary_sin = rotary_sin_.value();
1017
+ CHECK_DEVICE(rotary_sin); CHECK_CONTIGUOUS(rotary_sin);
1018
+ CHECK_SHAPE(rotary_sin, seqlen_ro, params.rotary_dim / 2);
1019
+ TORCH_CHECK(rotary_sin.scalar_type() == q_type, "rotary_cos must have the same dtype as query");
1020
+ params.rotary_cos_ptr = rotary_cos.data_ptr();
1021
+ params.rotary_sin_ptr = rotary_sin.data_ptr();
1022
+ params.is_rotary_interleaved = is_rotary_interleaved;
1023
+ if (seqlens_rotary_.has_value()) {
1024
+ at::Tensor seqlens_rotary = seqlens_rotary_.value();
1025
+ CHECK_DEVICE(seqlens_rotary); CHECK_CONTIGUOUS(seqlens_rotary);
1026
+ TORCH_CHECK(seqlens_rotary.dtype() == torch::kInt32, "seqlens_rotary must have dtype torch.int32");
1027
+ CHECK_SHAPE(seqlens_rotary, batch_size);
1028
+ params.seqlens_rotary = seqlens_rotary.data_ptr<int>();
1029
+ }
1030
+ } else {
1031
+ params.rotary_dim = 0;
1032
+ }
1033
+
1034
+ if (kv_batch_idx_.has_value()) {
1035
+ auto kv_batch_idx = kv_batch_idx_.value();
1036
+ CHECK_DEVICE(kv_batch_idx); CHECK_CONTIGUOUS(kv_batch_idx);
1037
+ TORCH_CHECK(kv_batch_idx.scalar_type() == torch::kInt32, "kv_batch_idx must have dtype int32");
1038
+ params.kv_batch_idx = reinterpret_cast<int *>(kv_batch_idx.data_ptr());
1039
+ }
1040
+
1041
+ at::Tensor out_accum, softmax_lse_accum;
1042
+ auto outaccum_type = at::ScalarType::Float;
1043
+ if (params.num_splits > 1) {
1044
+ TORCH_CHECK(params.num_splits <= 256, "num_splits > 256 not supported");
1045
+ if (!is_varlen_q) {
1046
+ out_accum = torch::empty({params.num_splits, batch_size, num_heads, seqlen_q, head_size_v}, opts.dtype(outaccum_type));
1047
+ softmax_lse_accum = torch::empty({params.num_splits, batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
1048
+ params.oaccum_batch_stride = out_accum.stride(1);
1049
+ params.lseaccum_batch_stride = softmax_lse_accum.stride(1);
1050
+ } else {
1051
+ out_accum = torch::empty({params.num_splits, num_heads, total_q, head_size_v}, opts.dtype(outaccum_type));
1052
+ softmax_lse_accum = torch::empty({params.num_splits, num_heads, total_q}, opts.dtype(at::kFloat));
1053
+ }
1054
+ params.is_fp32 = false;
1055
+ params.oaccum_ptr = out_accum.data_ptr();
1056
+ params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr();
1057
+ params.oaccum_split_stride = out_accum.stride(0);
1058
+ params.oaccum_row_stride = out_accum.stride(-2);
1059
+ params.oaccum_head_stride = out_accum.stride(-3);
1060
+ params.lseaccum_split_stride = softmax_lse_accum.stride(0);
1061
+ params.lseaccum_head_stride = softmax_lse_accum.stride(-2);
1062
+ }
1063
+
1064
+ if (q_type == at::ScalarType::Float8_e4m3fn) {
1065
+ if (q_descale_.has_value()) {
1066
+ auto q_descale = q_descale_.value();
1067
+ CHECK_DEVICE(q_descale);
1068
+ CHECK_SHAPE(q_descale, batch_size, num_heads_k);
1069
+ params.q_descale_ptr = q_descale.data_ptr<float>();
1070
+ params.q_descale_batch_stride = q_descale.stride(0);
1071
+ params.q_descale_head_stride = q_descale.stride(1);
1072
+ } else {
1073
+ params.q_descale_ptr = nullptr;
1074
+ }
1075
+ if (k_descale_.has_value()) {
1076
+ auto k_descale = k_descale_.value();
1077
+ CHECK_DEVICE(k_descale);
1078
+ CHECK_SHAPE(k_descale, batch_size, num_heads_k);
1079
+ params.k_descale_ptr = k_descale.data_ptr<float>();
1080
+ params.k_descale_batch_stride = k_descale.stride(0);
1081
+ params.k_descale_head_stride = k_descale.stride(1);
1082
+ } else {
1083
+ params.k_descale_ptr = nullptr;
1084
+ }
1085
+ if (v_descale_.has_value()) {
1086
+ auto v_descale = v_descale_.value();
1087
+ CHECK_DEVICE(v_descale);
1088
+ CHECK_SHAPE(v_descale, batch_size, num_heads_k);
1089
+ params.v_descale_ptr = v_descale.data_ptr<float>();
1090
+ params.v_descale_batch_stride = v_descale.stride(0);
1091
+ params.v_descale_head_stride = v_descale.stride(1);
1092
+ } else {
1093
+ params.v_descale_ptr = nullptr;
1094
+ }
1095
+ }
1096
+
1097
+ #ifdef FLASHATTENTION_DISABLE_LOCAL
1098
+ TORCH_CHECK(!params.is_local, "This flash attention build does not support local attention.");
1099
+ #endif
1100
+ #ifdef FLASHATTENTION_DISABLE_SOFTCAP
1101
+ TORCH_CHECK(params.softcap == 0.0, "This flash attention build does not support tanh softcapping.");
1102
+ #endif
1103
+ #ifdef FLASHATTENTION_DISABLE_SPLIT
1104
+ TORCH_CHECK(params.num_splits == 1, "This flash attention build does not support splits.");
1105
+ #endif
1106
+ #ifdef FLASHATTENTION_DISABLE_PACKGQA
1107
+ TORCH_CHECK(!params.pack_gqa || params.arch < 90 || (params.page_table && !params.pagedkv_tma) || params.num_splits > 1, "This flash attention build does not support pack_gqa.");
1108
+ #endif
1109
+ #ifdef FLASHATTENTION_DISABLE_PAGEDKV
1110
+ TORCH_CHECK(!(params.page_table && !params.pagedkv_tma), "This flash attention build does not support paged KV.");
1111
+ #endif
1112
+ #ifdef FLASHATTENTION_DISABLE_APPENDKV
1113
+ TORCH_CHECK(!k_new_.has_value(), "This flash attention build does not support appending KV.");
1114
+ #endif
1115
+
1116
+ if (total_q > 0 && (total_k + params.total_knew) > 0 && num_heads_k > 0) {
1117
+ auto stream = at::cuda::getCurrentCUDAStream().stream();
1118
+ run_mha_fwd(params, stream);
1119
+ if (params.num_splits > 1) {
1120
+ if (out_type == at::ScalarType::BFloat16) {
1121
+ // Since we want output in BF16. Otherwise fwd_combine will output to FP16
1122
+ params.is_bf16 = true;
1123
+ }
1124
+ // Unless there's seqused_q, for the purpose of attn_combine, we can just treat it as batch=1
1125
+ // and seqlen = total_q, and don't need to dispatch to Varlen there.
1126
+ // However, with dynamic split, each row needs to know which batch it belongs to
1127
+ // to read the number of splits, so we just use the varlen version of combine kernel.
1128
+ // if (is_varlen_q && !seqused_q_.has_value()) {
1129
+ // if (is_varlen_q) {
1130
+ // params.b = 1;
1131
+ // params.seqlen_q = total_q;
1132
+ // }
1133
+ // This will zero out the semaphore if needed
1134
+ run_mha_fwd_combine(params, stream, true /*enable_pdl*/);
1135
+ } else if (scheduler_needs_semaphore && params.skip_scheduler_metadata_computation) {
1136
+ // need to zero out the semaphore in this case
1137
+ tile_count_semaphore.index({torch::indexing::Slice(0, 1)}).zero_();
1138
+ }
1139
+ } else if (total_q > 0 && num_heads_k > 0) {
1140
+ // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0.
1141
+ out.zero_();
1142
+ softmax_lse.fill_(std::numeric_limits<float>::infinity());
1143
+ }
1144
+
1145
+ // return {out, softmax_lse};
1146
+ return {out, softmax_lse, out_accum, softmax_lse_accum};
1147
+ }
1148
+
1149
+ #ifdef FLASHATTENTION_DISABLE_BACKWARD
1150
+ void run_mha_bwd(Flash_bwd_params &params, cudaStream_t stream) {
1151
+ TORCH_CHECK(false, "Flash-Attention was built with backward disabled");
1152
+ }
1153
+ #else
1154
+ template <int Arch, bool Has_softcap>
1155
+ void run_mha_bwd_constexpr(Flash_bwd_params &params, cudaStream_t stream) {
1156
+ if (!params.is_bf16) {
1157
+ #ifndef FLASHATTENTION_DISABLE_FP16
1158
+ #ifndef FLASHATTENTION_DISABLE_HDIM64
1159
+ if (params.d_rounded == 64) { return run_mha_bwd_<Arch, cutlass::half_t, 64, Has_softcap>(params, stream); }
1160
+ #endif
1161
+ #ifndef FLASHATTENTION_DISABLE_HDIM96
1162
+ if (params.d_rounded == 96) { return run_mha_bwd_<Arch, cutlass::half_t, 96, Has_softcap>(params, stream); }
1163
+ #endif
1164
+ #ifndef FLASHATTENTION_DISABLE_HDIM128
1165
+ if (params.d_rounded == 128) { return run_mha_bwd_<Arch, cutlass::half_t, 128, Has_softcap>(params, stream); }
1166
+ #endif
1167
+ #ifndef FLASHATTENTION_DISABLE_HDIM192
1168
+ if (params.d_rounded == 192) { return run_mha_bwd_<Arch, cutlass::half_t, 192, Has_softcap>(params, stream); }
1169
+ #endif
1170
+ #ifndef FLASHATTENTION_DISABLE_HDIM256
1171
+ if (params.d_rounded == 256) { return run_mha_bwd_<Arch, cutlass::half_t, 256, Has_softcap>(params, stream); }
1172
+ #endif
1173
+ #else
1174
+ TORCH_CHECK(false, "This flash attention build does not support FP16.");
1175
+ #endif
1176
+ } else {
1177
+ #ifndef FLASHATTENTION_DISABLE_HDIM64
1178
+ if (params.d_rounded == 64) { return run_mha_bwd_<Arch, cutlass::bfloat16_t, 64, Has_softcap>(params, stream); }
1179
+ #endif
1180
+ #ifndef FLASHATTENTION_DISABLE_HDIM96
1181
+ if (params.d_rounded == 96) { return run_mha_bwd_<Arch, cutlass::bfloat16_t, 96, Has_softcap>(params, stream); }
1182
+ #endif
1183
+ #ifndef FLASHATTENTION_DISABLE_HDIM128
1184
+ if (params.d_rounded == 128) { return run_mha_bwd_<Arch, cutlass::bfloat16_t, 128, Has_softcap>(params, stream); }
1185
+ #endif
1186
+ #ifndef FLASHATTENTION_DISABLE_HDIM192
1187
+ if (params.d_rounded == 192) { return run_mha_bwd_<Arch, cutlass::bfloat16_t, 192, Has_softcap>(params, stream); }
1188
+ #endif
1189
+ #ifndef FLASHATTENTION_DISABLE_HDIM256
1190
+ if (params.d_rounded == 256) { return run_mha_bwd_<Arch, cutlass::bfloat16_t, 256, Has_softcap>(params, stream); }
1191
+ #endif
1192
+ }
1193
+ }
1194
+
1195
+ void run_mha_bwd(Flash_bwd_params &params, cudaStream_t stream) {
1196
+ // FP16_SWITCH(!params.is_bf16, [&] {
1197
+ // HEADDIM_SWITCH(params.d, [&] {
1198
+ // run_mha_bwd_<elem_type, kHeadDim>(params, stream);
1199
+ // });
1200
+ // });
1201
+ ARCH_SWITCH(params.arch, Arch, [&] {
1202
+ SOFTCAP_SWITCH(params.softcap > 0.f, Has_softcap, [&] {
1203
+ run_mha_bwd_constexpr<Arch, Has_softcap>(params, stream);
1204
+ });
1205
+ });
1206
+ }
1207
+ #endif
1208
+
1209
+
1210
+ // b: batch_size
1211
+ // s_q: seqlen_q
1212
+ // s_k: seqlen_k
1213
+ // h: num_heads
1214
+ // h_k: num_heads_k
1215
+ // d: head_size
1216
+ std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor> mha_bwd(
1217
+ at::Tensor dout, // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q
1218
+ at::Tensor q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q
1219
+ at::Tensor k, // (b, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k
1220
+ at::Tensor v, // (b, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k
1221
+ at::Tensor out, // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q
1222
+ at::Tensor softmax_lse, // (b, h, s_q) or (h, total_q) if there is cu_seqlens_q
1223
+ std::optional<at::Tensor> dq_, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q
1224
+ std::optional<at::Tensor> dk_, // (b, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k
1225
+ std::optional<at::Tensor> dv_, // (b, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k
1226
+ std::optional<at::Tensor> cu_seqlens_q_, // b+1
1227
+ std::optional<at::Tensor> cu_seqlens_k_, // b+1
1228
+ std::optional<at::Tensor> seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used.
1229
+ std::optional<at::Tensor> seqused_k_, // b. If given, only this many elements of each batch element's keys are used.
1230
+ std::optional<int64_t> max_seqlen_q_,
1231
+ std::optional<int64_t> max_seqlen_k_,
1232
+ std::optional<double> softmax_scale_,
1233
+ bool is_causal,
1234
+ int64_t window_size_left,
1235
+ int64_t window_size_right,
1236
+ double softcap,
1237
+ bool deterministic,
1238
+ int64_t sm_margin
1239
+ ) {
1240
+
1241
+ #ifdef FLASHATTENTION_DISABLE_BACKWARD
1242
+ TORCH_CHECK(false, "This flash attention build does not support backward.");
1243
+ #endif
1244
+
1245
+ auto dprops = at::cuda::getCurrentDeviceProperties();
1246
+ bool is_sm8x = dprops->major >= 8;
1247
+ TORCH_CHECK(is_sm8x, "FlashAttention only supports Ampere GPUs or newer.");
1248
+
1249
+ auto q_type = q.dtype();
1250
+ TORCH_CHECK(q_type == torch::kFloat16 || q_type == torch::kBFloat16,
1251
+ "FlashAttention only support fp16 and bf16 data type");
1252
+ TORCH_CHECK(k.dtype() == q_type, "query and key must have the same dtype");
1253
+ TORCH_CHECK(v.dtype() == q_type, "query and value must have the same dtype");
1254
+ TORCH_CHECK(out.dtype() == q_type, "query and out must have the same dtype");
1255
+ TORCH_CHECK(dout.dtype() == q_type, "query and dout must have the same dtype");
1256
+
1257
+ CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
1258
+ CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse);
1259
+
1260
+ TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
1261
+ TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
1262
+ TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
1263
+ TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension");
1264
+ TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension");
1265
+
1266
+ at::Tensor cu_seqlens_q;
1267
+ bool const is_varlen_q = cu_seqlens_q_.has_value();
1268
+ if (is_varlen_q) {
1269
+ cu_seqlens_q = cu_seqlens_q_.value();
1270
+ CHECK_DEVICE(cu_seqlens_q); CHECK_CONTIGUOUS(cu_seqlens_q);
1271
+ TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype torch.int32");
1272
+ TORCH_CHECK(max_seqlen_q_.has_value(), "max_seqlen_q must be provided if cu_seqlens_q is provided");
1273
+ }
1274
+ at::Tensor cu_seqlens_k;
1275
+ bool const is_varlen_k = cu_seqlens_k_.has_value();
1276
+ if (is_varlen_k) {
1277
+ cu_seqlens_k = cu_seqlens_k_.value();
1278
+ CHECK_DEVICE(cu_seqlens_k); CHECK_CONTIGUOUS(cu_seqlens_k);
1279
+ TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype torch.int32");
1280
+ TORCH_CHECK(max_seqlen_k_.has_value(), "max_seqlen_k must be provided if cu_seqlens_k is provided");
1281
+ }
1282
+ // This is what we will template on
1283
+ bool const is_varlen = is_varlen_q || is_varlen_k || seqused_q_.has_value() || seqused_k_.has_value();
1284
+ #ifdef FLASHATTENTION_DISABLE_VARLEN
1285
+ TORCH_CHECK(!is_varlen, "This flash attention build does not support varlen.");
1286
+ #endif
1287
+
1288
+ auto const sizes = q.sizes();
1289
+ int const batch_size = !is_varlen_q ? sizes[0] : cu_seqlens_q.size(0) - 1;
1290
+ int const seqlen_q = !is_varlen_q ? sizes[1] : max_seqlen_q_.value();
1291
+ int const total_q = !is_varlen_q ? batch_size * sizes[1] : sizes[0];
1292
+ int const num_heads = q.size(-2);
1293
+ int const head_size = q.size(-1);
1294
+ int const head_size_v = v.size(-1);
1295
+ int const seqlen_k = !is_varlen_k ? k.size(1) : max_seqlen_k_.value();
1296
+ int const total_k = !is_varlen_k ? batch_size * k.size(1) : k.size(0);
1297
+ int const num_heads_k = k.size(-2);
1298
+ TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8");
1299
+ TORCH_CHECK(head_size_v % 8 == 0, "head_size_v should be a multiple of 8");
1300
+ int const max_headdim = get_max_headdim();
1301
+ TORCH_CHECK(std::max(head_size, head_size_v) <= max_headdim, "FlashAttention forward only supports head dimension at most " + std::to_string(max_headdim));
1302
+ TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
1303
+ double softmax_scale = 1.0 / sqrt(double(head_size));
1304
+ if (softmax_scale_.has_value()) {
1305
+ softmax_scale = softmax_scale_.value();
1306
+ }
1307
+
1308
+ // This needs to go before kBlockM & kBlockN since we rely on the correct window_size and is_causal to set kBlockM
1309
+ if (window_size_left >= seqlen_k - 1) { window_size_left = -1; }
1310
+ if (window_size_right >= seqlen_q - 1) { window_size_right = -1; }
1311
+ if (is_causal) { window_size_right = 0; }
1312
+ // There's a case where is_causal=false, window_size=(-1, 0). Then set_params_bprop will set params.is_causal=true.
1313
+ // If we don't have is_causal here matching params.is_causal, we might get the wrong kBlockM (and cause IMA).
1314
+ is_causal = window_size_left < 0 && window_size_right == 0;
1315
+
1316
+ int const arch = at::cuda::getCurrentDeviceProperties()->major * 10 + at::cuda::getCurrentDeviceProperties()->minor;
1317
+ int const head_size_rounded = round_up_headdim(std::max(head_size, head_size_v));
1318
+ int const head_size_v_rounded = head_size_rounded;
1319
+ // Very important that these match the kernel configs
1320
+ bool const is_local = (window_size_left >= 0 || window_size_right >= 0) && !is_causal;
1321
+ int const kBlockM_sm90 = head_size_rounded <= 64 ? (is_causal && softcap > 0.0 ? 96 : 128)
1322
+ : (head_size_rounded <= 96 ? 64
1323
+ : (head_size_rounded <= 128 ? (is_causal || is_local || softcap > 0.0 ? 64 : 80)
1324
+ : 64));
1325
+ int const kBlockM_sm80 = head_size_rounded <= 64 ? 128 : 64;
1326
+ int const kBlockM_sm86 = head_size_rounded <= 192 ? 64 : 32;
1327
+ int const kBlockM = arch >= 90 ? kBlockM_sm90 : (arch == 86 || arch == 89 ? kBlockM_sm86 : kBlockM_sm80);
1328
+ int const kBlockN_sm90 = head_size_rounded <= 128
1329
+ ? 128
1330
+ : (head_size_rounded <= 192 ? 96 : 80);
1331
+ int const kBlockN_sm80 = head_size_rounded <= 128
1332
+ ? 128
1333
+ : (head_size_rounded <= 192 ? 80 : 64);
1334
+ int const kBlockN_sm86 = head_size_rounded <= 64 ? 128
1335
+ : (head_size_rounded <= 96 ? 128
1336
+ : (head_size_rounded <= 128 ? 96
1337
+ : (head_size_rounded <= 192 ? 64 : 64)));
1338
+ int const kBlockN = arch >= 90 ? kBlockN_sm90 : (arch == 86 || arch == 89 ? kBlockN_sm86 : kBlockN_sm80);
1339
+ auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
1340
+ int const seqlen_q_rounded = round_multiple(seqlen_q, kBlockM);
1341
+ int const seqlen_k_rounded = round_multiple(seqlen_k, kBlockN);
1342
+ int const total_q_padded_rounded = round_multiple(total_q + batch_size * kBlockM, kBlockM);
1343
+ int const total_k_padded_rounded = round_multiple(total_k + batch_size * kBlockN, kBlockN);
1344
+
1345
+ if (!is_varlen_q) {
1346
+ CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size);
1347
+ CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_v);
1348
+ CHECK_SHAPE(dout, batch_size, seqlen_q, num_heads, head_size_v);
1349
+ } else {
1350
+ CHECK_SHAPE(q, total_q, num_heads, head_size);
1351
+ CHECK_SHAPE(out, total_q, num_heads, head_size_v);
1352
+ CHECK_SHAPE(dout, total_q, num_heads, head_size_v);
1353
+ CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
1354
+ }
1355
+ if (!is_varlen_k) {
1356
+ CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size);
1357
+ CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size_v);
1358
+ } else {
1359
+ CHECK_SHAPE(k, total_k, num_heads_k, head_size);
1360
+ CHECK_SHAPE(v, total_k, num_heads_k, head_size_v);
1361
+ CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
1362
+ }
1363
+
1364
+ if (seqused_q_.has_value()){
1365
+ auto seqused_q = seqused_q_.value();
1366
+ TORCH_CHECK(seqused_q.dtype() == torch::kInt32, "seqused_q must have dtype int32");
1367
+ CHECK_DEVICE(seqused_q); CHECK_CONTIGUOUS(seqused_q);
1368
+ CHECK_SHAPE(seqused_q, batch_size);
1369
+ }
1370
+ if (seqused_k_.has_value()){
1371
+ auto seqused_k = seqused_k_.value();
1372
+ TORCH_CHECK(seqused_k.dtype() == torch::kInt32, "seqused_k must have dtype int32");
1373
+ CHECK_DEVICE(seqused_k); CHECK_CONTIGUOUS(seqused_k);
1374
+ CHECK_SHAPE(seqused_k, batch_size);
1375
+ }
1376
+
1377
+ at::Tensor dq, dk, dv;
1378
+ if (dq_.has_value()) {
1379
+ dq = dq_.value();
1380
+ TORCH_CHECK(dq.dtype() == q_type, "dq must have the same dtype as q");
1381
+ CHECK_DEVICE(dq);
1382
+ TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension");
1383
+ if (!is_varlen_q) {
1384
+ CHECK_SHAPE(dq, batch_size, seqlen_q, num_heads, head_size);
1385
+ } else {
1386
+ CHECK_SHAPE(dq, total_q, num_heads, head_size);
1387
+ }
1388
+ } else {
1389
+ dq = torch::empty_like(q);
1390
+ }
1391
+ if (dk_.has_value()) {
1392
+ dk = dk_.value();
1393
+ TORCH_CHECK(dk.dtype() == q_type, "dk must have the same dtype as q");
1394
+ CHECK_DEVICE(dk);
1395
+ TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension");
1396
+ if (!is_varlen_k) {
1397
+ CHECK_SHAPE(dk, batch_size, seqlen_k, num_heads_k, head_size);
1398
+ } else {
1399
+ CHECK_SHAPE(dk, total_k, num_heads_k, head_size);
1400
+ }
1401
+ } else {
1402
+ dk = torch::empty_like(k);
1403
+ }
1404
+ if (dv_.has_value()) {
1405
+ dv = dv_.value();
1406
+ TORCH_CHECK(dv.dtype() == q_type, "dv must have the same dtype as q");
1407
+ CHECK_DEVICE(dv);
1408
+ TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension");
1409
+ if (!is_varlen_k) {
1410
+ CHECK_SHAPE(dv, batch_size, seqlen_k, num_heads_k, head_size_v);
1411
+ } else {
1412
+ CHECK_SHAPE(dv, total_k, num_heads_k, head_size_v);
1413
+ }
1414
+ } else {
1415
+ dv = torch::empty_like(v);
1416
+ }
1417
+
1418
+ // Otherwise the kernel will be launched from cuda:0 device
1419
+ // Cast to char to avoid compiler warning about narrowing
1420
+ at::cuda::CUDAGuard device_guard{(char)q.get_device()};
1421
+
1422
+ auto opts = q.options();
1423
+ // Need softmax_d to have total_q_padded_rounded since we want its address to be aligned by 16/8 bytes for TMA / LDG.64
1424
+ at::Tensor softmax_d, softmax_lse_log2;
1425
+ if (!is_varlen) {
1426
+ // Need softmax_d to have seqlen_q_rounded since we want its address to be aligned by 16/8 bytes for TMA / LDG.64
1427
+ softmax_d = torch::empty({batch_size, num_heads, seqlen_q_rounded}, opts.dtype(at::kFloat));
1428
+ softmax_lse_log2 = torch::empty({batch_size, num_heads, seqlen_q_rounded}, opts.dtype(at::kFloat));
1429
+ } else {
1430
+ softmax_d = torch::empty({num_heads, total_q_padded_rounded}, opts.dtype(at::kFloat));
1431
+ softmax_lse_log2 = torch::empty({num_heads, total_q_padded_rounded}, opts.dtype(at::kFloat));
1432
+ }
1433
+ at::Tensor dq_accum, dk_accum, dv_accum;
1434
+ if (!is_varlen) {
1435
+ dq_accum = torch::empty({batch_size, num_heads, seqlen_q_rounded * head_size_rounded}, opts.dtype(at::kFloat));
1436
+ } else {
1437
+ dq_accum = torch::empty({num_heads, total_q_padded_rounded * head_size_rounded}, opts.dtype(at::kFloat));
1438
+ }
1439
+ if (num_heads_k != num_heads) { // MQA / GQA
1440
+ if (!is_varlen) {
1441
+ dk_accum = torch::zeros({batch_size, num_heads_k, seqlen_k_rounded * head_size_rounded}, opts.dtype(at::kFloat));
1442
+ dv_accum = torch::zeros({batch_size, num_heads_k, seqlen_k_rounded * head_size_v_rounded}, opts.dtype(at::kFloat));
1443
+ } else {
1444
+ dk_accum = torch::zeros({num_heads_k, total_k_padded_rounded, head_size_rounded}, opts.dtype(at::kFloat));
1445
+ dv_accum = torch::zeros({num_heads_k, total_k_padded_rounded, head_size_v_rounded}, opts.dtype(at::kFloat));
1446
+ }
1447
+ }
1448
+
1449
+ Flash_bwd_params params;
1450
+ set_params_dgrad(params,
1451
+ batch_size,
1452
+ seqlen_q, seqlen_k,
1453
+ seqlen_q_rounded, seqlen_k_rounded,
1454
+ num_heads, num_heads_k,
1455
+ head_size, head_size_rounded,
1456
+ q, k, v, out,
1457
+ dout, dq, dk, dv,
1458
+ !is_varlen_q ? nullptr : cu_seqlens_q.data_ptr(),
1459
+ !is_varlen_k ? nullptr : cu_seqlens_k.data_ptr(),
1460
+ seqused_q_.has_value() ? seqused_q_.value().data_ptr() : nullptr,
1461
+ seqused_k_.has_value() ? seqused_k_.value().data_ptr() : nullptr,
1462
+ dq_accum.data_ptr(),
1463
+ num_heads_k != num_heads ? dk_accum.data_ptr() : nullptr,
1464
+ num_heads_k != num_heads ? dv_accum.data_ptr() : nullptr,
1465
+ softmax_lse.data_ptr(),
1466
+ softmax_d.data_ptr(),
1467
+ /*p_dropout=*/0.f,
1468
+ softmax_scale,
1469
+ window_size_left,
1470
+ window_size_right,
1471
+ 0, // attention_chunk
1472
+ softcap,
1473
+ deterministic,
1474
+ sm_margin);
1475
+ params.total_q = total_q;
1476
+ params.total_k = total_k;
1477
+ params.softmax_lse_log2_ptr = softmax_lse_log2.data_ptr();
1478
+ params.dv = head_size_v;
1479
+ params.dv_rounded = head_size_v_rounded;
1480
+
1481
+ // auto tile_count_semaphore = (params.is_causal || params.is_local) ? torch::zeros({1}, opts.dtype(torch::kInt32)) : torch::empty({1}, opts.dtype(torch::kInt32));
1482
+ // params.tile_count_semaphore = tile_count_semaphore.data_ptr<int>();
1483
+ // Will be zero'ed out in the backward preprocess kernel
1484
+ at::Tensor dq_semaphore = torch::empty({(seqlen_q + kBlockM - 1) / kBlockM, batch_size, num_heads}, opts.dtype(torch::kInt32));
1485
+ params.dq_semaphore = dq_semaphore.data_ptr<int>();
1486
+ if (num_heads_k != num_heads && params.deterministic) {
1487
+ // TODO: do we need to zero them out?
1488
+ at::Tensor dk_semaphore = torch::empty({(seqlen_k + kBlockN - 1) / kBlockN, batch_size, num_heads_k}, opts.dtype(torch::kInt32));
1489
+ at::Tensor dv_semaphore = torch::empty({(seqlen_k + kBlockN - 1) / kBlockN, batch_size, num_heads_k}, opts.dtype(torch::kInt32));
1490
+ params.dk_semaphore = dk_semaphore.data_ptr<int>();
1491
+ params.dv_semaphore = dv_semaphore.data_ptr<int>();
1492
+ }
1493
+
1494
+ #ifdef FLASHATTENTION_DISABLE_LOCAL
1495
+ TORCH_CHECK(!params.is_local, "This flash attention build does not support local attention.");
1496
+ #endif
1497
+ #ifdef FLASHATTENTION_DISABLE_SOFTCAP
1498
+ TORCH_CHECK(params.softcap == 0.0, "This flash attention build does not support tanh softcapping.");
1499
+ #endif
1500
+
1501
+ if (total_q > 0 && total_k > 0 && num_heads_k > 0) {
1502
+ auto stream = at::cuda::getCurrentCUDAStream().stream();
1503
+ run_mha_bwd(params, stream);
1504
+ } else if (total_k > 0 && num_heads_k > 0) {
1505
+ // If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0.
1506
+ dk.zero_();
1507
+ dv.zero_();
1508
+ softmax_d.zero_();
1509
+ } else if (total_q > 0 && num_heads_k > 0) {
1510
+ dq.zero_();
1511
+ softmax_d.zero_();
1512
+ }
1513
+
1514
+ return { dq, dk, dv, softmax_d, softmax_lse_log2, dq_accum, dk_accum, dv_accum };
1515
+ }
1516
+
1517
+ std::tuple<at::Tensor, at::Tensor>
1518
+ mha_combine(at::Tensor out_partial, // num_splits x batch_size x seqlen x num_heads x head_size
1519
+ at::Tensor lse_partial, // num_splits x batch_size x seqlen x num_heads
1520
+ std::optional<at::Tensor> out_, // batch_size x seqlen x num_heads x head_size
1521
+ std::optional<at::ScalarType> out_dtype_
1522
+ ) {
1523
+
1524
+ auto dprops = at::cuda::getCurrentDeviceProperties();
1525
+ bool is_sm8x = dprops->major >= 8;
1526
+ TORCH_CHECK(is_sm8x, "Attention combine function only supports Ampere GPUs or newer.");
1527
+
1528
+ auto out_partial_type = out_partial.scalar_type();
1529
+ TORCH_CHECK(out_partial_type == at::ScalarType::Float, "Attention combine function only support fp32 data type");
1530
+ TORCH_CHECK(lse_partial.scalar_type() == at::ScalarType::Float, "Attention combine function only support fp32 data type");
1531
+
1532
+ CHECK_DEVICE(out_partial); CHECK_DEVICE(lse_partial);
1533
+
1534
+ TORCH_CHECK(out_partial.stride(-1) == 1, "Input tensor must have contiguous last dimension");
1535
+ TORCH_CHECK(lse_partial.stride(-2) == 1, "LSE tensor must be contiguous in the seqlen dimension");
1536
+
1537
+ const auto sizes = out_partial.sizes();
1538
+
1539
+ const int num_splits = sizes[0];
1540
+ const int batch_size = sizes[1];
1541
+ const int seqlen = sizes[2];
1542
+ const int num_heads = sizes[3];
1543
+ const int head_size_og = sizes[4];
1544
+ TORCH_CHECK(num_splits <= 256, "FlashAttention combine only supports num_splits at most 256");
1545
+
1546
+ CHECK_SHAPE(out_partial, num_splits, batch_size, seqlen, num_heads, head_size_og);
1547
+ CHECK_SHAPE(lse_partial, num_splits, batch_size, seqlen, num_heads);
1548
+
1549
+ int const alignment = 4;
1550
+ at::Tensor out_partial_padded;
1551
+ auto pad = [](at::Tensor x, int alignment) {
1552
+ return x.size(-1) % alignment == 0 ? x : torch::nn::functional::pad(x, torch::nn::functional::PadFuncOptions({0, alignment - x.size(-1) % alignment}));
1553
+ };
1554
+ out_partial_padded = pad(out_partial, alignment);
1555
+
1556
+ auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
1557
+ const int head_size = round_multiple(head_size_og, alignment);
1558
+
1559
+ auto opts = out_partial.options();
1560
+ at::ScalarType out_type = out_dtype_.value_or(out_partial.scalar_type());
1561
+ TORCH_CHECK(out_type == at::ScalarType::Float || out_type == at::ScalarType::BFloat16 || out_type == at::ScalarType::Half, "Output type must be FP32, FP16 or BF16");
1562
+ at::Tensor out;
1563
+ if (out_.has_value()) {
1564
+ out = out_.value();
1565
+ TORCH_CHECK(out.scalar_type() == out_type);
1566
+ CHECK_DEVICE(out);
1567
+ TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
1568
+ CHECK_SHAPE(out, batch_size, seqlen, num_heads, head_size_og);
1569
+ if (head_size_og % alignment != 0) {
1570
+ out = torch::empty({batch_size, seqlen, num_heads, head_size}, opts.dtype(out_type));
1571
+ }
1572
+ } else {
1573
+ out = torch::empty({batch_size, seqlen, num_heads, head_size}, opts.dtype(out_type));
1574
+ }
1575
+
1576
+ // Otherwise the kernel will be launched from cuda:0 device
1577
+ // Cast to char to avoid compiler warning about narrowing
1578
+ at::cuda::CUDAGuard device_guard{(char)out_partial.get_device()};
1579
+
1580
+ auto softmax_lse = torch::empty({batch_size, num_heads, seqlen}, opts.dtype(at::kFloat)).transpose(1, 2);
1581
+
1582
+ Flash_fwd_params params {}; // Need to reset the params to set everything to zero
1583
+ params.is_fp32 = out_type == at::ScalarType::Float;
1584
+ params.is_bf16 = out_type == at::ScalarType::BFloat16;
1585
+ params.oaccum_ptr = out_partial_padded.data_ptr();
1586
+ params.softmax_lseaccum_ptr = lse_partial.data_ptr();
1587
+ params.o_ptr = out.data_ptr();
1588
+ params.softmax_lse_ptr = softmax_lse.data_ptr();
1589
+ params.b = batch_size;
1590
+ params.h = num_heads;
1591
+ params.seqlen_q = seqlen;
1592
+ params.dv = head_size;
1593
+ params.num_splits = num_splits;
1594
+ params.oaccum_split_stride = out_partial_padded.stride(0);
1595
+ params.oaccum_row_stride = out_partial_padded.stride(2);
1596
+ params.oaccum_head_stride = out_partial_padded.stride(3);
1597
+ params.oaccum_batch_stride = out_partial_padded.stride(1);
1598
+ params.lseaccum_split_stride = lse_partial.stride(0);
1599
+ params.lseaccum_head_stride = lse_partial.stride(3);
1600
+ params.lseaccum_batch_stride = lse_partial.stride(1);
1601
+ params.o_row_stride = out.stride(1);
1602
+ params.o_head_stride = out.stride(2);
1603
+ params.o_batch_stride = out.stride(0);
1604
+ params.arch = at::cuda::getCurrentDeviceProperties()->major * 10 + at::cuda::getCurrentDeviceProperties()->minor;
1605
+
1606
+ if (seqlen > 0 && batch_size > 0) {
1607
+ auto stream = at::cuda::getCurrentCUDAStream().stream();
1608
+ run_mha_fwd_combine(params, stream, false /*enable_pdl*/);
1609
+ }
1610
+
1611
+ at::Tensor out_padded = out;
1612
+ if (head_size_og % alignment != 0) {
1613
+ out = out.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
1614
+ // if (out_.has_value()) { out_.value().copy_(out); }
1615
+ }
1616
+
1617
+ return {out, softmax_lse};
1618
+ }
1619
+
1620
+ #ifdef false
1621
+
1622
+ TORCH_LIBRARY(flash_attn_3, m) {
1623
+ m.def("fwd("
1624
+ "Tensor q,"
1625
+ "Tensor k,"
1626
+ "Tensor v,"
1627
+ "Tensor(k_new!)? k_new = None,"
1628
+ "Tensor(v_new!)? v_new = None,"
1629
+ "Tensor? q_v = None,"
1630
+ "Tensor(out!)? out = None,"
1631
+ "Tensor? cu_seqlens_q = None,"
1632
+ "Tensor? cu_seqlens_k = None,"
1633
+ "Tensor? cu_seqlens_k_new = None,"
1634
+ "Tensor? seqused_q = None,"
1635
+ "Tensor? seqused_k = None,"
1636
+ "int? max_seqlen_q = None,"
1637
+ "int? max_seqlen_k = None,"
1638
+ "Tensor? page_table = None,"
1639
+ "Tensor? kv_batch_idx = None,"
1640
+ "Tensor? leftpad_k = None,"
1641
+ "Tensor? rotary_cos = None,"
1642
+ "Tensor? rotary_sin = None,"
1643
+ "Tensor? seqlens_rotary = None,"
1644
+ "Tensor? q_descale = None,"
1645
+ "Tensor? k_descale = None,"
1646
+ "Tensor? v_descale = None,"
1647
+ "float? softmax_scale = None,"
1648
+ "bool is_causal = False,"
1649
+ "int window_size_left = -1,"
1650
+ "int window_size_right = -1,"
1651
+ "int attention_chunk = 0,"
1652
+ "float softcap = 0.0,"
1653
+ "bool is_rotary_interleaved = False,"
1654
+ "Tensor? scheduler_metadata = None,"
1655
+ "int num_splits = 0,"
1656
+ "bool? pack_gqa = None,"
1657
+ "int sm_margin = 0) -> (Tensor(out!), Tensor, Tensor, Tensor)");
1658
+ m.def("bwd("
1659
+ "Tensor dout,"
1660
+ "Tensor q,"
1661
+ "Tensor k,"
1662
+ "Tensor v,"
1663
+ "Tensor out,"
1664
+ "Tensor softmax_lse,"
1665
+ "Tensor(dq!)? dq = None,"
1666
+ "Tensor(dk!)? dk = None,"
1667
+ "Tensor(dv!)? dv = None,"
1668
+ "Tensor? cu_seqlens_q = None,"
1669
+ "Tensor? cu_seqlens_k = None,"
1670
+ "Tensor? seqused_q = None,"
1671
+ "Tensor? seqused_k = None,"
1672
+ "int? max_seqlen_q = None,"
1673
+ "int? max_seqlen_k = None,"
1674
+ "float? softmax_scale = None,"
1675
+ "bool is_causal = False,"
1676
+ "int window_size_left = -1,"
1677
+ "int window_size_right = -1,"
1678
+ "float softcap = 0.0,"
1679
+ "bool deterministic = False,"
1680
+ "int sm_margin = 0) -> (Tensor(dq!), Tensor(dk!), Tensor(dv!), Tensor, Tensor, Tensor, Tensor, Tensor)");
1681
+ m.def("fwd_combine("
1682
+ "Tensor out_partial,"
1683
+ "Tensor lse_partial,"
1684
+ "Tensor(out!)? out = None,"
1685
+ "ScalarType? out_dtype = None) -> (Tensor(out!), Tensor)");
1686
+ m.def("get_scheduler_metadata("
1687
+ "int batch_size,"
1688
+ "int max_seqlen_q,"
1689
+ "int max_seqlen_k,"
1690
+ "int num_heads,"
1691
+ "int num_heads_k,"
1692
+ "int headdim,"
1693
+ "int headdim_v,"
1694
+ "ScalarType qkv_dtype,"
1695
+ "Tensor seqused_k,"
1696
+ "Tensor? cu_seqlens_q = None,"
1697
+ "Tensor? cu_seqlens_k = None,"
1698
+ "Tensor? cu_seqlens_k_new = None,"
1699
+ "Tensor? seqused_q = None,"
1700
+ "Tensor? leftpad_k = None,"
1701
+ "int? page_size = None,"
1702
+ "int max_seqlen_k_new = 0,"
1703
+ "bool is_causal = False,"
1704
+ "int window_size_left = -1,"
1705
+ "int window_size_right = -1,"
1706
+ "int attention_chunk = 0,"
1707
+ "bool has_softcap = False,"
1708
+ "int num_splits = 0,"
1709
+ "bool? pack_gqa = None,"
1710
+ "int sm_margin = 0) -> Tensor");
1711
+ }
1712
+
1713
+ TORCH_LIBRARY_IMPL(flash_attn_3, CUDA, m) {
1714
+ m.impl("fwd", &mha_fwd);
1715
+ m.impl("bwd", &mha_bwd);
1716
+ m.impl("fwd_combine", &mha_combine);
1717
+ m.impl("get_scheduler_metadata", &mha_fwd_get_scheduler_metadata);
1718
+ }
1719
+
1720
+ #endif
flash-attn/flash_bwd_kernel_sm80.h ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2024, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ #pragma once
6
+
7
+ #include "cute/tensor.hpp"
8
+
9
+ #include <cutlass/cutlass.h>
10
+ #include <cutlass/array.h>
11
+ #include <cutlass/numeric_types.h>
12
+ #include <cutlass/kernel_hardware_info.h>
13
+
14
+ #include "utils.h"
15
+
16
+ namespace flash {
17
+
18
+ using namespace cute;
19
+
20
+ template <class CollectiveMainloop_, class CollectiveEpilogue_, class TileScheduler_>
21
+ class FlashAttnBwdSm80 {
22
+
23
+ public:
24
+
25
+ // Type Aliases
26
+ static constexpr bool Is_causal = CollectiveMainloop_::Is_causal;
27
+ static constexpr bool Is_local = CollectiveMainloop_::Is_local;
28
+ static_assert(CollectiveMainloop_::Varlen == CollectiveEpilogue_::Varlen);
29
+ static constexpr bool Varlen = CollectiveMainloop_::Varlen;
30
+
31
+ // Mainloop derived types
32
+ using CollectiveMainloop = CollectiveMainloop_;
33
+ using TileShape_MNK = typename CollectiveMainloop::TileShape_MNK;
34
+ using TiledMmaSdP = typename CollectiveMainloop::TiledMmaSdP;
35
+ using TiledMmadKV = typename CollectiveMainloop::TiledMmadKV;
36
+ using ArchTag = typename CollectiveMainloop::ArchTag;
37
+ using MainloopArguments = typename CollectiveMainloop::Arguments;
38
+ using MainloopParams = typename CollectiveMainloop::Params;
39
+ static constexpr bool dKV_swapAB = CollectiveMainloop::dKV_swapAB;
40
+
41
+ // Epilogue derived types
42
+ using CollectiveEpilogue = CollectiveEpilogue_;
43
+ using EpilogueArguments = typename CollectiveEpilogue::Arguments;
44
+ using EpilogueParams = typename CollectiveEpilogue::Params;
45
+
46
+ static_assert(ArchTag::kMinComputeCapability >= 80);
47
+
48
+ using TileScheduler = TileScheduler_;
49
+ using TileSchedulerArguments = typename flash::TileSchedulerArguments;
50
+ using TileSchedulerParams = typename TileScheduler::Params;
51
+
52
+ static constexpr uint32_t NumThreads = CUTE_STATIC_V(size(TiledMmaSdP{}));
53
+ static constexpr uint32_t MaxThreadsPerBlock = CUTE_STATIC_V(size(TiledMmaSdP{}));
54
+ static constexpr uint32_t MinBlocksPerMultiprocessor = 1;
55
+
56
+ // Kernel level shared memory storage
57
+ struct SharedStorage {
58
+ struct TensorStorage : cute::aligned_struct<128> {
59
+ union {
60
+ typename CollectiveMainloop::TensorStorage mainloop;
61
+ typename CollectiveEpilogue::TensorStorage epilogue;
62
+ };
63
+ } tensors;
64
+
65
+ alignas(16) typename TileScheduler::SharedStorage smem_scheduler;
66
+
67
+ };
68
+
69
+ static constexpr int SharedStorageSize = sizeof(SharedStorage);
70
+
71
+ // Device side arguments
72
+ struct Arguments {
73
+ MainloopArguments mainloop{};
74
+ EpilogueArguments epilogue{};
75
+ cutlass::KernelHardwareInfo hw_info{};
76
+ TileSchedulerArguments scheduler{};
77
+ };
78
+
79
+ // Kernel entry point API
80
+ struct Params {
81
+ MainloopParams mainloop{};
82
+ EpilogueParams epilogue{};
83
+ cutlass::KernelHardwareInfo hw_info{};
84
+ TileSchedulerParams scheduler{};
85
+ };
86
+
87
+ //
88
+ // Methods
89
+ //
90
+
91
+ // Convert to underlying arguments. In this case, a simple copy for the aliased type.
92
+ static
93
+ Params
94
+ to_underlying_arguments(Arguments const& args) {
95
+ CUTLASS_TRACE_HOST("to_underlying_arguments():");
96
+
97
+ // Get SM count if needed, otherwise use user supplied SM count
98
+ int sm_count = args.hw_info.sm_count;
99
+ if (sm_count <= 0) {
100
+ CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n"
101
+ " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count.");
102
+ sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id);
103
+ }
104
+
105
+ CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count);
106
+
107
+ cutlass::KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count};
108
+ return {
109
+ CollectiveMainloop::to_underlying_arguments(args.mainloop),
110
+ CollectiveEpilogue::to_underlying_arguments(args.epilogue),
111
+ hw_info,
112
+ TileScheduler::to_underlying_arguments(args.scheduler)
113
+ };
114
+ }
115
+
116
+ // Computes the kernel launch grid shape based on runtime parameters
117
+ static dim3
118
+ get_grid_shape(Params const& params) {
119
+ return TileScheduler::get_grid_shape(params.scheduler, params.hw_info.sm_count);
120
+ }
121
+
122
+ static dim3
123
+ get_block_shape() {
124
+ return dim3(MaxThreadsPerBlock, 1, 1);
125
+ }
126
+
127
+ CUTLASS_DEVICE
128
+ void
129
+ operator()(Params const& params, char* smem_buf) {
130
+
131
+ static constexpr int kBlockM = get<0>(TileShape_MNK{});
132
+ static constexpr int kBlockN = get<1>(TileShape_MNK{});
133
+
134
+ SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(smem_buf);
135
+
136
+ CollectiveMainloop mainloop;
137
+ CollectiveEpilogue epilogue;
138
+
139
+ TileScheduler scheduler(reinterpret_cast<typename TileScheduler::SharedStorage*>(&shared_storage.smem_scheduler));
140
+ // Initialize matmul objects.
141
+ TiledMmadKV tiled_mma_dKV;
142
+
143
+ scheduler.init_consumer();
144
+
145
+ int warp_idx = cutlass::canonical_warp_idx_sync();
146
+ CUTLASS_PRAGMA_NO_UNROLL
147
+ for (auto work_tile_info = warp_idx == 0 ? scheduler.template get_initial_work</*IsProducerWarp=*/true>(params.scheduler) : scheduler.template get_initial_work</*IsProducerWarp=*/false>(params.scheduler);
148
+ work_tile_info.is_valid(params.scheduler);
149
+ work_tile_info = warp_idx == 0 ? scheduler.template get_next_work</*IsProducerWarp=*/true>(params.scheduler, work_tile_info) : scheduler.template get_next_work</*IsProducerWarp=*/false>(params.scheduler, work_tile_info)) {
150
+
151
+ auto block_coord_ = work_tile_info.get_block_coord(params.scheduler);
152
+ auto [n_block, bidh, bidb, _ /*split_idx*/] = block_coord_;
153
+ cute::tuple<int32_t, int32_t, int32_t> block_coord = {n_block, bidh, bidb};
154
+
155
+ // dK and dV output accumulator.
156
+ Tensor tdKrdK = partition_fragment_C(tiled_mma_dKV, select<!dKV_swapAB ? 1 : 2, !dKV_swapAB? 2 : 1>(TileShape_MNK{}));
157
+ Tensor tdVrdV = partition_fragment_C(tiled_mma_dKV, select<!dKV_swapAB ? 1 : 2, !dKV_swapAB? 2 : 1>(TileShape_MNK{}));
158
+ bool tile_valid = mainloop.mma(params.mainloop, tdKrdK, tdVrdV, threadIdx.x,
159
+ block_coord, shared_storage);
160
+ scheduler.prefetch_next_work(params.scheduler, work_tile_info);
161
+ if (tile_valid) {
162
+ epilogue.store(params.epilogue, tdKrdK, tdVrdV, shared_storage, tiled_mma_dKV,
163
+ threadIdx.x, block_coord);
164
+ } else {
165
+ epilogue.store_zero(params.epilogue, threadIdx.x, block_coord);
166
+ }
167
+ }
168
+
169
+ }
170
+
171
+ };
172
+
173
+ } // namespace flash
flash-attn/flash_bwd_kernel_sm90.h ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ /******************************************************************************
3
+ * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
4
+ ******************************************************************************/
5
+
6
+ #pragma once
7
+
8
+ #include "cute/tensor.hpp"
9
+
10
+ #include <cutlass/cutlass.h>
11
+ #include <cutlass/arch/reg_reconfig.h>
12
+ #include <cutlass/array.h>
13
+ #include <cutlass/numeric_types.h>
14
+ #include <cutlass/numeric_conversion.h>
15
+ #include <cutlass/kernel_hardware_info.h>
16
+ #include "cutlass/pipeline/pipeline.hpp"
17
+
18
+ #include "utils.h"
19
+
20
+ namespace flash {
21
+
22
+ using namespace cute;
23
+
24
+ template <class CollectiveMainloop_, class CollectiveEpilogue_, class TileScheduler_>
25
+ class FlashAttnBwdSm90 {
26
+
27
+ public:
28
+
29
+ // Type Aliases
30
+ static constexpr bool Is_causal = CollectiveMainloop_::Is_causal;
31
+ static constexpr bool Is_local = CollectiveMainloop_::Is_local;
32
+ static_assert(CollectiveMainloop_::Varlen == CollectiveEpilogue_::Varlen);
33
+ static constexpr bool Varlen = CollectiveMainloop_::Varlen;
34
+
35
+ // Mainloop derived types
36
+ using CollectiveMainloop = CollectiveMainloop_;
37
+ using TileShape_MNK = typename CollectiveMainloop::TileShape_MNK;
38
+ using TiledMmaSdP = typename CollectiveMainloop::TiledMmaSdP;
39
+ using TiledMmadKV = typename CollectiveMainloop::TiledMmadKV;
40
+ using ArchTag = typename CollectiveMainloop::ArchTag;
41
+ using ClusterShape = typename CollectiveMainloop::ClusterShape;
42
+ using MainloopArguments = typename CollectiveMainloop::Arguments;
43
+ using MainloopParams = typename CollectiveMainloop::Params;
44
+ static constexpr bool dKV_swapAB = CollectiveMainloop::dKV_swapAB;
45
+
46
+ // Epilogue derived types
47
+ using CollectiveEpilogue = CollectiveEpilogue_;
48
+ using EpilogueArguments = typename CollectiveEpilogue::Arguments;
49
+ using EpilogueParams = typename CollectiveEpilogue::Params;
50
+
51
+ static_assert(ArchTag::kMinComputeCapability >= 90);
52
+
53
+ using TileScheduler = TileScheduler_;
54
+ using TileSchedulerArguments = typename flash::TileSchedulerArguments;
55
+ using TileSchedulerParams = typename TileScheduler::Params;
56
+
57
+ static constexpr uint32_t NumLoadWarpGroups = 1;
58
+ static constexpr uint32_t NumMmaWarpGroups = CUTE_STATIC_V(size(TiledMmaSdP{})) / cutlass::NumThreadsPerWarpGroup;
59
+ static constexpr uint32_t MaxThreadsPerBlock = CUTE_STATIC_V(size(TiledMmaSdP{})) + (NumLoadWarpGroups * cutlass::NumThreadsPerWarpGroup);
60
+ static constexpr uint32_t MinBlocksPerMultiprocessor = 1;
61
+ static_assert(NumMmaWarpGroups == 2 || NumMmaWarpGroups == 3);
62
+
63
+ /// Register requirement for Load and Math WGs
64
+ static constexpr uint32_t LoadRegisterRequirement = NumMmaWarpGroups == 2 ? 24 : 32;
65
+ static constexpr uint32_t MmaRegisterRequirement = NumMmaWarpGroups == 2 ? 240 : 160;
66
+ // If you want to print from the producer warp, you'd need to increase the number of registers
67
+ // Otherwise you'll get CUDA error.
68
+ // static constexpr uint32_t LoadRegisterRequirement = 40;
69
+ // static constexpr uint32_t MmaRegisterRequirement = NumMmaWarpGroups == 2 ? 232 : 152;
70
+
71
+ // Kernel level shared memory storage
72
+ struct SharedStorage {
73
+ struct TensorStorage : cute::aligned_struct<128> {
74
+ union {
75
+ typename CollectiveMainloop::TensorStorage mainloop;
76
+ typename CollectiveEpilogue::TensorStorage epilogue;
77
+ };
78
+ } tensors;
79
+
80
+ struct PipelineStorage : cute::aligned_struct<16> {
81
+ alignas(16) cutlass::arch::ClusterTransactionBarrier barrier_KV;
82
+ alignas(16) typename CollectiveMainloop::MainloopPipeline::SharedStorage pipeline_q;
83
+ alignas(16) typename CollectiveMainloop::MainloopPipeline_dO::SharedStorage pipeline_do;
84
+ alignas(16) typename TileScheduler::SharedStorage smem_scheduler;
85
+ } pipelines;
86
+
87
+ };
88
+
89
+ static constexpr int SharedStorageSize = sizeof(SharedStorage);
90
+
91
+ // Device side arguments
92
+ struct Arguments {
93
+ MainloopArguments mainloop{};
94
+ EpilogueArguments epilogue{};
95
+ cutlass::KernelHardwareInfo hw_info{};
96
+ TileSchedulerArguments scheduler{};
97
+ };
98
+
99
+ // Kernel entry point API
100
+ struct Params {
101
+ MainloopParams mainloop{};
102
+ EpilogueParams epilogue{};
103
+ cutlass::KernelHardwareInfo hw_info{};
104
+ TileSchedulerParams scheduler{};
105
+ };
106
+
107
+ //
108
+ // Methods
109
+ //
110
+
111
+ // Convert to underlying arguments. In this case, a simple copy for the aliased type.
112
+ static
113
+ Params
114
+ to_underlying_arguments(Arguments const& args) {
115
+ CUTLASS_TRACE_HOST("to_underlying_arguments():");
116
+
117
+ // Get SM count if needed, otherwise use user supplied SM count
118
+ int sm_count = args.hw_info.sm_count;
119
+ if (sm_count <= 0) {
120
+ CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n"
121
+ " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count.");
122
+ sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id);
123
+ }
124
+
125
+ CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count);
126
+
127
+ cutlass::KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count};
128
+ return {
129
+ CollectiveMainloop::to_underlying_arguments(args.mainloop),
130
+ CollectiveEpilogue::to_underlying_arguments(args.epilogue),
131
+ hw_info,
132
+ TileScheduler::to_underlying_arguments(args.scheduler)
133
+ };
134
+ }
135
+
136
+ // Computes the kernel launch grid shape based on runtime parameters
137
+ static dim3
138
+ get_grid_shape(Params const& params) {
139
+ return TileScheduler::get_grid_shape(params.scheduler, params.hw_info.sm_count);
140
+ }
141
+
142
+ static dim3
143
+ get_block_shape() {
144
+ return dim3(MaxThreadsPerBlock, 1, 1);
145
+ }
146
+
147
+ CUTLASS_DEVICE
148
+ void
149
+ operator()(Params const& params, char* smem_buf) {
150
+
151
+ static constexpr int NumMmaThreads = NumMmaWarpGroups * cutlass::NumThreadsPerWarpGroup;
152
+ static constexpr int NumCopyThreads = NumLoadWarpGroups * cutlass::NumThreadsPerWarpGroup;
153
+ static constexpr int kBlockM = get<0>(TileShape_MNK{});
154
+ static constexpr int kBlockN = get<1>(TileShape_MNK{});
155
+
156
+ using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline;
157
+ using PipelineParams = typename MainloopPipeline::Params;
158
+ using PipelineState = typename MainloopPipeline::PipelineState;
159
+ using MainloopPipeline_dO = typename CollectiveMainloop::MainloopPipeline_dO;
160
+ using PipelineParams_dO = typename MainloopPipeline_dO::Params;
161
+ using PipelineState_dO = typename MainloopPipeline_dO::PipelineState;
162
+ static constexpr bool Q_dO_same_stages = std::is_same_v<MainloopPipeline, MainloopPipeline_dO>;
163
+
164
+ SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(smem_buf);
165
+
166
+ int const lane_predicate = cute::elect_one_sync();
167
+ int const warp_idx = cutlass::canonical_warp_idx_sync();
168
+
169
+ // Issue Tma Descriptor Prefetch from a single thread
170
+ if (warp_idx == 0 && lane_predicate) {
171
+ CollectiveMainloop::prefetch_tma_descriptors(params.mainloop);
172
+ CollectiveEpilogue::prefetch_tma_descriptors(params.epilogue);
173
+ }
174
+
175
+ // Obtain warp index
176
+ int const warp_group_thread_idx = threadIdx.x % cutlass::NumThreadsPerWarpGroup;
177
+
178
+ PipelineParams pipeline_params;
179
+ pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytesQ + CollectiveMainloop::TmaTransactionBytesLSE;
180
+ int warp_group_idx = cutlass::canonical_warp_group_idx();
181
+ pipeline_params.role = warp_group_idx == 0
182
+ ? MainloopPipeline::ThreadCategory::Producer
183
+ : MainloopPipeline::ThreadCategory::Consumer;
184
+ pipeline_params.is_leader = warp_group_thread_idx == 0;
185
+ pipeline_params.num_consumers = NumMmaThreads;
186
+
187
+ if (warp_idx == 0 && lane_predicate) {
188
+ shared_storage.pipelines.barrier_KV.init(1 /*numThreads*/);
189
+ }
190
+ // We're counting on pipeline_q to call cutlass::arch::fence_barrier_init();
191
+ MainloopPipeline pipeline_q(shared_storage.pipelines.pipeline_q, pipeline_params, ClusterShape{});
192
+ auto role_dO = warp_group_idx == 0
193
+ ? MainloopPipeline_dO::ThreadCategory::Producer
194
+ : MainloopPipeline_dO::ThreadCategory::Consumer;
195
+ PipelineParams_dO pipeline_params_dO {pipeline_params.transaction_bytes, role_dO, pipeline_params.is_leader, pipeline_params.num_consumers};
196
+ MainloopPipeline_dO pipeline_do(shared_storage.pipelines.pipeline_do, cute::conditional_return<Q_dO_same_stages>(pipeline_params, pipeline_params_dO), ClusterShape{});
197
+
198
+ CollectiveMainloop mainloop;
199
+ CollectiveEpilogue epilogue;
200
+
201
+ // We need this to guarantee that the Pipeline init is visible to all producers and consumer blocks in the Cluster
202
+ if constexpr (size(ClusterShape{}) > 1) {
203
+ cute::cluster_arrive_relaxed();
204
+ cute::cluster_wait();
205
+ } else {
206
+ __syncthreads();
207
+ }
208
+
209
+ TileScheduler scheduler(reinterpret_cast<typename TileScheduler::SharedStorage*>(&shared_storage.pipelines.smem_scheduler));
210
+
211
+ if (warp_group_idx == 0) { // Producer
212
+ cutlass::arch::warpgroup_reg_dealloc<LoadRegisterRequirement>();
213
+
214
+ int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0);
215
+ if (warp_idx_in_warpgroup == 0) { // Load K, V, and do TMA on Q and dO
216
+ PipelineState smem_pipe_write = cutlass::make_producer_start_state<MainloopPipeline>();
217
+ PipelineState_dO smem_pipe_write_do = cutlass::make_producer_start_state<MainloopPipeline_dO>();
218
+ for (auto work_tile_info = scheduler.template get_initial_work</*IsProducerWarp=*/true>(params.scheduler);
219
+ work_tile_info.is_valid(params.scheduler);
220
+ work_tile_info = scheduler.template get_next_work</*IsProducerWarp=*/true>(params.scheduler, work_tile_info)) {
221
+ auto block_coord_ = work_tile_info.get_block_coord(params.scheduler);
222
+ auto [n_block, bidh, bidb, _ /*split_idx*/] = block_coord_;
223
+ cute::tuple<int32_t, int32_t, int32_t> block_coord = {n_block, bidh, bidb};
224
+ auto scheduler_prefetch = [&scheduler, &params, &work_tile_info]() {
225
+ scheduler.prefetch_next_work(params.scheduler, work_tile_info);
226
+ };
227
+ mainloop.load(params.mainloop, pipeline_q, pipeline_do, smem_pipe_write,
228
+ smem_pipe_write_do, shared_storage, scheduler_prefetch, block_coord);
229
+ }
230
+ mainloop.load_tail(pipeline_q, pipeline_do, smem_pipe_write, smem_pipe_write_do);
231
+ } else if (warp_idx_in_warpgroup == 1) {
232
+ for (auto work_tile_info = scheduler.template get_initial_work</*IsProducerWarp=*/false>(params.scheduler);
233
+ work_tile_info.is_valid(params.scheduler);
234
+ work_tile_info = scheduler.template get_next_work</*IsProducerWarp=*/false>(params.scheduler, work_tile_info)) {
235
+ auto block_coord_ = work_tile_info.get_block_coord(params.scheduler);
236
+ auto [n_block, bidh, bidb, _ /*split_idx*/] = block_coord_;
237
+ cute::tuple<int32_t, int32_t, int32_t> block_coord = {n_block, bidh, bidb};
238
+ mainloop.store_dq(params.mainloop, shared_storage, block_coord);
239
+ }
240
+ }
241
+ } else { // Consumer
242
+ cutlass::arch::warpgroup_reg_alloc<MmaRegisterRequirement>();
243
+ // Initialize matmul objects.
244
+ TiledMmadKV tiled_mma_dKV;
245
+
246
+ PipelineState smem_pipe_read;
247
+ PipelineState_dO smem_pipe_read_do;
248
+
249
+ mainloop.mma_init();
250
+ scheduler.init_consumer();
251
+
252
+ int work_idx = 0;
253
+ CUTLASS_PRAGMA_NO_UNROLL
254
+ for (auto work_tile_info = scheduler.template get_initial_work</*IsProducerWarp=*/false>(params.scheduler);
255
+ work_tile_info.is_valid(params.scheduler);
256
+ work_tile_info = scheduler.template get_next_work</*IsProducerWarp=*/false>(params.scheduler, work_tile_info)) {
257
+ auto block_coord_ = work_tile_info.get_block_coord(params.scheduler);
258
+ auto [n_block, bidh, bidb, _ /*split_idx*/] = block_coord_;
259
+ cute::tuple<int32_t, int32_t, int32_t> block_coord = {n_block, bidh, bidb};
260
+
261
+ // dK and dV output accumulator.
262
+ Tensor tdKrdK = partition_fragment_C(tiled_mma_dKV, select<!dKV_swapAB ? 1 : 2, !dKV_swapAB? 2 : 1>(TileShape_MNK{}));
263
+ Tensor tdVrdV = partition_fragment_C(tiled_mma_dKV, select<!dKV_swapAB ? 1 : 2, !dKV_swapAB? 2 : 1>(TileShape_MNK{}));
264
+ bool tile_valid = mainloop.mma(
265
+ params.mainloop, pipeline_q, pipeline_do, smem_pipe_read, smem_pipe_read_do,
266
+ tdKrdK, tdVrdV, threadIdx.x - NumCopyThreads, work_idx, block_coord, shared_storage);
267
+ if (tile_valid) {
268
+ epilogue.store(params.epilogue, tdKrdK, tdVrdV, shared_storage, tiled_mma_dKV,
269
+ threadIdx.x - NumCopyThreads, block_coord);
270
+ } else {
271
+ epilogue.store_zero(params.epilogue, threadIdx.x - NumCopyThreads, block_coord);
272
+ }
273
+
274
+ }
275
+ epilogue.store_tail();
276
+ }
277
+
278
+ }
279
+
280
+ };
281
+
282
+ } // namespace flash
flash-attn/flash_bwd_launch_template.h ADDED
@@ -0,0 +1,390 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ #pragma once
6
+
7
+ #include "cute/tensor.hpp"
8
+
9
+ #include "cutlass/device_kernel.h" // For device_kernel
10
+ #include "cutlass/kernel_launch.h" // For kernel_launch
11
+ #include "cutlass/cluster_launch.hpp" // For ClusterLauncher
12
+
13
+ #include "static_switch.h"
14
+ #include "flash.h"
15
+ #include "flash_bwd_preprocess_kernel.h"
16
+ #include "flash_bwd_postprocess_kernel.h"
17
+ #include "tile_scheduler.hpp"
18
+ #include "mainloop_bwd_sm90_tma_gmma_ws.hpp"
19
+ #include "mainloop_bwd_sm80.hpp"
20
+ #include "epilogue_bwd.hpp"
21
+ #include "flash_bwd_kernel_sm90.h"
22
+ #include "flash_bwd_kernel_sm80.h"
23
+
24
+ using namespace cute;
25
+
26
+ template <int Arch, int kHeadDim, int kBlockM, int kBlockN, typename Element,
27
+ bool Is_causal, bool Is_local, bool Has_softcap, bool Varlen, bool Deterministic, bool GQA,
28
+ int Stages_dO=2, int Stages_dS_or_QSm80=2,
29
+ bool SdP_swapAB=true, bool dKV_swapAB=false, bool dQ_swapAB=false,
30
+ int NumMmaWarpGroups=2, int AtomLayoutMSdP=1, int AtomLayoutNdKV=2, int AtomLayoutMdQ=1,
31
+ bool V_in_regs=false>
32
+ void run_flash_bwd(Flash_bwd_params &params, cudaStream_t stream) {
33
+ static_assert(!(Is_causal && Is_local), "Is_causal and Is_local cannot be true at the same time.");
34
+ using ElementAccum = float;
35
+ using ArchTag = std::conditional_t<Arch >= 90, cutlass::arch::Sm90, cutlass::arch::Sm80>;
36
+
37
+ int const total_q_padded_rounded = cute::round_up(params.total_q + params.b * kBlockM, kBlockM);
38
+ int const total_k_padded_rounded = cute::round_up(params.total_k + params.b * kBlockN, kBlockN);
39
+ bool const is_varlen_q = params.cu_seqlens_q;
40
+ bool const is_varlen_k = params.cu_seqlens_k;
41
+ int seqlen_q = !is_varlen_q ? params.seqlen_q : params.total_q;
42
+ int seqlen_k = !is_varlen_k ? params.seqlen_k : params.total_k;
43
+ int seqlen_q_rounded = !is_varlen_q ? params.seqlen_q_rounded : total_q_padded_rounded;
44
+ int seqlen_k_rounded = !is_varlen_k ? params.seqlen_k_rounded : total_k_padded_rounded;
45
+ int batch_q = !is_varlen_q ? params.b : 1;
46
+ int batch_k = !is_varlen_k ? params.b : 1;
47
+
48
+ using TileShape_MK = cute::Shape<Int<kBlockM>, Int<kHeadDim>>;
49
+ using PreprocessKernel = flash::FlashAttnBwdPreprocess<TileShape_MK, Element, ElementAccum, ArchTag, /*Clear_dQaccum=*/true, Varlen>;
50
+ typename PreprocessKernel::Arguments preprocess_args {
51
+ static_cast<Element const*>(params.o_ptr),
52
+ {seqlen_q, params.dv, params.h, batch_q}, // shape_O
53
+ {params.o_row_stride, _1{}, params.o_head_stride, !is_varlen_q ? params.o_batch_stride : 0}, // stride_O
54
+ static_cast<Element const*>(params.do_ptr),
55
+ {params.do_row_stride, _1{}, params.do_head_stride, !is_varlen_q ? params.do_batch_stride : 0}, // stride_dO
56
+ static_cast<float*>(params.dsoftmax_sum),
57
+ {seqlen_q_rounded, params.h, batch_q}, // shape_dPsum
58
+ {_1{}, seqlen_q_rounded, !is_varlen_q ? params.h * params.seqlen_q_rounded : 0}, // stride_dPsum
59
+ static_cast<float*>(params.softmax_lse_ptr),
60
+ {_1{}, seqlen_q, !is_varlen_q ? params.h * params.seqlen_q : 0}, // stride_LSE
61
+ static_cast<float*>(params.softmax_lse_log2_ptr),
62
+ {_1{}, seqlen_q_rounded, !is_varlen_q ? params.h * params.seqlen_q_rounded : 0}, // stride_LSE_log2
63
+ static_cast<ElementAccum*>(params.dq_accum_ptr),
64
+ {seqlen_q_rounded * params.d_rounded, params.h, batch_q}, // shape_dQaccum
65
+ {_1{}, seqlen_q_rounded * params.d_rounded, !is_varlen_q ? params.d_rounded * seqlen_q_rounded * params.h : 0}, // stride_dQaccum
66
+ params.b,
67
+ params.dq_semaphore,
68
+ params.cu_seqlens_q,
69
+ params.seqused_q
70
+ };
71
+ typename PreprocessKernel::Params preprocess_params = PreprocessKernel::to_underlying_arguments(preprocess_args);
72
+ int num_m_block = cute::ceil_div(params.seqlen_q, kBlockM);
73
+ dim3 grid_m(num_m_block, params.h, params.b);
74
+ cutlass::kernel_launch<PreprocessKernel>(grid_m, PreprocessKernel::MaxThreadsPerBlock, PreprocessKernel::SharedStorageSize, stream, preprocess_params, false /*launch_with_pdl*/);
75
+ CHECK_CUDA_KERNEL_LAUNCH();
76
+
77
+ using TileShape_MNK = cute::Shape<Int<kBlockM>, Int<kBlockN>, Int<kHeadDim>>;
78
+ using ClusterShape = cute::Shape<_1, Int<1>, _1>; // Currently doesn't not support cluster
79
+ // Stages_dS_or_QSm80 is Stages_dS if Sm90 and Stages if Sm80
80
+ static constexpr int Stages = Arch >= 90 ? 2 : Stages_dS_or_QSm80;
81
+ static constexpr int Stages_dS = Arch >= 90 ? Stages_dS_or_QSm80 : 1;
82
+ using CollectiveMainloop = std::conditional_t<
83
+ Arch >= 90,
84
+ flash::CollectiveMainloopBwdSm90<Stages, Stages_dO, Stages_dS, ClusterShape, TileShape_MNK, Element, ElementAccum, cutlass::arch::Sm90,
85
+ Is_causal, Is_local, Has_softcap, Varlen, Deterministic,
86
+ SdP_swapAB, dKV_swapAB, dQ_swapAB, NumMmaWarpGroups, AtomLayoutMSdP, AtomLayoutNdKV, AtomLayoutMdQ, V_in_regs>,
87
+ flash::CollectiveMainloopBwdSm80<Stages, Stages_dO, TileShape_MNK, Element, ElementAccum, cutlass::arch::Sm80,
88
+ Is_causal, Is_local, Has_softcap, Varlen, Deterministic,
89
+ SdP_swapAB, dKV_swapAB, dQ_swapAB, NumMmaWarpGroups, AtomLayoutMSdP, AtomLayoutNdKV, AtomLayoutMdQ, V_in_regs>
90
+ >;
91
+ using CollectiveEpilogue = std::conditional_t<
92
+ !GQA,
93
+ flash::CollectiveEpilogueBwd<TileShape_MNK, Element, ArchTag, CollectiveMainloop::NumMmaThreads, Varlen, dKV_swapAB, NumMmaWarpGroups * (Arch >= 90 ? 1 : cutlass::NumWarpsPerWarpGroup) / AtomLayoutNdKV>,
94
+ flash::CollectiveEpilogueBwdGQA<TileShape_MNK, ElementAccum, ArchTag, CollectiveMainloop::NumMmaThreads, Varlen, Deterministic>
95
+ >;
96
+ using Scheduler = std::conditional_t<
97
+ Is_causal && !Varlen,
98
+ flash::SingleTileBwdLPTScheduler,
99
+ flash::SingleTileScheduler<Varlen, false /*Split*/, false /*PackGQA*/, kBlockN>
100
+ >;
101
+ using AttnKernel = std::conditional_t<
102
+ Arch >= 90,
103
+ flash::enable_sm90_or_later<flash::FlashAttnBwdSm90<CollectiveMainloop, CollectiveEpilogue, Scheduler>>,
104
+ flash::enable_sm80_to_sm89<flash::FlashAttnBwdSm80<CollectiveMainloop, CollectiveEpilogue, Scheduler>>
105
+ >;
106
+
107
+ typename CollectiveMainloop::Arguments mainloop_args {
108
+ static_cast<Element const*>(params.q_ptr),
109
+ {seqlen_q, params.d, params.h, batch_q}, // shape_Q
110
+ {params.q_row_stride, _1{}, params.q_head_stride, !is_varlen_q ? params.q_batch_stride : 0}, // stride_Q
111
+ static_cast<Element const*>(params.k_ptr),
112
+ {seqlen_k, params.d, params.h_k, batch_k}, // shape_K
113
+ {params.k_row_stride, _1{}, params.k_head_stride, !is_varlen_k ? params.k_batch_stride : 0}, // stride_K
114
+ static_cast<Element const*>(params.v_ptr),
115
+ {seqlen_k, params.dv, params.h_k, batch_k}, // shape_V
116
+ {params.v_row_stride, _1{}, params.v_head_stride, !is_varlen_k ? params.v_batch_stride : 0}, // stride_V
117
+ static_cast<Element const*>(params.do_ptr),
118
+ {seqlen_q, params.dv, params.h, batch_q}, // shape_dO
119
+ {params.do_row_stride, _1{}, params.do_head_stride, !is_varlen_q ? params.do_batch_stride : 0}, // stride_dO
120
+ static_cast<ElementAccum*>(params.dq_accum_ptr),
121
+ {seqlen_q_rounded * params.d_rounded, params.h, batch_q}, // shape_dQaccum
122
+ {_1{}, seqlen_q_rounded * params.d_rounded, !is_varlen_q ? params.d_rounded * params.seqlen_q_rounded * params.h : 0}, // stride_dQaccum
123
+ static_cast<float*>(params.softmax_lse_log2_ptr),
124
+ {seqlen_q_rounded, params.h, batch_q}, // shape_LSE
125
+ {_1{}, seqlen_q_rounded, !is_varlen_q ? params.h * params.seqlen_q_rounded : 0}, // stride_LSE_log2
126
+ static_cast<float*>(params.dsoftmax_sum),
127
+ {_1{}, seqlen_q_rounded, !is_varlen_q ? params.h * params.seqlen_q_rounded : 0}, // stride_dPsum
128
+ params.scale_softmax,
129
+ params.window_size_left, params.window_size_right, 0 /*attention_chunk*/,
130
+ params.softcap,
131
+ params.b,
132
+ params.dq_semaphore,
133
+ params.cu_seqlens_q, params.cu_seqlens_k,
134
+ params.seqused_q, params.seqused_k
135
+ };
136
+ // The case work with GQA is ugly but idk how to fix it.
137
+ typename CollectiveEpilogue::Arguments epilogue_args {
138
+ static_cast<typename CollectiveEpilogue::Element*>(!GQA ? params.dk_ptr : params.dk_accum_ptr),
139
+ [&] {
140
+ if constexpr (!GQA) {
141
+ return typename CollectiveEpilogue::ShapedKV {seqlen_k, params.d, params.h, batch_k}; // shape_dK
142
+ } else {
143
+ return typename CollectiveEpilogue::ShapedKV {seqlen_k_rounded * params.d_rounded, params.h_k, batch_k}; // shape_dKaccum
144
+ }
145
+ }(),
146
+ [&] {
147
+ if constexpr (!GQA) {
148
+ return typename CollectiveEpilogue::StridedKV {params.dk_row_stride, _1{}, params.dk_head_stride, !is_varlen_k ? params.dk_batch_stride : 0}; // stride_dK
149
+ } else {
150
+ return typename CollectiveEpilogue::StridedKV {_1{}, params.d_rounded * seqlen_k_rounded, !is_varlen_k ? params.h_k * params.d_rounded * params.seqlen_k_rounded : 0}; // stride_dKaccum
151
+ }
152
+ }(),
153
+ static_cast<typename CollectiveEpilogue::Element*>(!GQA ? params.dv_ptr : params.dv_accum_ptr),
154
+ [&] {
155
+ if constexpr (!GQA) {
156
+ return typename CollectiveEpilogue::ShapedKV {seqlen_k, params.dv, params.h, batch_k}; // shape_dV
157
+ } else {
158
+ return typename CollectiveEpilogue::ShapedKV {seqlen_k_rounded * params.dv_rounded, params.h_k, batch_k}; // shape_dVaccum
159
+ }
160
+ }(),
161
+ [&] {
162
+ if constexpr (!GQA) {
163
+ return typename CollectiveEpilogue::StridedKV {params.dv_row_stride, _1{}, params.dv_head_stride, !is_varlen_k ? params.dv_batch_stride : 0}; // stride_dV
164
+ } else {
165
+ return typename CollectiveEpilogue::StridedKV {_1{}, params.dv_rounded * seqlen_k_rounded, !is_varlen_k ? params.h_k * params.dv_rounded * params.seqlen_k_rounded : 0}; // stride_dVaccum
166
+ }
167
+ }(),
168
+ params.h,
169
+ params.dk_semaphore,
170
+ params.dv_semaphore,
171
+ params.cu_seqlens_k,
172
+ params.seqused_k,
173
+ };
174
+
175
+ int num_blocks_n = cutlass::ceil_div(params.seqlen_k, get<1>(TileShape_MNK{}));
176
+ num_blocks_n = cutlass::round_up(num_blocks_n, size<1>(ClusterShape{}));
177
+ typename flash::TileSchedulerArguments scheduler_args {
178
+ num_blocks_n, params.h, params.b, 1 /*num_splits*/,
179
+ params.h / params.h_k,
180
+ params.seqlen_k,
181
+ params.seqlen_q, params.d, params.dv, sizeof(Element),
182
+ params.tile_count_semaphore, params.cu_seqlens_k, params.seqused_k
183
+ };
184
+
185
+ int device;
186
+ cudaGetDevice(&device);
187
+ typename AttnKernel::Params kernel_params = AttnKernel::to_underlying_arguments({
188
+ mainloop_args, epilogue_args, {device, params.num_sm}, scheduler_args
189
+ });
190
+
191
+ dim3 grid_dims = AttnKernel::get_grid_shape(kernel_params);
192
+ dim3 block_dims = AttnKernel::get_block_shape();
193
+ int smem_size = AttnKernel::SharedStorageSize;
194
+ // int smem_size_q = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_q));
195
+ // int smem_size_do = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_do));
196
+ // int smem_size_ds = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_ds));
197
+ // int smem_size_dqacc = [&] {
198
+ // if constexpr (Arch >= 90) {
199
+ // return sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_dqacc));
200
+ // } else {
201
+ // return 0;
202
+ // }
203
+ // }();
204
+ // int smem_size_k = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_k));
205
+ // int smem_size_v = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_v));
206
+ // int smem_size_lse = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_lse));
207
+ // int smem_size_dpsum = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_dpsum));
208
+ // printf("smem_size = %d, q = %d, k = %d, v = %d, do = %d, ds = %d, dqacc = %d, lse = %d, dpsum = %d\n", smem_size, smem_size_q, smem_size_k, smem_size_v, smem_size_do, smem_size_ds, smem_size_dqacc, smem_size_lse, smem_size_dpsum);
209
+ if constexpr (size(ClusterShape{}) > 1) {
210
+ void const* kernel = (void const*) cutlass::device_kernel<AttnKernel>;
211
+ if (smem_size >= 48 * 1024) {
212
+ CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
213
+ }
214
+ dim3 cluster_dims(size<0>(ClusterShape{}), size<1>(ClusterShape{}), size<2>(ClusterShape{}));
215
+ cutlass::ClusterLauncher::launch(
216
+ grid_dims, cluster_dims, block_dims, smem_size, stream, kernel, kernel_params, false /*launch_with_pdl*/);
217
+ } else {
218
+ if (smem_size >= 48 * 1024) {
219
+ CHECK_CUDA(cudaFuncSetAttribute(cutlass::device_kernel<AttnKernel>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
220
+ }
221
+ cutlass::kernel_launch<AttnKernel>(grid_dims, block_dims, smem_size, stream, kernel_params, false /*launch_with_pdl*/);
222
+ }
223
+ CHECK_CUDA_KERNEL_LAUNCH();
224
+
225
+ using PostprocessKernel = flash::FlashAttnBwdPostprocessConvertdQ<TileShape_MK, Element, ElementAccum, ArchTag,
226
+ AttnKernel::CollectiveMainloop::NumMmaThreads,
227
+ typename AttnKernel::CollectiveMainloop::TiledMmadQ,
228
+ AttnKernel::CollectiveMainloop::dQ_swapAB
229
+ >;
230
+ typename PostprocessKernel::Arguments postprocess_args {
231
+ static_cast<ElementAccum const*>(params.dq_accum_ptr),
232
+ {seqlen_q_rounded * params.d_rounded, params.h, batch_q}, // shape_dQaccum
233
+ {_1{}, seqlen_q_rounded * params.d_rounded, !is_varlen_q ? params.d_rounded * params.seqlen_q_rounded * params.h : 0}, // stride_dQaccum
234
+ static_cast<Element*>(params.dq_ptr),
235
+ {seqlen_q, params.d, params.h, batch_q}, // shape_dQ
236
+ {params.dq_row_stride, _1{}, params.dq_head_stride, params.dq_batch_stride}, // stride_dQ
237
+ params.scale_softmax,
238
+ params.cu_seqlens_q,
239
+ params.seqused_q
240
+ };
241
+ typename PostprocessKernel::Params postprocess_params = PostprocessKernel::to_underlying_arguments(postprocess_args);
242
+ int num_m_block_postprocess = cute::ceil_div(params.seqlen_q, get<0>(TileShape_MK{}));
243
+ dim3 grid_m_postprocess(num_m_block_postprocess, params.h, params.b);
244
+ int smem_size_postprocess = PostprocessKernel::SharedStorageSize;
245
+ if (smem_size_postprocess >= 48 * 1024) {
246
+ CHECK_CUDA(cudaFuncSetAttribute(cutlass::device_kernel<PostprocessKernel>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_postprocess));
247
+ }
248
+ cutlass::kernel_launch<PostprocessKernel>(grid_m_postprocess, PostprocessKernel::MaxThreadsPerBlock, smem_size_postprocess, stream, postprocess_params, false /*launch_with_pdl*/);
249
+ CHECK_CUDA_KERNEL_LAUNCH();
250
+
251
+ if constexpr (GQA) {
252
+ using TileShape_NK = cute::Shape<Int<kBlockN>, Int<kHeadDim>>;
253
+ using PostprocessKerneldKV = flash::FlashAttnBwdPostprocessConvertdQ<TileShape_NK, Element, ElementAccum, ArchTag,
254
+ AttnKernel::CollectiveEpilogue::NumEpilogueThreads,
255
+ typename AttnKernel::CollectiveMainloop::TiledMmadKV,
256
+ AttnKernel::CollectiveMainloop::dKV_swapAB
257
+ >;
258
+ typename PostprocessKerneldKV::Arguments postprocess_dK_args {
259
+ static_cast<ElementAccum const*>(params.dk_accum_ptr),
260
+ {seqlen_k_rounded * params.d_rounded, params.h_k, batch_k}, // shape_dKaccum
261
+ {_1{}, seqlen_k_rounded * params.d_rounded, !is_varlen_k ? params.d_rounded * params.seqlen_k_rounded * params.h_k : 0}, // stride_dKaccum
262
+ static_cast<Element*>(params.dk_ptr),
263
+ {seqlen_k, params.d, params.h_k, batch_k}, // shape_dK
264
+ {params.dk_row_stride, _1{}, params.dk_head_stride, params.dk_batch_stride}, // stride_dK
265
+ 1.f,
266
+ params.cu_seqlens_k,
267
+ params.seqused_k
268
+ };
269
+ typename PostprocessKerneldKV::Params postprocess_dK_params = PostprocessKerneldKV::to_underlying_arguments(postprocess_dK_args);
270
+ typename PostprocessKerneldKV::Arguments postprocess_dV_args {
271
+ static_cast<ElementAccum const*>(params.dv_accum_ptr),
272
+ {seqlen_k_rounded * params.dv_rounded, params.h_k, batch_k}, // shape_dVaccum
273
+ {_1{}, seqlen_k_rounded * params.dv_rounded, !is_varlen_k ? params.dv_rounded * params.seqlen_k_rounded * params.h_k : 0}, // stride_dVaccum
274
+ static_cast<Element*>(params.dv_ptr),
275
+ {seqlen_k, params.dv, params.h_k, batch_k}, // shape_dV
276
+ {params.dv_row_stride, _1{}, params.dv_head_stride, params.dv_batch_stride}, // stride_dV
277
+ 1.f,
278
+ params.cu_seqlens_k,
279
+ params.seqused_k
280
+ };
281
+ typename PostprocessKerneldKV::Params postprocess_dV_params = PostprocessKerneldKV::to_underlying_arguments(postprocess_dV_args);
282
+ int num_n_block_postprocess = cute::ceil_div(params.seqlen_k, get<0>(TileShape_NK{}));
283
+ dim3 grid_n_postprocess(num_n_block_postprocess, params.h_k, params.b);
284
+ int smem_size_postprocess = PostprocessKerneldKV::SharedStorageSize;
285
+ if (smem_size_postprocess >= 48 * 1024) {
286
+ CHECK_CUDA(cudaFuncSetAttribute(cutlass::device_kernel<PostprocessKerneldKV>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_postprocess));
287
+ }
288
+ cutlass::kernel_launch<PostprocessKerneldKV>(grid_n_postprocess, PostprocessKerneldKV::MaxThreadsPerBlock, smem_size_postprocess, stream, postprocess_dK_params, false /*launch_with_pdl*/);
289
+ CHECK_CUDA_KERNEL_LAUNCH();
290
+ cutlass::kernel_launch<PostprocessKerneldKV>(grid_n_postprocess, PostprocessKerneldKV::MaxThreadsPerBlock, smem_size_postprocess, stream, postprocess_dV_params, false /*launch_with_pdl*/);
291
+ CHECK_CUDA_KERNEL_LAUNCH();
292
+ }
293
+
294
+ }
295
+
296
+ template<int Arch, typename T, int kBlockM, int kBlockN, int kHeadDim, bool Is_causal, bool Is_local, bool Has_softcap,
297
+ int Stages_dO=2, int Stages_dS_or_QSm80=2,
298
+ bool SdP_swapAB=true, bool dKV_swapAB=false, bool dQ_swapAB=false,
299
+ int NumMmaWarpGroups=2, int AtomLayoutMSdP=1, int AtomLayoutNdKV=2, int AtomLayoutMdQ=1,
300
+ bool V_in_regs=false>
301
+ void run_mha_bwd_dispatch(Flash_bwd_params &params, cudaStream_t stream) {
302
+ VARLEN_SWITCH(params.cu_seqlens_q != nullptr || params.cu_seqlens_k != nullptr, Varlen, [&] {
303
+ BOOL_SWITCH(params.h != params.h_k, GQA, [&] {
304
+ // BOOL_SWITCH(params.deterministic, Deterministic, [&] {
305
+ // run_flash_bwd<kHeadDim, kBlockM, kBlockN, T, Is_causal, Is_local, Has_softcap, Varlen, false, GQA, Stages_dO, Stages_dS_or_QSm80, SdP_swapAB, dKV_swapAB, dQ_swapAB, NumMmaWarpGroups, AtomLayoutMSdP, AtomLayoutNdKV, AtomLayoutMdQ>(params, stream);
306
+ run_flash_bwd<Arch, kHeadDim, kBlockM, kBlockN, T, Is_causal, Is_local, Has_softcap, Varlen /*Varlen*/, false /*Deterministic*/, GQA, Stages_dO, Stages_dS_or_QSm80, SdP_swapAB, dKV_swapAB, dQ_swapAB, NumMmaWarpGroups, AtomLayoutMSdP, AtomLayoutNdKV, AtomLayoutMdQ, V_in_regs>(params, stream);
307
+ // });
308
+ });
309
+ });
310
+ }
311
+
312
+
313
+ template<int Arch, typename T, bool Has_softcap>
314
+ void run_mha_bwd_hdim64(Flash_bwd_params &params, cudaStream_t stream) {
315
+ CAUSAL_LOCAL_SWITCH(params.is_causal, params.is_local, Is_causal, Is_local, [&] {
316
+ if constexpr (Arch >= 90) {
317
+ if constexpr (Is_causal && Has_softcap) {
318
+ // register spill with 128 x 128
319
+ run_mha_bwd_dispatch<Arch, T, 96, 128, 64, Is_causal, Is_local, Has_softcap, 2, 2, true, false, true, 2, 1, 2, 2, false>(params, stream);
320
+ } else {
321
+ // With ShuffleStats we no longer have register spilling when Has_softcap and using 128 x 128 block.
322
+ run_mha_bwd_dispatch<Arch, T, 128, 128, 64, Is_causal, Is_local, Has_softcap, 2, 2, true, false, false, 2, 1, 2, 2, false>(params, stream);
323
+ }
324
+ } else if constexpr (Arch == 86 || Arch == 89) {
325
+ run_mha_bwd_dispatch<Arch, T, 64, 128, 64, Is_causal, Is_local, Has_softcap, 2, 2, false, false, false, 2, 2, 4, 2, true>(params, stream);
326
+ // run_mha_bwd_dispatch<Arch, T, 96, 96, 64, Is_causal, Is_local, Has_softcap, 1, 2, false, true, true, 2, 2, 4, 4, false>(params, stream);
327
+ // run_mha_bwd_dispatch<Arch, T, 80, 128, 64, Is_causal, Is_local, Has_softcap, 1, 2, true, false, true, 2, 2, 4, 2, true>(params, stream);
328
+ // run_mha_bwd_dispatch<Arch, T, 96, 128, 64, Is_causal, Is_local, Has_softcap, 1, 2, true, false, true, 2, 1, 8, 4, false>(params, stream);
329
+ } else {
330
+ run_mha_bwd_dispatch<Arch, T, 128, 128, 64, Is_causal, Is_local, Has_softcap, 2, 2, false, false, false, 2, 4, 4, 4, false>(params, stream);
331
+ }
332
+ });
333
+ }
334
+
335
+ template<int Arch, typename T, bool Has_softcap>
336
+ void run_mha_bwd_hdim96(Flash_bwd_params &params, cudaStream_t stream) {
337
+ CAUSAL_LOCAL_SWITCH(params.is_causal, params.is_local, Is_causal, Is_local, [&] {
338
+ if constexpr (Arch >= 90) {
339
+ run_mha_bwd_dispatch<Arch, T, 64, 128, 96, Is_causal, Is_local, Has_softcap, 2, 2, true, false, false, 2, 1, 2, 1, true>(params, stream);
340
+ } else if constexpr (Arch == 86 || Arch == 89) {
341
+ run_mha_bwd_dispatch<Arch, T, 64, 128, 96, Is_causal, Is_local, Has_softcap, 1, 2, false, false, false, 2, 2, 4, 2, true>(params, stream);
342
+ } else {
343
+ run_mha_bwd_dispatch<Arch, T, 64, 128, 96, Is_causal, Is_local, Has_softcap, 2, 2, false, false, false, 2, 2, 4, 2, false>(params, stream);
344
+ }
345
+ });
346
+ }
347
+
348
+ template<int Arch, typename T, bool Has_softcap>
349
+ void run_mha_bwd_hdim128(Flash_bwd_params &params, cudaStream_t stream) {
350
+ CAUSAL_LOCAL_SWITCH(params.is_causal, params.is_local, Is_causal, Is_local, [&] {
351
+ if constexpr (Arch >= 90) {
352
+ if constexpr (Is_causal || Is_local || Has_softcap) {
353
+ run_mha_bwd_dispatch<Arch, T, 64, 128, 128, Is_causal, Is_local, Has_softcap, 2, 2, true, false, false, 2, 1, 2, 1, false>(params, stream);
354
+ } else {
355
+ run_mha_bwd_dispatch<Arch, T, 80, 128, 128, Is_causal, Is_local, Has_softcap, 2, 2, true, false, true, 2, 1, 2, 1, false>(params, stream);
356
+ }
357
+ } else if constexpr (Arch == 86 || Arch == 89) {
358
+ run_mha_bwd_dispatch<Arch, T, 64, 96, 128, Is_causal, Is_local, Has_softcap, 1, 2, false, false, false, 2, 2, 2, 2, true>(params, stream);
359
+ } else {
360
+ run_mha_bwd_dispatch<Arch, T, 64, 128, 128, Is_causal, Is_local, Has_softcap, 2, 2, false, false, false, 2, 2, 2, 2, false>(params, stream);
361
+ }
362
+ });
363
+ }
364
+
365
+ template<int Arch, typename T, bool Has_softcap>
366
+ void run_mha_bwd_hdim192(Flash_bwd_params &params, cudaStream_t stream) {
367
+ CAUSAL_LOCAL_SWITCH(params.is_causal, params.is_local, Is_causal, Is_local, [&] {
368
+ if constexpr (Arch >= 90) {
369
+ run_mha_bwd_dispatch<Arch, T, 64, 96, 192, Is_causal, Is_local, Has_softcap, 1, 1, false, true, false, 3, 1, 1, 1, false>(params, stream);
370
+ } else if constexpr (Arch == 86 || Arch == 89) {
371
+ run_mha_bwd_dispatch<Arch, T, 64, 64, 192, Is_causal, Is_local, Has_softcap, 1, 1, false, false, false, 2, 2, 2, 2, true>(params, stream);
372
+ } else {
373
+ run_mha_bwd_dispatch<Arch, T, 64, 80, 192, Is_causal, Is_local, Has_softcap, 1, 2, false, true, false, 2, 4, 2, 2, false>(params, stream);
374
+ }
375
+ });
376
+ }
377
+
378
+ template<int Arch, typename T, bool Has_softcap>
379
+ void run_mha_bwd_hdim256(Flash_bwd_params &params, cudaStream_t stream) {
380
+ CAUSAL_LOCAL_SWITCH(params.is_causal, params.is_local, Is_causal, Is_local, [&] {
381
+ if constexpr (Arch >= 90) {
382
+ run_mha_bwd_dispatch<Arch, T, 64, 80, 256, Is_causal, Is_local, Has_softcap, 1, 1, false, true, true, 2, 1, 1, 1, false>(params, stream);
383
+ } else if constexpr (Arch == 86 || Arch == 89) {
384
+ run_mha_bwd_dispatch<Arch, T, 32, 64, 256, Is_causal, Is_local, Has_softcap, 1, 1, false, false, false, 2, 2, 2, 1, true>(params, stream);
385
+ // run_mha_bwd_dispatch<Arch, T, 64, 32, 256, Is_causal, Is_local, Has_softcap, 1, 1, false, false, false, 2, 4, 1, 2, true>(params, stream);
386
+ } else {
387
+ run_mha_bwd_dispatch<Arch, T, 64, 64, 256, Is_causal, Is_local, Has_softcap, 1, 1, false, false, false, 2, 4, 2, 2, false>(params, stream);
388
+ }
389
+ });
390
+ }
flash-attn/flash_bwd_postprocess_kernel.h ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ #pragma once
6
+
7
+ #include "cute/tensor.hpp"
8
+
9
+ #include <cutlass/cutlass.h>
10
+ #include <cutlass/array.h>
11
+ #include <cutlass/numeric_types.h>
12
+ #include <cutlass/numeric_conversion.h>
13
+ #include "cutlass/arch/barrier.h"
14
+
15
+ #include "seqlen.h"
16
+ #include "utils.h"
17
+
18
+ namespace flash {
19
+
20
+ using namespace cute;
21
+
22
+ template <class TileShape_MK_, class Element, class ElementAccum, class ArchTag_, int kNThreads, class TiledMma, bool dQ_swapAB>
23
+ class FlashAttnBwdPostprocessConvertdQ {
24
+
25
+ public:
26
+
27
+ // Type Aliases
28
+ using TileShape_MK = TileShape_MK_;
29
+ using ArchTag = ArchTag_;
30
+
31
+ static_assert(ArchTag::kMinComputeCapability >= 75);
32
+ static constexpr bool IsSm90 = ArchTag::kMinComputeCapability >= 90;
33
+
34
+ static constexpr uint32_t MaxThreadsPerBlock = kNThreads;
35
+ static constexpr uint32_t MinBlocksPerMultiprocessor = 2;
36
+
37
+ static constexpr int kBlockM = get<0>(TileShape_MK{});
38
+ static constexpr int kHeadDim = get<1>(TileShape_MK{});
39
+ static_assert(!IsSm90 || kNThreads % cutlass::NumThreadsPerWarpGroup == 0, "kNThreads must be a multiple of NumThreadsPerWarpGroup");
40
+ static constexpr int NumdQWarpGgroups = kNThreads / cutlass::NumThreadsPerWarpGroup;
41
+ using R2SLayoutAtomdQaccum = std::conditional_t<
42
+ IsSm90,
43
+ Layout<Shape<Int<cutlass::NumThreadsPerWarpGroup>, Int<NumdQWarpGgroups>>>,
44
+ Layout<Shape<Int<kNThreads>>>
45
+ >;
46
+ using R2STiledCopydQaccum = decltype(make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{}, R2SLayoutAtomdQaccum{},
47
+ Layout<Shape<Int<IsSm90 ? 4 : 1>>>{})); // Val layout, 1 or 4 vals per read
48
+ using G2SLayoutAtomdQaccum = Layout<Shape<Int<kNThreads>>>;
49
+ // UniversalCopy instead of AutoVectorizingCopyWithAssumedAlignment as the latter generates cp.async instructions
50
+ using G2STiledCopydQaccum = decltype(make_tiled_copy(Copy_Atom<UniversalCopy<uint128_t>, ElementAccum>{}, G2SLayoutAtomdQaccum{},
51
+ Layout<Shape<_4>>{})); // Val layout, 4 vals per read
52
+ // We don't do bound checking for the gmem -> smem load so we just assert here.
53
+ static_assert(IsSm90 || (kBlockM * kHeadDim) % (kNThreads * 4) == 0);
54
+ static constexpr int SmemdQaccumSize = size(TileShape_MK{});
55
+ using SmemLayoutdQaccumFlat = Layout<Shape<Int<SmemdQaccumSize>>>;
56
+ using SmemLayoutdQaccum = std::conditional_t<
57
+ IsSm90,
58
+ Layout<Shape<Int<kBlockM * kHeadDim / NumdQWarpGgroups>, Int<NumdQWarpGgroups>>>,
59
+ Layout<Shape<Int<kBlockM * kHeadDim>>>
60
+ >;
61
+
62
+ // We can't just use kHeadDim here. E.g. if MMA shape is 64 x 96 but split across 2 WGs,
63
+ // then setting kBlockKSmem to 32 will cause "Static shape_div failure".
64
+ // We want to treat it as 64 x 48, so kBlockKSmem should be 16.
65
+ static constexpr int MmaShapeN = get<1>(typename TiledMma::AtomShape_MNK{});
66
+ static constexpr int kBlockKSmem = MmaShapeN % 64 == 0 ? 64 : (MmaShapeN % 32 == 0 ? 32 : 16);
67
+ static constexpr int kSwizzle = kBlockKSmem == 64 ? 3 : (kBlockKSmem == 32 ? 2 : 1);
68
+ using SmemLayoutAtomdQ =
69
+ decltype(composition(Swizzle<kSwizzle, 3, 3>{},
70
+ Layout<Shape<Int<8>, Int<kBlockKSmem>>,
71
+ Stride<Int<kBlockKSmem>, _1>>{}));
72
+ using SmemLayoutdQ = decltype(tile_to_shape(SmemLayoutAtomdQ{}, TileShape_MK{}));
73
+ using SmemLayoutdQt =
74
+ decltype(cute::composition(SmemLayoutdQ{},
75
+ make_layout(make_shape(get<1>(TileShape_MK{}), get<0>(TileShape_MK{})),
76
+ make_stride(Int<get<0>(TileShape_MK{})>{}, _1{}))));
77
+
78
+ using SmemCopyAtomdQ = Copy_Atom<
79
+ std::conditional_t<
80
+ IsSm90,
81
+ std::conditional_t<!dQ_swapAB, cute::SM90_U32x4_STSM_N, cute::SM90_U16x8_STSM_T>,
82
+ AutoVectorizingCopyWithAssumedAlignment<128>
83
+ >,
84
+ Element>;
85
+
86
+ static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
87
+ static_assert(kHeadDim % kGmemElemsPerLoad == 0, "Headdim must be a multiple of kGmemElemsPerLoad");
88
+ static constexpr int kGmemThreadsPerRow = cutlass::gcd(kHeadDim / kGmemElemsPerLoad, int(MaxThreadsPerBlock));
89
+ static_assert(MaxThreadsPerBlock % kGmemThreadsPerRow == 0, "MaxThreadsPerBlock must be a multiple of kGmemThreadsPerRow");
90
+ using GmemLayoutAtom = Layout<Shape <Int<MaxThreadsPerBlock / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
91
+ Stride<Int<kGmemThreadsPerRow>, _1>>;
92
+ using GmemTiledCopy = decltype(
93
+ make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>{},
94
+ GmemLayoutAtom{},
95
+ Layout<Shape<_1, Int<kGmemElemsPerLoad>>>{})); // Val layout, 8 or 16 vals per load
96
+
97
+ struct SharedStorage : cute::aligned_struct<128> {
98
+ cute::array_aligned<ElementAccum, cute::cosize_v<SmemLayoutdQaccum>> smem_dqacc;
99
+ cute::array_aligned<Element, cute::cosize_v<SmemLayoutdQ>> smem_dq;
100
+ alignas(16) cutlass::arch::ClusterTransactionBarrier barrier_dQaccum;
101
+ };
102
+
103
+ static constexpr int SharedStorageSize = sizeof(SharedStorage);
104
+
105
+ using ShapedQ = cute::Shape<int32_t, int32_t, int32_t, int32_t>; // (seqlen_q, d, head, batch)
106
+ using StridedQ = cute::Stride<int64_t, _1, int64_t, int64_t>;
107
+ using ShapedQaccum = cute::Shape<int32_t, int32_t, int32_t>; // (seqlen_q * d, head, batch)
108
+ using StridedQaccum = cute::Stride<_1, int64_t, int64_t>;
109
+
110
+ // Device side arguments
111
+ struct Arguments {
112
+ ElementAccum const* ptr_dQaccum;
113
+ ShapedQaccum const shape_dQaccum;
114
+ StridedQaccum const stride_dQaccum;
115
+ Element* ptr_dQ;
116
+ ShapedQ const shape_dQ;
117
+ StridedQ const stride_dQ;
118
+ float const softmax_scale;
119
+ int const* cu_seqlens = nullptr;
120
+ int const* seqused = nullptr;
121
+ };
122
+
123
+ // Kernel entry point API
124
+ struct Params {
125
+ ElementAccum const* ptr_dQaccum;
126
+ ShapedQaccum const shape_dQaccum;
127
+ StridedQaccum const stride_dQaccum;
128
+ Element* ptr_dQ;
129
+ ShapedQ const shape_dQ;
130
+ StridedQ const stride_dQ;
131
+ float const softmax_scale;
132
+ int const* cu_seqlens = nullptr;
133
+ int const* seqused = nullptr;
134
+ };
135
+
136
+ // Convert to underlying arguments. In this case, a simple copy for the aliased type.
137
+ static
138
+ Params
139
+ to_underlying_arguments(Arguments const& args) {
140
+ return {
141
+ args.ptr_dQaccum,
142
+ args.shape_dQaccum,
143
+ args.stride_dQaccum,
144
+ args.ptr_dQ,
145
+ args.shape_dQ,
146
+ args.stride_dQ,
147
+ args.softmax_scale,
148
+ args.cu_seqlens,
149
+ args.seqused
150
+ };
151
+ }
152
+
153
+ CUTLASS_DEVICE
154
+ void
155
+ operator()(Params const& params, char* smem_buf) {
156
+
157
+ static constexpr int kBlockM = get<0>(TileShape_MK{});
158
+ SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(smem_buf);
159
+
160
+ Tensor sdQaccum = make_tensor(make_smem_ptr(shared_storage.smem_dqacc.data()), SmemLayoutdQaccum{});
161
+ Tensor sdQaccum_flat = make_tensor(make_smem_ptr(shared_storage.smem_dqacc.data()), SmemLayoutdQaccumFlat{});
162
+ Tensor sdQ = make_tensor(make_smem_ptr(shared_storage.smem_dq.data()), SmemLayoutdQ{});
163
+ Tensor sdQt = make_tensor(make_smem_ptr(shared_storage.smem_dq.data()), SmemLayoutdQt{});
164
+
165
+ int const thread_idx = threadIdx.x;
166
+ int const m_block = blockIdx.x;
167
+ int const bidh = blockIdx.y;
168
+ int const bidb = blockIdx.z;
169
+
170
+ flash::SeqlenInfo<true /*Varlen*/, kBlockM> seqlen_info(bidb, size<0>(params.shape_dQ), params.cu_seqlens, params.seqused);
171
+ bool const is_varlen = params.cu_seqlens;
172
+ if (is_varlen && m_block * kBlockM >= seqlen_info.seqlen) { return; }
173
+
174
+ // Step 1: load dQaccum from gmem to smem
175
+ Tensor mdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum const*>(params.ptr_dQaccum)),
176
+ params.shape_dQaccum, params.stride_dQaccum)(_, bidh, !is_varlen ? bidb : 0);
177
+ Tensor gdQaccum = local_tile(domain_offset(make_coord(seqlen_info.offset_padded * kHeadDim), mdQaccum), Shape<Int<kBlockM * kHeadDim>>{}, make_coord(m_block)); // (M * K)
178
+ if constexpr (IsSm90) { // Use BulkCopy
179
+ static constexpr uint32_t TmaTransactionBytesdQaccum = static_cast<uint32_t>(size(SmemLayoutdQaccumFlat{}) * cute::sizeof_bits_v<ElementAccum> / 8);
180
+ auto bulk_copy = Copy_Traits<SM90_BULK_COPY_AUTO>{};
181
+ // if (thread0()) { print(gdQaccum); printf("\n"); print(sdQaccum_flat); printf("\n"); }
182
+ if (thread_idx == 0) {
183
+ shared_storage.barrier_dQaccum.init(1 /*numThreads*/);
184
+ shared_storage.barrier_dQaccum.arrive_and_expect_tx(TmaTransactionBytesdQaccum);
185
+ copy(bulk_copy.with(*reinterpret_cast<uint64_t*>(&shared_storage.barrier_dQaccum)), gdQaccum, sdQaccum_flat);
186
+ }
187
+ __syncthreads();
188
+ shared_storage.barrier_dQaccum.wait(0);
189
+ } else {
190
+ G2STiledCopydQaccum g2s_tiled_copy_dQaccum;
191
+ auto g2s_thr_copy_dQaccum = g2s_tiled_copy_dQaccum.get_thread_slice(thread_idx);
192
+ Tensor tdQgdQaccumg2s = g2s_thr_copy_dQaccum.partition_S(gdQaccum);
193
+ Tensor tdQsdQaccumg2s = g2s_thr_copy_dQaccum.partition_D(sdQaccum);
194
+ cute::copy(g2s_tiled_copy_dQaccum, tdQgdQaccumg2s, tdQsdQaccumg2s);
195
+ __syncthreads();
196
+ }
197
+
198
+ // __syncthreads(); if (cute::thread0()) { print_tensor(sdQaccum); }
199
+
200
+ // Step 2: Load dQaccum from smem to register, then convert fp32 -> fp16/bf16
201
+ R2STiledCopydQaccum s2r_tiled_copy_dQaccum;
202
+ auto s2r_thr_copy_dQaccum = s2r_tiled_copy_dQaccum.get_thread_slice(thread_idx);
203
+ Tensor tdQsdQaccum = s2r_thr_copy_dQaccum.partition_S(sdQaccum);
204
+ TiledMma tiled_mma_dQ;
205
+ Tensor taccdQrdQaccum = partition_fragment_C(tiled_mma_dQ, select<!dQ_swapAB ? 0 : 1, !dQ_swapAB ? 1 : 0>(TileShape_MK{}));
206
+ // if (blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 1) { print(tiled_mma_dQ); printf("\n"); }
207
+ // if (blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 1) { print(tdQsdQaccum); }
208
+ // if (blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 1) { print(taccdQrdQaccum); }
209
+ CUTE_STATIC_ASSERT_V(size(taccdQrdQaccum) == size(tdQsdQaccum));
210
+ Tensor tdQrdQaccum = s2r_thr_copy_dQaccum.retile_D(taccdQrdQaccum);
211
+ cute::copy(s2r_tiled_copy_dQaccum, tdQsdQaccum, tdQrdQaccum);
212
+ #pragma unroll
213
+ for (int i = 0; i < size(taccdQrdQaccum); ++i) { taccdQrdQaccum(i) *= params.softmax_scale; }
214
+ // Convert tdQrdQ from fp32 to fp16
215
+ Tensor rdQ = make_tensor_like<Element>(taccdQrdQaccum);
216
+ flash::convert_type_out(taccdQrdQaccum, rdQ);
217
+
218
+ // Step 3: Copy dQ from register to smem
219
+ auto smem_tiled_copy_dQ = make_tiled_copy_C(SmemCopyAtomdQ{}, tiled_mma_dQ);
220
+ auto smem_thr_copy_dQ = smem_tiled_copy_dQ.get_thread_slice(thread_idx);
221
+ Tensor taccdQrdQ = smem_thr_copy_dQ.retile_S(rdQ); // ((Atom,AtomNum), MMA_N, MMA_N)
222
+ // if (cute::thread0()) { print(smem_tiled_copy_dQ); }
223
+ // if (cute::thread0()) { print(smem_thr_copy_dQ); }
224
+ // if (cute::thread0()) { print(sdQ); }
225
+ Tensor taccdQsdQ = smem_thr_copy_dQ.partition_D(cute::conditional_return<!dQ_swapAB>(sdQ, sdQt)); // ((Atom,AtomNum),PIPE_M,PIPE_N)
226
+ cute::copy(smem_tiled_copy_dQ, taccdQrdQ, taccdQsdQ);
227
+ __syncthreads();
228
+
229
+ // Step 4: Copy dQ from smem to register to prepare for coalesced write to gmem
230
+ Tensor mdQ = make_tensor(make_gmem_ptr(params.ptr_dQ), params.shape_dQ, params.stride_dQ)(_, _, bidh, !is_varlen ? bidb : 0);
231
+ Tensor gdQ = local_tile(domain_offset(make_coord(seqlen_info.offset, _0{}), mdQ), TileShape_MK{}, make_coord(m_block, _0{})); // (M, K)
232
+ GmemTiledCopy gmem_tiled_copy_dQ;
233
+ auto gmem_thr_copy_dQ = gmem_tiled_copy_dQ.get_thread_slice(thread_idx);
234
+ Tensor tdQsdQ = gmem_thr_copy_dQ.partition_S(sdQ); // ((Atom,AtomNum),ATOM_M,ATOM_N)
235
+ Tensor tdQgdQ = gmem_thr_copy_dQ.partition_D(gdQ);
236
+
237
+ Tensor tdQrdQ = make_fragment_like(tdQsdQ);
238
+ Tensor tdQcdQ = gmem_thr_copy_dQ.partition_D(cute::make_identity_tensor(TileShape_MK{}));
239
+ Tensor tdQpdQ = make_tensor<bool>(make_shape(size<2>(tdQgdQ)));
240
+ #pragma unroll
241
+ for (int k = 0; k < size(tdQpdQ); ++k) { tdQpdQ(k) = get<1>(tdQcdQ(_0{}, _0{}, k)) < get<1>(params.shape_dQ); }
242
+ // Need to check OOB when reading from smem if kBlockM isn't evenly tiled
243
+ static constexpr bool EvenM = kBlockM % CUTE_STATIC_V(size<0>(GmemLayoutAtom{})) == 0;
244
+ flash::copy</*Is_even_MN=*/EvenM, /*Is_even_K=*/true, /*Clear_OOB_MN=*/false>(
245
+ gmem_tiled_copy_dQ, tdQsdQ, tdQrdQ, tdQcdQ, tdQpdQ, kBlockM);
246
+
247
+ // Step 5: Copy dQ from register to gmem
248
+ // Clear_OOB_K must be false since we don't want to write zeros to gmem
249
+ flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
250
+ gmem_tiled_copy_dQ, tdQrdQ, tdQgdQ, tdQcdQ, tdQpdQ, std::min(seqlen_info.seqlen - m_block * kBlockM, kBlockM)
251
+ );
252
+ }
253
+
254
+ };
255
+
256
+ } // namespace flash
flash-attn/flash_bwd_preprocess_kernel.h ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ #pragma once
6
+
7
+ #include "cute/tensor.hpp"
8
+
9
+ #include <cutlass/cutlass.h>
10
+ #include <cutlass/array.h>
11
+ #include <cutlass/numeric_types.h>
12
+ #include <cutlass/numeric_conversion.h>
13
+
14
+ #include "seqlen.h"
15
+ #include "utils.h"
16
+
17
+ namespace flash {
18
+
19
+ using namespace cute;
20
+
21
+ template <class TileShape_MK_, class Element, class ElementAccum, class ArchTag_, bool Clear_dQaccum, bool Varlen>
22
+ class FlashAttnBwdPreprocess {
23
+
24
+ public:
25
+
26
+ // Type Aliases
27
+ using TileShape_MK = TileShape_MK_;
28
+ using ArchTag = ArchTag_;
29
+
30
+ static_assert(std::is_same_v<Element, cutlass::half_t> && ArchTag::kMinComputeCapability >= 75 ||
31
+ std::is_same_v<Element, cutlass::bfloat16_t> && ArchTag::kMinComputeCapability >= 80 ||
32
+ std::is_same_v<Element, cutlass::float_e4m3_t> && ArchTag::kMinComputeCapability >= 89);
33
+
34
+ static constexpr uint32_t MaxThreadsPerBlock = 256;
35
+ static constexpr uint32_t MinBlocksPerMultiprocessor = 2;
36
+ static constexpr int SharedStorageSize = 0;
37
+
38
+ static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
39
+ static_assert(get<1>(TileShape_MK{}) % kGmemElemsPerLoad == 0, "Headdim must be a multiple of kGmemElemsPerLoad");
40
+ static constexpr int kBlockM = get<0>(TileShape_MK{});
41
+ static constexpr int kHeadDim = get<1>(TileShape_MK{});
42
+ // We want kBlockKGmem to be a power of 2 so that when we do the summing,
43
+ // it's just between threads in the same warp
44
+ static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32);
45
+ static constexpr int kGmemThreadsPerRow = kBlockKGmem / kGmemElemsPerLoad;
46
+ static_assert(MaxThreadsPerBlock % kGmemThreadsPerRow == 0, "MaxThreadsPerBlock must be a multiple of kGmemThreadsPerRow");
47
+ using GmemLayoutAtom = Layout<Shape <Int<MaxThreadsPerBlock / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
48
+ Stride<Int<kGmemThreadsPerRow>, _1>>;
49
+ using GmemTiledCopy = decltype(
50
+ make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>{},
51
+ GmemLayoutAtom{},
52
+ Layout<Shape<_1, Int<kGmemElemsPerLoad>>>{})); // Val layout, 8 or 16 vals per load
53
+
54
+ static constexpr int kGmemElemsPerLoadAccum = sizeof(cute::uint128_t) / sizeof(ElementAccum);
55
+ static_assert((kBlockM * kHeadDim / kGmemElemsPerLoadAccum) % MaxThreadsPerBlock == 0, "MaxThreadsPerBlock must divide kBlockM * kHeadDim / kGmemElemsPerLoadAccum");
56
+ using GmemLayoutAtomAccum = Layout<Shape<Int<MaxThreadsPerBlock>>>;
57
+ using GmemTiledCopyAccum = decltype(
58
+ make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{},
59
+ GmemLayoutAtomAccum{},
60
+ Layout<Shape<Int<kGmemElemsPerLoadAccum>>>{})); // Val layout, 4 vals per store
61
+
62
+ using ShapeO = cute::Shape<int32_t, int32_t, int32_t, int32_t>; // (seqlen_q, d, head, batch)
63
+ using StrideO = cute::Stride<int64_t, _1, int64_t, int64_t>;
64
+ using ShapedPsum = cute::Shape<int32_t, int32_t, int32_t>; // (seqlen_q, head, batch)
65
+ using StridedPsum = cute::Stride<_1, int64_t, int64_t>;
66
+ using ShapedQaccum = cute::Shape<int32_t, int32_t, int32_t>; // (seqlen_q * d, head, batch)
67
+ using StridedQaccum = cute::Stride<_1, int64_t, int64_t>;
68
+
69
+ // Device side arguments
70
+ struct Arguments {
71
+ Element const* ptr_O;
72
+ ShapeO const shape_O;
73
+ StrideO const stride_O;
74
+ Element const* ptr_dO;
75
+ StrideO const stride_dO;
76
+ float* ptr_dPsum;
77
+ ShapedPsum const shape_dPsum;
78
+ StridedPsum const stride_dPsum;
79
+ float const* ptr_LSE;
80
+ StridedPsum const stride_LSE;
81
+ float *ptr_LSE_log2;
82
+ StridedPsum const stride_LSE_log2;
83
+ ElementAccum* ptr_dQaccum;
84
+ ShapedQaccum const shape_dQaccum;
85
+ StridedQaccum const stride_dQaccum;
86
+ int num_batch; // We need this to know the size of dq_semaphore in case of varlen
87
+ int* dq_semaphore;
88
+ int const* cu_seqlens = nullptr;
89
+ int const* seqused = nullptr;
90
+ };
91
+
92
+ // Kernel entry point API
93
+ struct Params {
94
+ Element const* ptr_O;
95
+ ShapeO const shape_O;
96
+ StrideO const stride_O;
97
+ Element const* ptr_dO;
98
+ StrideO const stride_dO;
99
+ float* ptr_dPsum;
100
+ ShapedPsum const shape_dPsum;
101
+ StridedPsum const stride_dPsum;
102
+ float const* ptr_LSE;
103
+ StridedPsum const stride_LSE;
104
+ float* ptr_LSE_log2;
105
+ StridedPsum const stride_LSE_log2;
106
+ ElementAccum* ptr_dQaccum;
107
+ ShapedQaccum const shape_dQaccum;
108
+ StridedQaccum const stride_dQaccum;
109
+ int num_batch;
110
+ int* dq_semaphore;
111
+ int const* cu_seqlens = nullptr;
112
+ int const* seqused = nullptr;
113
+ };
114
+
115
+ // Convert to underlying arguments. In this case, a simple copy for the aliased type.
116
+ static
117
+ Params
118
+ to_underlying_arguments(Arguments const& args) {
119
+ return {
120
+ args.ptr_O,
121
+ args.shape_O,
122
+ args.stride_O,
123
+ args.ptr_dO,
124
+ args.stride_dO,
125
+ args.ptr_dPsum,
126
+ args.shape_dPsum,
127
+ args.stride_dPsum,
128
+ args.ptr_LSE,
129
+ args.stride_LSE,
130
+ args.ptr_LSE_log2,
131
+ args.stride_LSE_log2,
132
+ args.ptr_dQaccum,
133
+ args.shape_dQaccum,
134
+ args.stride_dQaccum,
135
+ args.num_batch,
136
+ args.dq_semaphore,
137
+ args.cu_seqlens,
138
+ args.seqused
139
+ };
140
+ }
141
+
142
+ CUTLASS_DEVICE
143
+ void
144
+ operator()(Params const& params, [[maybe_unused]] char* smem_buf) {
145
+
146
+ static constexpr int kBlockM = get<0>(TileShape_MK{});
147
+
148
+ int const thread_idx = threadIdx.x;
149
+ int const m_block = blockIdx.x;
150
+ int const bidh = blockIdx.y;
151
+ int const bidb = blockIdx.z;
152
+
153
+ flash::SeqlenInfo<Varlen, kBlockM> seqlen_info(bidb, size<0>(params.shape_O), params.cu_seqlens, params.seqused);
154
+ bool const is_varlen = Varlen && params.cu_seqlens;
155
+ int const seqlen_o = seqlen_info.seqlen;
156
+ if (is_varlen && m_block * kBlockM >= seqlen_o) { return; }
157
+
158
+ Tensor mO = make_tensor(make_gmem_ptr(params.ptr_O), params.shape_O, params.stride_O)(_, _, bidh, !is_varlen ? bidb : 0);
159
+ Tensor gO = local_tile(cute::domain_offset(make_coord(seqlen_info.offset, _0{}), mO), TileShape_MK{}, make_coord(m_block, _0{})); // (M, K)
160
+ Tensor mdO = make_tensor(make_gmem_ptr(params.ptr_dO), params.shape_O, params.stride_dO)(_, _, bidh, !is_varlen ? bidb : 0);
161
+ Tensor gdO = local_tile(cute::domain_offset(make_coord(seqlen_info.offset, _0{}), mdO), TileShape_MK{}, make_coord(m_block, _0{})); // (M, K)
162
+
163
+ auto shape_LSE = select<0, 2, 3>(params.shape_O);
164
+ Tensor mLSE = make_tensor(make_gmem_ptr(params.ptr_LSE), shape_LSE, params.stride_LSE)(_, bidh, !is_varlen ? bidb : 0);
165
+ Tensor gLSE = local_tile(cute::domain_offset(make_coord(seqlen_info.offset), mLSE), Shape<Int<kBlockM>>{}, make_coord(m_block));
166
+ static_assert(kBlockM <= MaxThreadsPerBlock);
167
+ float lse = thread_idx < seqlen_o - m_block * kBlockM && thread_idx < kBlockM ? gLSE(thread_idx) : INFINITY;
168
+
169
+ GmemTiledCopy gmem_tiled_copy_O;
170
+ auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx);
171
+
172
+ Tensor tOgO = gmem_thr_copy_O.partition_S(gO);
173
+ Tensor tOgdO = gmem_thr_copy_O.partition_S(gdO);
174
+ // Construct identity layout for gO
175
+ Tensor cO = cute::make_identity_tensor(TileShape_MK{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
176
+ // Repeat the partitioning with identity layouts
177
+ Tensor tOcO = gmem_thr_copy_O.partition_D(cO);
178
+ Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOgO)));
179
+ #pragma unroll
180
+ for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < get<1>(params.shape_O); }
181
+
182
+ // (8, kBlockM / 32, kHeadDim / 64) or (8, kBlockM / 16, kHeadDim / 128)
183
+ Tensor tOrO = make_fragment_like(tOgO);
184
+ Tensor tOrdO = make_fragment_like(tOgdO);
185
+ flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/true, /*Clearn_OOB_K=*/true>(
186
+ gmem_tiled_copy_O, tOgO, tOrO, tOcO, tOpO, seqlen_o - m_block * kBlockM
187
+ );
188
+ flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/true, /*Clearn_OOB_K=*/true>(
189
+ gmem_tiled_copy_O, tOgdO, tOrdO, tOcO, tOpO, seqlen_o - m_block * kBlockM
190
+ );
191
+ // if (threadIdx.x == 222) { printf("bidx = %d, bidy = %d, bidz = %d, seqlen_o = %d, m_block = %d, seqlen_o - m_block * kBlockM = %d, tOgO addr = %p\n", blockIdx.x, blockIdx.y, blockIdx.z, seqlen_o, m_block, seqlen_o - m_block * kBlockM, &tOgO(0));}
192
+
193
+ // Reshape from e.g. (8, kBlockM / 32, kHeadDim / 64) to (kBlockM / 32, (8, kHeadDim / 64))
194
+ Layout l = make_layout(get<1>(tOrO.layout()), make_layout(get<0>(tOrO.layout()), get<2>(tOrO.layout())));
195
+ Tensor tOrO_l = make_tensor(tOrO.data(), l);
196
+ Tensor o_fp32 = make_tensor_like<float>(tOrO_l);
197
+ flash::convert_type_out(tOrO_l, o_fp32);
198
+ Tensor tOrdO_l = make_tensor(tOrdO.data(), l);
199
+ Tensor do_fp32 = make_tensor_like<float>(tOrdO_l);
200
+ flash::convert_type_out(tOrdO_l, do_fp32);
201
+ // Sum across the last dimension
202
+ Tensor dP_sum = make_tensor<float>(make_shape(size<0>(o_fp32)));
203
+ #pragma unroll
204
+ for (int mi = 0; mi < size<0>(o_fp32); ++mi) {
205
+ float dP_sum_cur = do_fp32(mi, 0) * o_fp32(mi, 0);
206
+ #pragma unroll
207
+ for (int ni = 1; ni < size<1>(o_fp32); ni++) {
208
+ dP_sum_cur += do_fp32(mi, ni) * o_fp32(mi, ni);
209
+ }
210
+ flash::SumOp<float> sum_op;
211
+ dP_sum(mi) = flash::Allreduce<kGmemThreadsPerRow>::run(dP_sum_cur, sum_op);
212
+ }
213
+
214
+ Tensor mdPsum = make_tensor(make_gmem_ptr(params.ptr_dPsum), params.shape_dPsum, params.stride_dPsum)(_, bidh, !is_varlen ? bidb : 0);
215
+ Tensor gdPsum = local_tile(cute::domain_offset(make_coord(seqlen_info.offset_padded), mdPsum), Shape<Int<kBlockM>>{}, make_coord(m_block));
216
+ if (get<1>(tOcO(_0{}, _0{}, _0{})) == 0) {
217
+ #pragma unroll
218
+ for (int mi = 0; mi < size(dP_sum); ++mi) {
219
+ int const row = get<0>(tOcO(_0{}, mi, _0{}));
220
+ gdPsum(row) = row < seqlen_o - m_block * kBlockM ? dP_sum(mi) : 0;
221
+ }
222
+ }
223
+
224
+ int const seqlen_rounded = cute::round_up(seqlen_o, kBlockM);
225
+ Tensor mLSElog2 = make_tensor(make_gmem_ptr(params.ptr_LSE_log2), params.shape_dPsum, params.stride_LSE_log2)(_, bidh, !is_varlen ? bidb : 0);
226
+ Tensor gLSElog2 = local_tile(cute::domain_offset(make_coord(seqlen_info.offset_padded), mLSElog2), Shape<Int<kBlockM>>{}, make_coord(m_block));
227
+ if (thread_idx < seqlen_rounded - m_block * kBlockM && thread_idx < kBlockM) {
228
+ gLSElog2(thread_idx) = lse == -INFINITY ? 0.f : lse * float(M_LOG2E);
229
+ }
230
+
231
+ if constexpr (Clear_dQaccum) {
232
+ Tensor mdQaccum = make_tensor(make_gmem_ptr(params.ptr_dQaccum), params.shape_dQaccum, params.stride_dQaccum)(_, bidh, !is_varlen ? bidb : 0);
233
+ Tensor gdQaccum = local_tile(cute::domain_offset(make_coord(seqlen_info.offset_padded * kHeadDim), mdQaccum), Shape<Int<kBlockM * kHeadDim>>{}, make_coord(m_block));
234
+ GmemTiledCopyAccum gmem_tiled_copy_dQaccum;
235
+ auto gmem_thr_copy_dQaccum = gmem_tiled_copy_dQaccum.get_thread_slice(thread_idx);
236
+ Tensor tdQgdQaccum = gmem_thr_copy_dQaccum.partition_D(gdQaccum);
237
+ Tensor zero = make_fragment_like(tdQgdQaccum);
238
+ clear(zero);
239
+ cute::copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{}, zero, tdQgdQaccum);
240
+ }
241
+
242
+ if (params.dq_semaphore != nullptr && thread_idx == 0) {
243
+ int const num_batch = params.num_batch;
244
+ int const num_head = get<2>(params.shape_O);
245
+ params.dq_semaphore[bidh + bidb * num_head + m_block * num_head * num_batch] = 0;
246
+ }
247
+
248
+ }
249
+
250
+ };
251
+
252
+ } // namespace flash
flash-attn/flash_fwd_combine.cu ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2024, Tri Dao.
2
+ // Splitting the different head dimensions to different files to speed up compilation.
3
+
4
+ #include "flash_fwd_combine_launch_template.h"
5
+
6
+ template void run_mha_fwd_combine_<float, float, 64>(Flash_fwd_params &params, cudaStream_t stream, bool enable_pdl);
7
+ template void run_mha_fwd_combine_<float, float, 128>(Flash_fwd_params &params, cudaStream_t stream, bool enable_pdl);
8
+
9
+ template void run_mha_fwd_combine_<cutlass::half_t, float, 64>(Flash_fwd_params &params, cudaStream_t stream, bool enable_pdl);
10
+ template void run_mha_fwd_combine_<cutlass::half_t, float, 128>(Flash_fwd_params &params, cudaStream_t stream, bool enable_pdl);
11
+
12
+ template void run_mha_fwd_combine_<cutlass::bfloat16_t, float, 64>(Flash_fwd_params &params, cudaStream_t stream, bool enable_pdl);
13
+ template void run_mha_fwd_combine_<cutlass::bfloat16_t, float, 128>(Flash_fwd_params &params, cudaStream_t stream, bool enable_pdl);
flash-attn/flash_fwd_combine_kernel.h ADDED
@@ -0,0 +1,482 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ #pragma once
6
+
7
+ #include "cute/tensor.hpp"
8
+
9
+ #include <cutlass/cutlass.h>
10
+ #include <cutlass/arch/memory.h>
11
+ #include <cutlass/array.h>
12
+ #include <cutlass/numeric_types.h>
13
+ #include <cutlass/numeric_conversion.h>
14
+
15
+ #include "cutlass/arch/grid_dependency_control.h"
16
+
17
+ #include "seqlen.h"
18
+ #include "utils.h"
19
+
20
+ namespace flash {
21
+
22
+ using namespace cute;
23
+
24
+ template <class TileShape_MK_, int kLogMaxSplits_, int kNThreads, int AlignmentLSE_,
25
+ bool Is_even_K, bool Varlen, class Element, class ElementPartial, class ArchTag_>
26
+ class FlashAttnFwdCombine {
27
+
28
+ public:
29
+
30
+ // Type Aliases
31
+ using TileShape_MK = TileShape_MK_;
32
+ using ArchTag = ArchTag_;
33
+ static constexpr int kMaxSplits = 1 << kLogMaxSplits_;
34
+ static constexpr int AlignmentLSE = std::min(AlignmentLSE_, int(128 / 8 / sizeof(float)));
35
+ static_assert(AlignmentLSE >= 1);
36
+ static constexpr int kStages = 4;
37
+
38
+ static_assert(ArchTag::kMinComputeCapability >= 75);
39
+ static constexpr bool Has_cp_async = ArchTag::kMinComputeCapability >= 80;
40
+
41
+ static constexpr uint32_t MaxThreadsPerBlock = kNThreads;
42
+ static constexpr uint32_t MinBlocksPerMultiprocessor = 2;
43
+
44
+ static constexpr int kBlockM = get<0>(TileShape_MK{});
45
+ static constexpr int kBlockK = get<1>(TileShape_MK{});
46
+
47
+ static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(ElementPartial);
48
+ static_assert(kBlockK % kGmemElemsPerLoad == 0, "kBlockK must be a multiple of kGmemElemsPerLoad");
49
+ static constexpr int kBlockKGmem = kBlockK % 128 == 0 ? 128 : (kBlockK % 64 == 0 ? 64 : 32);
50
+ static constexpr int kGmemThreadsPerRow = kBlockKGmem / kGmemElemsPerLoad;
51
+ static_assert(MaxThreadsPerBlock % kGmemThreadsPerRow == 0, "MaxThreadsPerBlock must be a multiple of kGmemThreadsPerRow");
52
+ using GmemCopyAtom = std::conditional_t<
53
+ Has_cp_async,
54
+ cute::Copy_Atom<SM80_CP_ASYNC_CACHEGLOBAL<uint128_t>, ElementPartial>,
55
+ cute::Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementPartial>
56
+ >;
57
+ using GmemLayoutAtom = Layout<Shape <Int<MaxThreadsPerBlock / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
58
+ Stride<Int<kGmemThreadsPerRow>, _1>>;
59
+ static_assert(kBlockM % CUTE_STATIC_V(shape<0>(GmemLayoutAtom{})) == 0);
60
+ using GmemTiledCopyAccum = decltype(
61
+ make_tiled_copy(GmemCopyAtom{},
62
+ GmemLayoutAtom{},
63
+ Layout<Shape<_1, Int<kGmemElemsPerLoad>>>{})); // Val layout, 4 vals per load
64
+ using GmemTiledCopy = decltype(
65
+ make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>{},
66
+ GmemLayoutAtom{},
67
+ Layout<Shape<_1, Int<kGmemElemsPerLoad>>>{})); // Val layout, 4 vals per load
68
+
69
+ using AlignmentTypeLSE = cute::uint_byte_t<static_cast<int>(sizeof(float)) * AlignmentLSE>;
70
+ static constexpr int kGmemElemsPerLoadLSE = sizeof(AlignmentTypeLSE) / sizeof(float);
71
+ static_assert(kBlockM % kGmemElemsPerLoadLSE == 0, "kBlockM must be a multiple of kGmemElemsPerLoadLSE");
72
+ static_assert(kBlockM % 8 == 0, "kBlockM must be a multiple of 8");
73
+ static constexpr int kBlockMSmem = kBlockM % 128 == 0 ? 128 : (kBlockM % 64 == 0 ? 64 : (kBlockM % 32 == 0 ? 32 : (kBlockM % 16 == 0 ? 16 : 8)));
74
+ static constexpr int kGmemThreadsPerRowLSE = kBlockMSmem / kGmemElemsPerLoadLSE;
75
+ static_assert(MaxThreadsPerBlock % kGmemThreadsPerRowLSE == 0, "MaxThreadsPerBlock must be a multiple of kGmemThreadsPerRowLSE");
76
+ using GmemLayoutAtomLSE = Layout<Shape <Int<MaxThreadsPerBlock / kGmemThreadsPerRowLSE>, Int<kGmemThreadsPerRowLSE>>,
77
+ Stride<Int<kGmemThreadsPerRowLSE>, _1>>;
78
+ static_assert(kMaxSplits % CUTE_STATIC_V(shape<0>(GmemLayoutAtomLSE{})) == 0);
79
+ using GmemCopyAtomLSE = std::conditional_t<
80
+ Has_cp_async,
81
+ cute::Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<AlignmentTypeLSE>, float>,
82
+ cute::Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<AlignmentLSE * sizeof(float) * 8>, float>
83
+ >;
84
+ using GmemTiledCopyLSE = decltype(
85
+ make_tiled_copy(GmemCopyAtomLSE{},
86
+ GmemLayoutAtomLSE{},
87
+ Layout<Shape<_1, Int<kGmemElemsPerLoadLSE>>>{})); // Val layout, 4 vals per load
88
+
89
+ // Otherwise we get IMA when some threads access sLSE, as we're not doing any masking
90
+ static_assert((kBlockM * kMaxSplits * AlignmentLSE) % kNThreads == 0, "kNThreads must divide kBlockM * kMaxSplits * AlignmentLSE");
91
+ // This works for kBlockMSmem = 8, 16, 32, 64, 128, no bank conflicts
92
+ using SmemLSESwizzle = std::conditional_t<
93
+ kBlockMSmem == 8,
94
+ Swizzle<5, 0, 5>,
95
+ std::conditional_t<kBlockMSmem == 16, Swizzle<4, 0, 4>, Swizzle<3, 2, 3>>
96
+ >;
97
+ using SmemLayoutAtomLSE =
98
+ decltype(composition(SmemLSESwizzle{},
99
+ Layout<Shape<Int<8>, Int<kBlockMSmem>>,
100
+ Stride<Int<kBlockMSmem>, _1>>{}));
101
+ using SmemLayoutLSE = decltype(tile_to_shape(SmemLayoutAtomLSE{}, Shape<Int<kMaxSplits>, Int<kBlockM>>{}));
102
+
103
+ using SmemLayoutO = Layout<Shape<Int<kBlockM>, Int<kBlockK>, Int<kStages>>,
104
+ Stride<Int<kBlockK>, _1, Int<kBlockM * kBlockK>>>;
105
+
106
+ // We want each column (kMaxSplits) to be processed by threads in the same warp.
107
+ // To reduce the number of shuffles, we want as few threads on the same column as possible.
108
+ // E.g., if kBlockM is divisible by 64, and there are 256 threads, we want 4 threads (0, 1, 2, 4) per column
109
+ // have have 64 such quads.
110
+ static_assert(MaxThreadsPerBlock % kBlockMSmem == 0, "MaxThreadsPerBlock must be a multiple of kBlockMSmem");
111
+ static constexpr int kSmemThreadsPerColLSEt = MaxThreadsPerBlock / kBlockMSmem;
112
+ static_assert(cutlass::NumThreadsPerWarp % kSmemThreadsPerColLSEt == 0, "kSmemThreadsPerColLSEt must divide NumThreadsPerWarp");
113
+ using S2RLayoutAtomLSE = Layout<Shape<Int<kSmemThreadsPerColLSEt>, Int<MaxThreadsPerBlock / kSmemThreadsPerColLSEt>>>;
114
+ using S2RTiledCopyLSE = decltype(make_tiled_copy(cute::Copy_Atom<cute::DefaultCopy, float>{}, S2RLayoutAtomLSE{}, Layout<_1>{}));
115
+
116
+ using ShapeOPartial = cute::Shape<int32_t, int32_t, int32_t, int32_t, int32_t>; // (seqlen, d, num_splits, head, batch)
117
+ using StrideOPartial = cute::Stride<int64_t, _1, int64_t, int64_t, int64_t>;
118
+ using ShapeLSEPartial = cute::Shape<int32_t, int32_t, int32_t, int32_t>; // (seqlen, num_splits, head, batch)
119
+ using StrideLSEPartial = cute::Stride<_1, int64_t, int64_t, int64_t>; // (seqlen, num_splits, head, batch)
120
+ using ShapeO = cute::Shape<int32_t, int32_t, int32_t, int32_t>; // (seqlen, d, head, batch)
121
+ using StrideO = cute::Stride<int64_t, _1, int64_t, int64_t>;
122
+ using ShapeLSE = cute::Shape<int32_t, int32_t, int32_t>; // (seqlen, head, batch)
123
+ using StrideLSE = cute::Stride<_1, int64_t, int64_t>; // (seqlen, head, batch)
124
+
125
+ struct SharedStorage : cute::aligned_struct<128> {
126
+ cute::array_aligned<float, cute::cosize_v<SmemLayoutLSE>> smem_lse_partial;
127
+ cute::array_aligned<int, kBlockM> smem_max_valid_split;
128
+ cute::array_aligned<ElementPartial, cute::cosize_v<SmemLayoutO>> smem_o_partial;
129
+ };
130
+
131
+ static constexpr int SharedStorageSize = sizeof(SharedStorage);
132
+
133
+ // Device side arguments
134
+ struct Arguments {
135
+ ElementPartial const* const ptr_O_partial;
136
+ ShapeOPartial const shape_O_partial;
137
+ StrideOPartial const stride_O_partial;
138
+ float const* const ptr_LSE_partial;
139
+ ShapeLSEPartial const shape_LSE_partial;
140
+ StrideLSEPartial const stride_LSE_partial;
141
+ Element* const ptr_O;
142
+ StrideO const stride_O;
143
+ float* const ptr_LSE;
144
+ StrideLSE const stride_LSE;
145
+ int const* const cu_seqlens = nullptr;
146
+ int const* const seqused = nullptr;
147
+ int const* const num_splits_dynamic_ptr = nullptr;
148
+ int* const semaphore_to_reset = nullptr;
149
+ };
150
+
151
+ // Kernel entry point API
152
+ struct Params {
153
+ ElementPartial const* const ptr_O_partial;
154
+ ShapeOPartial const shape_O_partial;
155
+ StrideOPartial const stride_O_partial;
156
+ float const* const ptr_LSE_partial;
157
+ ShapeLSEPartial const shape_LSE_partial;
158
+ StrideLSEPartial const stride_LSE_partial;
159
+ Element* const ptr_O;
160
+ StrideO const stride_O;
161
+ float* const ptr_LSE;
162
+ StrideLSE const stride_LSE;
163
+ cutlass::FastDivmod seqlen_divmod, head_divmod;
164
+ int const* const cu_seqlens = nullptr;
165
+ int const* const seqused = nullptr;
166
+ int const* const num_splits_dynamic_ptr = nullptr;
167
+ int* const semaphore_to_reset = nullptr;
168
+ };
169
+
170
+ // Convert to underlying arguments. In this case, a simple copy for the aliased type.
171
+ static
172
+ Params
173
+ to_underlying_arguments(Arguments const& args) {
174
+ assert(get<1>(args.shape_LSE_partial) <= kMaxSplits);
175
+ return {
176
+ args.ptr_O_partial,
177
+ args.shape_O_partial,
178
+ args.stride_O_partial,
179
+ args.ptr_LSE_partial,
180
+ args.shape_LSE_partial,
181
+ args.stride_LSE_partial,
182
+ args.ptr_O,
183
+ args.stride_O,
184
+ args.ptr_LSE,
185
+ args.stride_LSE,
186
+ cutlass::FastDivmod(get<0>(args.shape_LSE_partial)), cutlass::FastDivmod(get<2>(args.shape_LSE_partial)),
187
+ args.cu_seqlens,
188
+ args.seqused,
189
+ args.num_splits_dynamic_ptr,
190
+ args.semaphore_to_reset
191
+ };
192
+ }
193
+
194
+ CUTLASS_DEVICE
195
+ void
196
+ operator()(Params const& params, char* smem_buf) {
197
+
198
+ SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(smem_buf);
199
+ Tensor sLSE = make_tensor(make_smem_ptr(shared_storage.smem_lse_partial.data()), SmemLayoutLSE{});
200
+ Tensor sMaxValidSplit = make_tensor(make_smem_ptr(shared_storage.smem_max_valid_split.data()), Shape<Int<kBlockM>>{});
201
+ Tensor sO = make_tensor(make_smem_ptr(shared_storage.smem_o_partial.data()), SmemLayoutO{});
202
+
203
+ int const thread_idx = threadIdx.x;
204
+ int const m_block = blockIdx.x;
205
+ int const k_block = blockIdx.y;
206
+ int const batch = blockIdx.z;
207
+ int const num_splits = params.num_splits_dynamic_ptr ? params.num_splits_dynamic_ptr[batch] : get<1>(params.shape_LSE_partial);
208
+
209
+ if (params.semaphore_to_reset && threadIdx.x == 0 && blockIdx.x == gridDim.x - 1 && blockIdx.y == gridDim.y - 1 && blockIdx.z == gridDim.z - 1) {
210
+ cutlass::arch::wait_on_dependent_grids();
211
+ *params.semaphore_to_reset = 0;
212
+ }
213
+ if (num_splits <= 1) { return; }
214
+ flash::SeqlenInfo<Varlen, kBlockM> seqlen_info{batch, size<0>(params.shape_LSE_partial), params.cu_seqlens, params.seqused};
215
+ int const offset = seqlen_info.offset;
216
+ int const seqlen = seqlen_info.seqlen;
217
+ int max_idx = seqlen * get<2>(params.shape_LSE_partial);
218
+ if constexpr (Varlen) {
219
+ if (m_block * kBlockM >= max_idx) { return; }
220
+ }
221
+
222
+ cutlass::FastDivmod seqlen_divmod_dynamic(seqlen);
223
+
224
+ // Step 1: load LSE_partial from gmem -> smem
225
+ Tensor mLSEpartial = make_tensor(make_gmem_ptr(params.ptr_LSE_partial + offset * get<0>(params.stride_LSE_partial)),
226
+ select<1, 0, 2, 3>(params.shape_LSE_partial),
227
+ select<1, 0, 2, 3>(params.stride_LSE_partial))(_, _, _, !Varlen ? batch : 0); // (num_splits, seqlen, head)
228
+ Tensor mLSEpartial_copy = cute::tiled_divide(mLSEpartial, Shape<_1, Int<kGmemElemsPerLoadLSE>>{});
229
+ GmemTiledCopyLSE gmem_tiled_copy_LSE;
230
+ auto gmem_thr_copy_LSE = gmem_tiled_copy_LSE.get_thread_slice(thread_idx);
231
+ Tensor tLSEsLSE = gmem_thr_copy_LSE.partition_D(sLSE);
232
+
233
+ // Construct identity layout for sLSE
234
+ Tensor cLSE = make_identity_tensor(make_shape(size<0>(sLSE), size<1>(sLSE))); // (NUM_SPLITS, BLK_M) -> (num_splits, blk_m)
235
+ // Repeat the partitioning with identity layouts
236
+ Tensor tLSEcLSE = gmem_thr_copy_LSE.partition_S(cLSE);
237
+
238
+ cutlass::arch::wait_on_dependent_grids();
239
+
240
+ #pragma unroll
241
+ for (int m = 0; m < size<2>(tLSEcLSE); ++m) {
242
+ int mi = int(get<1>(tLSEcLSE(_0{}, _0{}, m)));
243
+ int idx = m_block * kBlockM + mi;
244
+ if (idx < max_idx) {
245
+ int m_idx, bidh;
246
+ if constexpr (!Varlen) {
247
+ bidh = params.seqlen_divmod.divmod(m_idx, idx);
248
+ } else {
249
+ bidh = seqlen_divmod_dynamic.divmod(m_idx, idx);
250
+ }
251
+ Tensor mLSEpartial_cur_copy = mLSEpartial_copy(_, _, m_idx, bidh);
252
+ #pragma unroll
253
+ for (int s = 0; s < size<1>(tLSEcLSE); ++s) {
254
+ int si = get<0>(tLSEcLSE(_0{}, s, _0{}));
255
+ // if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && thread_idx < 32) { printf("thread_idx = %d, m = %d, s = %d, addr = %p, bank = %d\n", thread_idx, m, s, reinterpret_cast<float *>(&(tLSEsLSE(_0{}, s, m))), reinterpret_cast<int>(&(tLSEsLSE(_0{}, s, m))) / 4 % 32);}
256
+ if (si < num_splits) {
257
+ cute::copy(gmem_tiled_copy_LSE, mLSEpartial_cur_copy(_, si), tLSEsLSE(_, s, m));
258
+ } else {
259
+ cute::fill(tLSEsLSE(_, s, m), -INFINITY);
260
+ }
261
+ }
262
+ } else {
263
+ // We don't need to zero out the rest of the LSEs, as we will not write the output to gmem
264
+ // cute::fill(tLSEsLSE(_, _, m), -INFINITY);
265
+ }
266
+ }
267
+ if constexpr (Has_cp_async) { cute::cp_async_fence(); }
268
+
269
+ // Step 2: Load O_partial from gmem -> smem for split = 0, 1, ..., kStages - 2.
270
+ // We want these async loads to be in flight as we compute the LSE.
271
+ GmemTiledCopyAccum gmem_tiled_copy_O_partial;
272
+ auto gmem_thr_copy_O_partial = gmem_tiled_copy_O_partial.get_thread_slice(thread_idx);
273
+ // Construct identity layout for gO
274
+ Tensor cO = cute::make_identity_tensor(TileShape_MK{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
275
+ // Repeat the partitioning with identity layouts
276
+ Tensor tOcO = gmem_thr_copy_O_partial.partition_D(cO);
277
+ Tensor mOpartial = make_tensor(make_gmem_ptr(params.ptr_O_partial + offset * get<0>(params.stride_O_partial)),
278
+ params.shape_O_partial, params.stride_O_partial)(_, _, _, _, !Varlen ? batch : 0); // (seqlen, d, num_splits, head)
279
+
280
+ // Precompute these values to avoid recomputing them in the loop
281
+ Tensor tOmidx = make_tensor<int>(make_shape(size<1>(tOcO)));
282
+ Tensor tObidh = make_tensor<int>(make_shape(size<1>(tOcO)));
283
+ Tensor tOrOptr = make_tensor<ElementPartial const*>(make_shape(size<1>(tOcO)));
284
+ #pragma unroll
285
+ for (int m = 0; m < size<1>(tOcO); ++m) {
286
+ int mi = get<0>(tOcO(_0{}, m, _0{}));
287
+ int idx = m_block * kBlockM + mi;
288
+ if constexpr (!Varlen) {
289
+ tObidh(m) = params.seqlen_divmod.divmod(tOmidx(m), idx);
290
+ } else {
291
+ tObidh[m] = seqlen_divmod_dynamic.divmod(tOmidx(m), idx);
292
+ }
293
+ tOrOptr[m] = &mOpartial(tOmidx(m), k_block * kBlockK, _0{}, tObidh(m));
294
+ if (idx >= max_idx) {
295
+ tObidh[m] = -1;
296
+ }
297
+ }
298
+
299
+ Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOcO)));
300
+ if constexpr (!(Is_even_K)) {
301
+ #pragma unroll
302
+ for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < get<1>(params.shape_O_partial) - k_block * kBlockK; }
303
+ }
304
+
305
+ Tensor tOsOpartial = gmem_thr_copy_O_partial.partition_D(sO);
306
+
307
+ auto load_O_partial = [&] (int split, int stage) {
308
+ Tensor tOsOpartial_cur = tOsOpartial(_, _, _, stage);
309
+ #pragma unroll
310
+ for (int m = 0; m < size<1>(tOcO); ++m) {
311
+ if (tObidh(m) >= 0) {
312
+ Tensor mOpartial_cur = make_tensor(make_gmem_ptr(tOrOptr[m]), mOpartial(_0{}, _, _, _0{}).layout());
313
+ Tensor mOpartial_cur_copy = cute::tiled_divide(mOpartial_cur, Shape<Int<kGmemElemsPerLoad>>{});
314
+ #pragma unroll
315
+ for (int k = 0; k < size<2>(tOcO); ++k) {
316
+ int k_idx = get<1>(tOcO(_0{}, _0{}, k)) / kGmemElemsPerLoad;
317
+ if (Is_even_K || tOpO(k)) {
318
+ cute::copy(gmem_tiled_copy_O_partial, mOpartial_cur_copy(_, k_idx, split), tOsOpartial_cur(_, m, k));
319
+ }
320
+ }
321
+ }
322
+ }
323
+ };
324
+
325
+ for (int s = 0; s < kStages - 1; ++s) {
326
+ if (s < num_splits) { load_O_partial(s, s); }
327
+ if constexpr (Has_cp_async) { cute::cp_async_fence(); }
328
+ }
329
+
330
+ // Step 3: load and transpose LSE_partial from smem -> rmem
331
+ if constexpr (Has_cp_async) { cutlass::arch::cp_async_wait<kStages - 1>(); }
332
+ __syncthreads();
333
+
334
+ S2RTiledCopyLSE s2r_tiled_copy_LSE;
335
+ auto s2r_thr_copy_LSE = s2r_tiled_copy_LSE.get_thread_slice(thread_idx);
336
+ Tensor ts2rsLSE = s2r_thr_copy_LSE.partition_S(sLSE);
337
+ Tensor ts2rrLSE = make_fragment_like(ts2rsLSE);
338
+ cute::copy(s2r_tiled_copy_LSE, ts2rsLSE, ts2rrLSE);
339
+
340
+ // Step 4: compute the final LSE along the split dimension
341
+ Tensor lse_sum = make_tensor<float>(make_shape(size<2>(ts2rrLSE)));
342
+ Tensor ts2rcLSE = s2r_thr_copy_LSE.partition_D(cLSE);
343
+ // We compute the max valid split for each row to short-circuit the computation later
344
+ Tensor max_valid_split = make_tensor<int>(make_shape(size<2>(ts2rrLSE)));
345
+ static_assert(CUTE_STATIC_V(size<0>(ts2rrLSE)) == 1);
346
+ #pragma unroll
347
+ for (int m = 0; m < size<2>(ts2rrLSE); ++m) {
348
+ float lse_max = ts2rrLSE(_0{}, _0{}, m);
349
+ #pragma unroll
350
+ for (int s = 1; s < size<1>(ts2rrLSE); ++s) { lse_max = max(lse_max, ts2rrLSE(_0{}, s, m)); }
351
+ MaxOp<float> max_op;
352
+ lse_max = Allreduce<kSmemThreadsPerColLSEt>::run(lse_max, max_op);
353
+ int max_valid_idx = -1;
354
+ #pragma unroll
355
+ for (int s = 0; s < size<1>(ts2rrLSE); ++s) {
356
+ if (ts2rrLSE(_0{}, s, m) != -INFINITY) { max_valid_idx = get<0>(ts2rcLSE(_0{}, s, _0{})); }
357
+ }
358
+ MaxOp<int> max_int_op;
359
+ max_valid_split[m] = Allreduce<kSmemThreadsPerColLSEt>::run(max_valid_idx, max_int_op);
360
+ float lse_max_cur = lse_max == -INFINITY ? 0.0f : lse_max; // In case all local LSEs are -inf
361
+ float lse_sum_cur = 0.f;
362
+ #pragma unroll
363
+ for (int s = 0; s < size<1>(ts2rrLSE); ++s) {
364
+ float scale = expf(ts2rrLSE(_0{}, s, m) - lse_max_cur);
365
+ lse_sum_cur += scale;
366
+ // if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && thread_idx < 32) { printf("thread_idx = %d, m = %d, s = %d, addr = %p, bank = %d\n", thread_idx, m, s, reinterpret_cast<float *>(&(ts2rsLSE(_0{}, s, m))), reinterpret_cast<int>(&(ts2rsLSE(_0{}, s, m))) / 4 % 32);}
367
+ // ts2rsLSE(_0{}, m, s) = scale;
368
+ ts2rrLSE(_0{}, s, m) = scale;
369
+ }
370
+ SumOp<float> sum_op;
371
+ lse_sum_cur = Allreduce<kSmemThreadsPerColLSEt>::run(lse_sum_cur, sum_op);
372
+ lse_sum(m) = logf(lse_sum_cur) + lse_max;
373
+ float inv_sum = (lse_sum_cur == 0.f || lse_sum_cur != lse_sum_cur) ? 0.f : 1.f / lse_sum_cur;
374
+ #pragma unroll
375
+ for (int s = 0; s < size<1>(ts2rrLSE); ++s) { ts2rrLSE(_0{}, s, m) *= inv_sum; }
376
+ }
377
+ // Store the scales exp(lse - lse_logsum) back to smem
378
+ cute::copy(s2r_tiled_copy_LSE, ts2rrLSE, ts2rsLSE);
379
+
380
+ // Store max_valid_split to smem
381
+ #pragma unroll
382
+ for (int m = 0; m < size<2>(ts2rrLSE); ++m) {
383
+ if (get<0>(ts2rcLSE(_0{}, _0{}, m)) == 0) { // Only the thread responsible for s=0 writes to smem
384
+ int mi = int(get<1>(ts2rcLSE(_0{}, _0{}, m)));
385
+ if (mi < kBlockM) { sMaxValidSplit[mi] = max_valid_split[m]; }
386
+ }
387
+ }
388
+
389
+ // Step 5: store final LSE back to gmem
390
+ if (k_block == 0) {
391
+ auto shape_LSE = select<0, 2, 3>(params.shape_LSE_partial);
392
+ Tensor mLSE = make_tensor(make_gmem_ptr(params.ptr_LSE + offset * get<0>(params.stride_LSE)), shape_LSE, params.stride_LSE)(_, _, !Varlen ? batch : 0);
393
+ #pragma unroll
394
+ for (int m = 0; m < size<2>(ts2rrLSE); ++m) {
395
+ if (get<0>(ts2rcLSE(_0{}, _0{}, m)) == 0) { // Only the thread responsible for s=0 writes to gmem
396
+ int mi = int(get<1>(ts2rcLSE(_0{}, _0{}, m)));
397
+ int idx = m_block * kBlockM + mi;
398
+ if (idx < max_idx) {
399
+ int m_idx, bidh;
400
+ if constexpr (!Varlen) {
401
+ bidh = params.seqlen_divmod.divmod(m_idx, idx);
402
+ } else {
403
+ bidh = seqlen_divmod_dynamic.divmod(m_idx, idx);
404
+ }
405
+ // printf("thread_idx = %d, m = %d, mi = %d, idx = %d, m_idx = %d, bidh = %d, bidb = %d, lse_sum = %f\n", thread_idx, m, mi, idx, m_idx, bidh, bidb, lse_sum(m));
406
+ mLSE(m_idx, bidh) = lse_sum(m);
407
+ }
408
+ }
409
+ }
410
+ }
411
+
412
+ // Step 6: read O_partial from gmem -> smem -> rmem and accumulate the final O
413
+ __syncthreads();
414
+ int thr_max_valid_split = sMaxValidSplit[get<0>(tOcO(_0{}, _0{}, _0{}))];
415
+ #pragma unroll
416
+ for (int m = 1; m < size<1>(tOcO); ++m) { thr_max_valid_split = max(thr_max_valid_split, sMaxValidSplit[get<0>(tOcO(_0{}, m, _0{}))]); }
417
+ Layout tOrOpartial_layout = gmem_thr_copy_O_partial.partition_S(make_tensor<ElementPartial>(TileShape_MK{})).layout();
418
+ Tensor tOrOpartial = make_fragment_like<ElementPartial>(tOrOpartial_layout);
419
+ Tensor tOrO = make_fragment_like<float>(tOrOpartial);
420
+ clear(tOrO);
421
+ int stage_load = kStages - 1, stage_compute = 0;
422
+ #pragma unroll 4 // Already tuned for speed
423
+ for (int s = 0; s <= thr_max_valid_split; ++s) {
424
+ Tensor scale = make_tensor<float>(make_shape(size<1>(tOrOpartial)));
425
+ #pragma unroll
426
+ for (int m = 0; m < size<1>(tOrOpartial); ++m) { scale(m) = sLSE(s, get<0>(tOcO(_0{}, m, _0{}))); }
427
+
428
+ if (s + kStages - 1 <= thr_max_valid_split) { load_O_partial(s + kStages - 1, stage_load); }
429
+ if constexpr (Has_cp_async) { cute::cp_async_fence(); }
430
+ stage_load = stage_load < kStages - 1 ? stage_load + 1 : 0;
431
+ if constexpr (Has_cp_async) { cutlass::arch::cp_async_wait<kStages - 1>(); }
432
+ // We don't need __syncthreads() because each thread is just reading its own data from smem
433
+ cute::copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementPartial>{},
434
+ tOsOpartial(_, _, _, stage_compute), tOrOpartial);
435
+ stage_compute = stage_compute < kStages - 1 ? stage_compute + 1 : 0;
436
+
437
+ #pragma unroll
438
+ for (int m = 0; m < size<1>(tOrOpartial); ++m) {
439
+ if (tObidh(m) >= 0 && scale(m) > 0.f) {
440
+ #pragma unroll
441
+ for (int k = 0; k < size<2>(tOrOpartial); ++k) {
442
+ if (Is_even_K || tOpO(k)) {
443
+ Tensor rOpartial = make_tensor_like<float>(tOrOpartial(_, m, k));
444
+ flash::convert_type_out(tOrOpartial(_, m, k), rOpartial);
445
+ #pragma unroll
446
+ for (int i = 0; i < size<0>(tOrOpartial); ++i) {
447
+ tOrO(i, m, k) += scale(m) * rOpartial[i];
448
+ }
449
+ }
450
+ }
451
+ }
452
+ }
453
+ }
454
+
455
+ // Step 7: Write the final O to gmem
456
+ Tensor rO = make_tensor_like<Element>(tOrO);
457
+ flash::convert_type_out(tOrO, rO);
458
+ auto shape_O = make_shape(get<0>(params.shape_O_partial), get<1>(params.shape_O_partial) - k_block * kBlockK, get<3>(params.shape_O_partial), get<4>(params.shape_O_partial));
459
+ Tensor mO = make_tensor(make_gmem_ptr(params.ptr_O + offset * get<0>(params.stride_O) + k_block * kBlockK * get<1>(params.stride_O)),
460
+ shape_O, params.stride_O)(_, _, _, !Varlen ? batch : 0);
461
+ Tensor mO_copy = cute::tiled_divide(mO, Shape<_1, Int<kGmemElemsPerLoad>>{});
462
+ GmemTiledCopy gmem_tiled_copy_O;
463
+ auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx);
464
+
465
+ #pragma unroll
466
+ for (int m = 0; m < size<1>(tOcO); ++m) {
467
+ if (tObidh(m) >= 0) {
468
+ #pragma unroll
469
+ for (int k = 0; k < size<2>(tOcO); ++k) {
470
+ int k_idx = get<1>(tOcO(_0{}, _0{}, k)) / kGmemElemsPerLoad;
471
+ if (Is_even_K || tOpO(k)) {
472
+ cute::copy(gmem_tiled_copy_O, rO(_, m, k), mO_copy(_, tOmidx(m), k_idx, tObidh(m)));
473
+ }
474
+ }
475
+ }
476
+ }
477
+
478
+ }
479
+
480
+ };
481
+
482
+ } // namespace flash
flash-attn/flash_fwd_combine_launch_template.h ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ #pragma once
6
+
7
+ #include "cute/tensor.hpp"
8
+
9
+ #include "cutlass/cutlass.h"
10
+ #include "cutlass/arch/arch.h" // For cutlass::arch::Sm80
11
+ #include "cutlass/device_kernel.h" // For device_kernel
12
+ #include "cutlass/kernel_launch.h" // For kernel_launch
13
+
14
+ #include "static_switch.h"
15
+ #include "flash.h"
16
+ #include "flash_fwd_combine_kernel.h"
17
+
18
+ using namespace cute;
19
+
20
+ template <int Arch, int kBlockM, int kBlockK, int kLogMaxSplits, bool IsEvenK, bool Varlen, typename Element, typename ElementPartial>
21
+ void run_flash_fwd_combine(Flash_fwd_params &params, cudaStream_t stream, bool enable_pdl) {
22
+ using ArchTag = std::conditional_t<Arch >= 90, cutlass::arch::Sm90, cutlass::arch::Sm80>;
23
+ using TileShape_MK = cute::Shape<Int<kBlockM>, Int<kBlockK>>;
24
+ using CombineKernel = flash::FlashAttnFwdCombine<TileShape_MK, kLogMaxSplits, 256 /*kNThreads*/, 1 /*AlignmentLSE*/,
25
+ IsEvenK, Varlen, Element, ElementPartial, ArchTag>;
26
+
27
+ typename CombineKernel::Arguments args {
28
+ static_cast<ElementPartial const*>(params.oaccum_ptr),
29
+ {!Varlen ? params.seqlen_q : params.total_q, params.dv, params.num_splits, params.h, !Varlen ? params.b : 1}, // shape_O_partial
30
+ {params.oaccum_row_stride, _1{}, params.oaccum_split_stride, params.oaccum_head_stride, !Varlen ? params.oaccum_batch_stride : 0}, // stride_O_partial
31
+ static_cast<float*>(params.softmax_lseaccum_ptr),
32
+ {!Varlen ? params.seqlen_q : params.total_q, params.num_splits, params.h, !Varlen ? params.b : 1}, // shape_LSE_partial
33
+ {_1{}, params.lseaccum_split_stride, params.lseaccum_head_stride, !Varlen ? params.lseaccum_batch_stride : 0}, // stride_LSE_partial
34
+ static_cast<Element*>(params.o_ptr),
35
+ {params.o_row_stride, _1{}, params.o_head_stride, !Varlen ? params.o_batch_stride : 0}, // stride_O
36
+ static_cast<float*>(params.softmax_lse_ptr),
37
+ {_1{}, !Varlen ? params.seqlen_q : params.total_q, !Varlen ? params.h * params.seqlen_q : 0}, // stride_LSE
38
+ params.cu_seqlens_q, params.seqused_q, params.num_splits_dynamic_ptr, params.tile_count_semaphore
39
+ };
40
+
41
+ typename CombineKernel::Params kernel_params = CombineKernel::to_underlying_arguments(args);
42
+ int num_blocks_k = cute::ceil_div(params.dv, kBlockK);
43
+ int num_blocks_m = cute::ceil_div(params.seqlen_q * params.h, kBlockM);
44
+ dim3 grid_m(num_blocks_m, num_blocks_k, params.b);
45
+ auto kernel = cutlass::device_kernel<CombineKernel>;
46
+ int smem_size = CombineKernel::SharedStorageSize;
47
+ if (smem_size >= 48 * 1024) {
48
+ CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
49
+ }
50
+ // kernel<<<grid_m, CombineKernel::MaxThreadsPerBlock, smem_size, stream>>>(kernel_params);
51
+ cutlass::kernel_launch<CombineKernel>(grid_m, CombineKernel::MaxThreadsPerBlock, smem_size, stream, kernel_params, Arch >= 90 && enable_pdl /*launch_with_pdl*/);
52
+ CHECK_CUDA_KERNEL_LAUNCH();
53
+ }
54
+
55
+ template<typename T, typename Tpartial, int kBlockK>
56
+ void run_mha_fwd_combine_(Flash_fwd_params &params, cudaStream_t stream, bool enable_pdl) {
57
+ // We want kBlockM to be as small as possible to maximize parallelism.
58
+ // E.g., if hdim is 64, we want kBlockM to be 16 so that we can use 256 threads, each reading 4 elements (floats).
59
+ static_assert(kBlockK % 32 == 0, "kBlockK must be a multiple of 32");
60
+ static constexpr int kBlockM = kBlockK % 128 == 0 ? 8 : (kBlockK % 64 == 0 ? 16 : 32);
61
+ ARCH_SWITCH(params.arch, Arch, [&] {
62
+ BOOL_SWITCH(params.cu_seqlens_q || params.seqused_q, Varlen, [&] {
63
+ if constexpr (kBlockM >= 16) { // If kBlockM == 8 then the minimum number of splits is 32.
64
+ if (params.num_splits <= 16) {
65
+ run_flash_fwd_combine<Arch, kBlockM, kBlockK, 4, false /*IsEvenK*/, Varlen, T, Tpartial>(params, stream, enable_pdl);
66
+ return;
67
+ }
68
+ }
69
+ if (params.num_splits <= 32) {
70
+ run_flash_fwd_combine<Arch, kBlockM, kBlockK, 5, false /*IsEvenK*/, Varlen, T, Tpartial>(params, stream, enable_pdl);
71
+ } else if (params.num_splits <= 64) {
72
+ run_flash_fwd_combine<Arch, kBlockM, kBlockK, 6, false /*IsEvenK*/, Varlen, T, Tpartial>(params, stream, enable_pdl);
73
+ } else if (params.num_splits <= 128) {
74
+ run_flash_fwd_combine<Arch, kBlockM, kBlockK, 7, false /*IsEvenK*/, Varlen, T, Tpartial>(params, stream, enable_pdl);
75
+ } else {
76
+ run_flash_fwd_combine<Arch, kBlockM, kBlockK, 8, false /*IsEvenK*/, Varlen, T, Tpartial>(params, stream, enable_pdl);
77
+ }
78
+ });
79
+ });
80
+ }
flash-attn/flash_fwd_kernel_sm80.h ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2024, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ #pragma once
6
+
7
+ #include "cute/tensor.hpp"
8
+
9
+ #include <cutlass/cutlass.h>
10
+ #include <cutlass/array.h>
11
+ #include <cutlass/numeric_types.h>
12
+ #include <cutlass/kernel_hardware_info.h>
13
+
14
+ #include "seqlen.h"
15
+ #include "utils.h"
16
+ #include "softmax.h"
17
+
18
+ namespace flash {
19
+
20
+ using namespace cute;
21
+
22
+ template <class CollectiveMainloop_, class CollectiveEpilogue_, class TileScheduler_>
23
+ class FlashAttnFwdSm80 {
24
+
25
+ public:
26
+
27
+ // Type Aliases
28
+ using CollectiveMainloop = CollectiveMainloop_;
29
+ using CollectiveEpilogue = CollectiveEpilogue_;
30
+ static constexpr bool Is_causal = CollectiveMainloop::Is_causal;
31
+ static constexpr bool Is_local = CollectiveMainloop::Is_local;
32
+ static_assert(CollectiveMainloop::Varlen == CollectiveEpilogue::Varlen);
33
+ static constexpr bool Has_softcap = CollectiveMainloop::Has_softcap;
34
+ static constexpr bool Varlen = CollectiveMainloop::Varlen;
35
+ static constexpr bool PagedKV = CollectiveMainloop::PagedKV;
36
+ static constexpr bool Split = CollectiveMainloop::Split;
37
+ static constexpr bool Is_FP8 = CollectiveMainloop::Is_FP8;
38
+ static constexpr bool Transpose_V = CollectiveMainloop::Transpose_V;
39
+ static constexpr bool AppendKV = CollectiveMainloop::AppendKV;
40
+ static constexpr bool PackGQA = CollectiveMainloop::PackGQA;
41
+ static constexpr int NumProducerThreads = CollectiveMainloop::NumProducerThreads;
42
+ using SeqlenInfo_t = typename CollectiveMainloop::SeqlenInfo_t;
43
+
44
+ // Mainloop derived types
45
+ using TileShape_MNK = typename CollectiveMainloop::TileShape_MNK;
46
+ using TiledMma = typename CollectiveMainloop::TiledMma;
47
+ using ArchTag = typename CollectiveMainloop::ArchTag;
48
+ using MainloopArguments = typename CollectiveMainloop::Arguments;
49
+ using MainloopParams = typename CollectiveMainloop::Params;
50
+
51
+ // Epilogue derived types
52
+ using EpilogueArguments = typename CollectiveEpilogue::Arguments;
53
+ using EpilogueParams = typename CollectiveEpilogue::Params;
54
+
55
+ static_assert(ArchTag::kMinComputeCapability >= 80);
56
+
57
+ using TileScheduler = TileScheduler_;
58
+ using TileSchedulerArguments = typename flash::TileSchedulerArguments;
59
+ using TileSchedulerParams = typename TileScheduler::Params;
60
+
61
+ static constexpr uint32_t NumThreads = CUTE_STATIC_V(size(TiledMma{}));
62
+ static constexpr uint32_t MaxThreadsPerBlock = CUTE_STATIC_V(size(TiledMma{}));
63
+ static constexpr uint32_t MinBlocksPerMultiprocessor = NumThreads == 128 ? 2 : 1;
64
+
65
+ // Kernel level shared memory storage
66
+ // We overlap the shared memory for the mainloop and epilogue. However, we only want smem_o to overlap with smem_v + smem_k and not smem_q
67
+ // and nothing else, so we'll pad in case sizeof(smem_o) > sizeof(smem_v) + sizeof(smem_k).
68
+ static constexpr int mainloop_smem_padding_ = int(sizeof(typename CollectiveEpilogue::TensorStorage))
69
+ - int(sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_v)))
70
+ - int(sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_k)));
71
+ static constexpr int mainloop_smem_padding = mainloop_smem_padding_ < 0 ? 0 : mainloop_smem_padding_;
72
+ struct SharedStorage {
73
+ struct TensorStorage : cute::aligned_struct<128> {
74
+ union {
75
+ struct {
76
+ cute::array<uint32_t, mainloop_smem_padding / sizeof(uint32_t)> padding_;
77
+ typename CollectiveMainloop::TensorStorage mainloop;
78
+ };
79
+ // We want smem_o to line up with the start of smem_v
80
+ typename CollectiveEpilogue::TensorStorage epilogue;
81
+ };
82
+ } tensors;
83
+
84
+ alignas(16) typename TileScheduler::SharedStorage smem_scheduler;
85
+
86
+ };
87
+
88
+ static constexpr int SharedStorageSize = sizeof(SharedStorage);
89
+
90
+ // Device side arguments
91
+ struct Arguments {
92
+ MainloopArguments mainloop{};
93
+ EpilogueArguments epilogue{};
94
+ cutlass::KernelHardwareInfo hw_info{};
95
+ TileSchedulerArguments scheduler{};
96
+ };
97
+
98
+ // Kernel entry point API
99
+ struct Params {
100
+ MainloopParams mainloop{};
101
+ EpilogueParams epilogue{};
102
+ cutlass::KernelHardwareInfo hw_info{};
103
+ TileSchedulerParams scheduler{};
104
+ };
105
+
106
+ //
107
+ // Methods
108
+ //
109
+
110
+ // Convert to underlying arguments. In this case, a simple copy for the aliased type.
111
+ static
112
+ Params
113
+ to_underlying_arguments(Arguments const& args) {
114
+ CUTLASS_TRACE_HOST("to_underlying_arguments():");
115
+
116
+ // Get SM count if needed, otherwise use user supplied SM count
117
+ int sm_count = args.hw_info.sm_count;
118
+ if (sm_count <= 0) {
119
+ CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n"
120
+ " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count.");
121
+ sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id);
122
+ }
123
+
124
+ CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count);
125
+
126
+ cutlass::KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count};
127
+ return {
128
+ CollectiveMainloop::to_underlying_arguments(args.mainloop),
129
+ CollectiveEpilogue::to_underlying_arguments(args.epilogue),
130
+ hw_info,
131
+ TileScheduler::to_underlying_arguments(args.scheduler)
132
+ };
133
+ }
134
+
135
+ // Computes the kernel launch grid shape based on runtime parameters
136
+ static dim3
137
+ get_grid_shape(Params const& params) {
138
+ return TileScheduler::get_grid_shape(params.scheduler, params.hw_info.sm_count * MinBlocksPerMultiprocessor);
139
+ }
140
+
141
+ static dim3
142
+ get_block_shape() {
143
+ return dim3(MaxThreadsPerBlock, 1, 1);
144
+ }
145
+
146
+ CUTLASS_DEVICE
147
+ void
148
+ operator()(Params const& params, char* smem_buf) {
149
+
150
+ static constexpr int kBlockM = get<0>(TileShape_MNK{});
151
+
152
+ SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(smem_buf);
153
+
154
+ CollectiveMainloop mainloop;
155
+ CollectiveEpilogue epilogue;
156
+
157
+ TileScheduler scheduler(reinterpret_cast<typename TileScheduler::SharedStorage*>(&shared_storage.smem_scheduler));
158
+ // Initialize matmul objects.
159
+ TiledMma tiled_mma;
160
+
161
+ scheduler.init_consumer();
162
+
163
+ int warp_idx = cutlass::canonical_warp_idx_sync();
164
+ CUTLASS_PRAGMA_NO_UNROLL
165
+ for (auto work_tile_info = warp_idx == 0 ? scheduler.template get_initial_work</*IsProducerWarp=*/true>(params.scheduler) : scheduler.template get_initial_work</*IsProducerWarp=*/false>(params.scheduler);
166
+ work_tile_info.is_valid(params.scheduler);
167
+ work_tile_info = warp_idx == 0 ? scheduler.template get_next_work</*IsProducerWarp=*/true>(params.scheduler, work_tile_info) : scheduler.template get_next_work</*IsProducerWarp=*/false>(params.scheduler, work_tile_info)) {
168
+ // Attention output (GEMM-II) accumulator.
169
+ Tensor tOrO = partition_fragment_C(tiled_mma, select<0, 2>(TileShape_MNK{}));
170
+ float softmax_scale_log2 = params.mainloop.softmax_scale_log2;
171
+ // If there's tanh softcap, the scaling will be done before tanh.
172
+ auto block_coord = work_tile_info.get_block_coord(params.scheduler);
173
+ int const bidb = get<2>(block_coord);
174
+ if constexpr (Is_FP8 && !Has_softcap) {
175
+ int const bidh = get<1>(block_coord);
176
+ int const bidh_kv = !PackGQA ? params.mainloop.qhead_per_khead_divmod.divide(bidh) : bidh;
177
+ float const q_descale = params.mainloop.ptr_q_descale == nullptr ? 1.0f : params.mainloop.ptr_q_descale[bidb * get<0>(params.mainloop.stride_q_descale) + bidh_kv * get<1>(params.mainloop.stride_q_descale)];
178
+ float const k_descale = params.mainloop.ptr_k_descale == nullptr ? 1.0f : params.mainloop.ptr_k_descale[bidb * get<0>(params.mainloop.stride_k_descale) + bidh_kv * get<1>(params.mainloop.stride_k_descale)];
179
+ softmax_scale_log2 *= q_descale * k_descale;
180
+ }
181
+ flash::Softmax<2 * (2 * kBlockM / NumThreads), /*Max_offset=*/!Is_FP8 ? 0 : 8> softmax(softmax_scale_log2);
182
+
183
+ SeqlenInfo_t seqlen_info{
184
+ bidb,
185
+ get<0>(params.mainloop.shape_Q),
186
+ !PagedKV ? size<0>(params.mainloop.shape_K) : size<0>(params.mainloop.shape_K) * size<1>(params.mainloop.shape_pagetable),
187
+ get<0>(params.mainloop.shape_K_new),
188
+ params.mainloop.cu_seqlens_q, params.mainloop.cu_seqlens_k, params.mainloop.cu_seqlens_k_new,
189
+ params.mainloop.seqused_q, params.mainloop.seqused_k, params.mainloop.leftpad_k,
190
+ params.mainloop.seqlens_rotary
191
+ };
192
+ if constexpr (AppendKV) {
193
+ bool tile_new_valid = mainloop.store_kv_new(
194
+ params.mainloop, threadIdx.x, shared_storage, seqlen_info, block_coord);
195
+ if (tile_new_valid) { __syncthreads(); }
196
+ }
197
+ bool tile_valid = mainloop.mma(
198
+ params.mainloop, tOrO, softmax, threadIdx.x, seqlen_info, block_coord,
199
+ shared_storage);
200
+ scheduler.prefetch_next_work(params.scheduler, work_tile_info);
201
+ if (tile_valid) {
202
+ // if (threadIdx.x == 128) { printf("Before epilogue, bid.x = %d, bid.y = %d, bid.z = %d, m_block = %d, bidb = %d, split_idx = %d\n", blockIdx.x, blockIdx.y, blockIdx.z, m_block, bidb, split_idx); }
203
+ epilogue.store(params.epilogue, tOrO, softmax.row_sum, shared_storage, tiled_mma,
204
+ threadIdx.x, block_coord);
205
+ } else {
206
+ // Write 0 to gO and -inf to gLSE.
207
+ epilogue.store_zero(params.epilogue, threadIdx.x, block_coord);
208
+ }
209
+ }
210
+
211
+ }
212
+
213
+ };
214
+
215
+ } // namespace flash
flash-attn/flash_fwd_kernel_sm90.h ADDED
@@ -0,0 +1,458 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ #pragma once
6
+
7
+ #include "cute/tensor.hpp"
8
+
9
+ #include <cutlass/cutlass.h>
10
+ #include <cutlass/arch/reg_reconfig.h>
11
+ #include <cutlass/array.h>
12
+ #include <cutlass/numeric_types.h>
13
+ #include <cutlass/numeric_conversion.h>
14
+ #include <cutlass/kernel_hardware_info.h>
15
+ #include "cutlass/pipeline/pipeline.hpp"
16
+
17
+ #include "cutlass/arch/grid_dependency_control.h"
18
+
19
+ #include "seqlen.h"
20
+ #include "utils.h"
21
+ #include "softmax.h"
22
+
23
+ namespace flash {
24
+
25
+ using namespace cute;
26
+
27
+ template <class CollectiveMainloop_, class CollectiveEpilogue_, class TileScheduler_>
28
+ class FlashAttnFwdSm90 {
29
+
30
+ public:
31
+
32
+ // Type Aliases
33
+ using CollectiveMainloop = CollectiveMainloop_;
34
+ using CollectiveEpilogue = CollectiveEpilogue_;
35
+ static constexpr bool Is_causal = CollectiveMainloop::Is_causal;
36
+ static constexpr bool Is_local = CollectiveMainloop::Is_local;
37
+ static_assert(CollectiveMainloop::Varlen == CollectiveEpilogue::Varlen);
38
+ static constexpr bool Has_softcap = CollectiveMainloop::Has_softcap;
39
+ static constexpr bool Varlen = CollectiveMainloop::Varlen;
40
+ static constexpr bool Split = CollectiveMainloop::Split;
41
+ static constexpr bool Is_FP8 = CollectiveMainloop::Is_FP8;
42
+ static constexpr bool Transpose_V = CollectiveMainloop::Transpose_V;
43
+ static constexpr bool AppendKV = CollectiveMainloop::AppendKV;
44
+ static constexpr bool HasQv = CollectiveMainloop::HasQv;
45
+ static constexpr bool Use_TMA_Q = CollectiveMainloop::Use_TMA_Q;
46
+ static constexpr bool Use_TMA_KV = CollectiveMainloop::Use_TMA_KV;
47
+ static constexpr bool Use_TMA_O = CollectiveEpilogue::Use_TMA_O;
48
+ static constexpr bool PackGQA = CollectiveMainloop::PackGQA;
49
+ static constexpr int NumProducerThreads = CollectiveMainloop::NumProducerThreads;
50
+ static constexpr bool SameHeadDim = CollectiveMainloop::SameHeadDim;
51
+ static constexpr bool LargeHeadDimV = CollectiveMainloop::LargeHeadDimV;
52
+ static_assert(CollectiveMainloop::LargeHeadDimV == CollectiveEpilogue::LargeHeadDimV);
53
+ using SeqlenInfo_t = typename CollectiveMainloop::SeqlenInfo_t;
54
+
55
+ // Mainloop derived types
56
+ using TileShape_MNK_PV = typename CollectiveMainloop::TileShape_MNK_PV;
57
+ using TiledMmaPV = typename CollectiveMainloop::TiledMmaPV;
58
+ using ArchTag = typename CollectiveMainloop::ArchTag;
59
+ using ClusterShape = typename CollectiveMainloop::ClusterShape;
60
+ using MainloopArguments = typename CollectiveMainloop::Arguments;
61
+ using MainloopParams = typename CollectiveMainloop::Params;
62
+ using BarrierQ = std::conditional_t<Use_TMA_Q, cutlass::arch::ClusterTransactionBarrier, cutlass::arch::ClusterBarrier>;
63
+
64
+ // Epilogue derived types
65
+ using EpilogueArguments = typename CollectiveEpilogue::Arguments;
66
+ using EpilogueParams = typename CollectiveEpilogue::Params;
67
+
68
+ static_assert(ArchTag::kMinComputeCapability >= 90);
69
+
70
+ using TileScheduler = TileScheduler_;
71
+ using TileSchedulerArguments = typename flash::TileSchedulerArguments;
72
+ using TileSchedulerParams = typename TileScheduler::Params;
73
+
74
+ static constexpr uint32_t NumLoadWarpGroups = 1;
75
+ static constexpr uint32_t NumMmaWarpGroups = CUTE_STATIC_V(size(TiledMmaPV{})) / cutlass::NumThreadsPerWarpGroup;
76
+ static constexpr uint32_t MaxThreadsPerBlock = CUTE_STATIC_V(size(TiledMmaPV{})) + (NumLoadWarpGroups * cutlass::NumThreadsPerWarpGroup);
77
+ static constexpr uint32_t MinBlocksPerMultiprocessor = 1;
78
+ static_assert(NumMmaWarpGroups == 1 || NumMmaWarpGroups == 2 || NumMmaWarpGroups == 3);
79
+
80
+ /// Register requirement for Load and Math WGs
81
+ // If we use cp.async to load K and V, we need more registers for the producer WG.
82
+ static constexpr uint32_t LoadRegisterRequirement = NumMmaWarpGroups == 1 ? 56 : (NumMmaWarpGroups == 2 ? (Use_TMA_KV ? 24 : 40) : 32);
83
+ static constexpr uint32_t MmaRegisterRequirement = NumMmaWarpGroups == 1 ? 256 : (NumMmaWarpGroups == 2 ? (Use_TMA_KV ? 240 : 232) : 160);
84
+ // If you want to print from the producer warp, you'd need to increase the number of registers
85
+ // Otherwise you'll get CUDA error.
86
+ // static constexpr uint32_t LoadRegisterRequirement = 40;
87
+ // static constexpr uint32_t MmaRegisterRequirement = NumMmaWarpGroups == 2 ? 232 : 152;
88
+
89
+ // Kernel level shared memory storage
90
+ // We overlap the shared memory for the mainloop and epilogue. However, we only want smem_o to overlap with smem_v
91
+ // and nothing else, so we'll pad in case sizeof(smem_o) > sizeof(smem_v).
92
+ static constexpr int mainloop_smem_padding_ = int(sizeof(typename CollectiveEpilogue::TensorStorage)) - int(sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_v)));
93
+ static constexpr int mainloop_smem_padding = mainloop_smem_padding_ < 0 ? 0 : mainloop_smem_padding_;
94
+ struct SharedStorage {
95
+ struct TensorStorage : cute::aligned_struct<128, _1> {
96
+ union {
97
+ struct {
98
+ cute::array<uint32_t, mainloop_smem_padding / sizeof(uint32_t)> padding_;
99
+ typename CollectiveMainloop::TensorStorage mainloop;
100
+ };
101
+ // We want smem_o to line up with the start of smem_v
102
+ typename CollectiveEpilogue::TensorStorage epilogue;
103
+ };
104
+ } tensors;
105
+ struct PipelineStorage : cute::aligned_struct<16, _1> {
106
+ alignas(16) BarrierQ barrier_Q;
107
+ alignas(16) BarrierQ barrier_Qv;
108
+ alignas(16) cutlass::arch::ClusterBarrier barrier_O;
109
+ alignas(16) typename CollectiveMainloop::MainloopPipelineK::SharedStorage pipeline_k;
110
+ alignas(16) typename CollectiveMainloop::MainloopPipelineV::SharedStorage pipeline_v;
111
+ alignas(16) typename CollectiveMainloop::MainloopPipelineVt::SharedStorage pipeline_vt;
112
+ alignas(16) typename CollectiveMainloop::MainloopPipelineKVNew::SharedStorage pipeline_k_new;
113
+ alignas(16) typename CollectiveMainloop::MainloopPipelineKVNew::SharedStorage pipeline_v_new;
114
+ alignas(16) typename TileScheduler::SharedStorage smem_scheduler;
115
+ } pipelines;
116
+
117
+ };
118
+
119
+ static constexpr int SharedStorageSize = sizeof(SharedStorage);
120
+
121
+ // Device side arguments
122
+ struct Arguments {
123
+ MainloopArguments mainloop{};
124
+ EpilogueArguments epilogue{};
125
+ cutlass::KernelHardwareInfo hw_info{};
126
+ TileSchedulerArguments scheduler{};
127
+ };
128
+
129
+ // Kernel entry point API
130
+ struct Params {
131
+ MainloopParams mainloop{};
132
+ EpilogueParams epilogue{};
133
+ cutlass::KernelHardwareInfo hw_info{};
134
+ TileSchedulerParams scheduler{};
135
+ };
136
+
137
+ //
138
+ // Methods
139
+ //
140
+
141
+ // Convert to underlying arguments. In this case, a simple copy for the aliased type.
142
+ static
143
+ Params
144
+ to_underlying_arguments(Arguments const& args) {
145
+ CUTLASS_TRACE_HOST("to_underlying_arguments():");
146
+
147
+ // Get SM count if needed, otherwise use user supplied SM count
148
+ int sm_count = args.hw_info.sm_count;
149
+ if (sm_count <= 0) {
150
+ CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n"
151
+ " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count.");
152
+ sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id);
153
+ }
154
+
155
+ CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count);
156
+
157
+ cutlass::KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count};
158
+ return {
159
+ CollectiveMainloop::to_underlying_arguments(args.mainloop),
160
+ CollectiveEpilogue::to_underlying_arguments(args.epilogue),
161
+ hw_info,
162
+ TileScheduler::to_underlying_arguments(args.scheduler)
163
+ };
164
+ }
165
+
166
+ // Computes the kernel launch grid shape based on runtime parameters
167
+ static dim3
168
+ get_grid_shape(Params const& params) {
169
+ return TileScheduler::get_grid_shape(params.scheduler, params.hw_info.sm_count);
170
+ }
171
+
172
+ static dim3
173
+ get_block_shape() {
174
+ return dim3(MaxThreadsPerBlock, 1, 1);
175
+ }
176
+
177
+ CUTLASS_DEVICE
178
+ void
179
+ operator()(Params const& params, char* smem_buf) {
180
+
181
+ static constexpr int NumMmaThreads = NumMmaWarpGroups * cutlass::NumThreadsPerWarpGroup;
182
+ static constexpr int MmaThreadOffset = NumLoadWarpGroups * cutlass::NumThreadsPerWarpGroup;
183
+ static constexpr int kBlockM = get<0>(TileShape_MNK_PV{});
184
+
185
+ using MainloopPipelineK = typename CollectiveMainloop::MainloopPipelineK;
186
+ using MainloopPipelineV = typename CollectiveMainloop::MainloopPipelineV;
187
+ using MainloopPipelineVt = typename CollectiveMainloop::MainloopPipelineVt;
188
+ using MainloopPipelineKVNew = typename CollectiveMainloop::MainloopPipelineKVNew;
189
+ using PipelineState = typename CollectiveMainloop::PipelineState;
190
+ using PipelineParamsK = typename MainloopPipelineK::Params;
191
+ using PipelineParamsV = typename MainloopPipelineV::Params;
192
+ using PipelineParamsVt = typename MainloopPipelineVt::Params;
193
+ using PipelineParamsKVNew = typename MainloopPipelineKVNew::Params;
194
+
195
+ SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(smem_buf);
196
+
197
+ int const lane_predicate = cute::elect_one_sync();
198
+ int const warp_idx = cutlass::canonical_warp_idx_sync();
199
+
200
+ // Issue Tma Descriptor Prefetch from a single thread
201
+ if (warp_idx == 0 && lane_predicate) {
202
+ CollectiveMainloop::prefetch_tma_descriptors(params.mainloop);
203
+ CollectiveEpilogue::prefetch_tma_descriptors(params.epilogue);
204
+ }
205
+
206
+ // Obtain warp index
207
+ int const warp_group_thread_idx = threadIdx.x % cutlass::NumThreadsPerWarpGroup;
208
+ int warp_group_idx = cutlass::canonical_warp_group_idx();
209
+
210
+ if (warp_idx == 0 && lane_predicate) {
211
+ shared_storage.pipelines.barrier_Q.init(Use_TMA_Q ? 1 : NumProducerThreads /*numThreads*/);
212
+ if constexpr (HasQv) {
213
+ shared_storage.pipelines.barrier_Qv.init(Use_TMA_Q ? 1 : NumProducerThreads /*numThreads*/);
214
+ }
215
+ shared_storage.pipelines.barrier_O.init(size(ClusterShape{}) * (Use_TMA_O ? 1 : NumMmaThreads) /*numThreads*/);
216
+ }
217
+
218
+ // We're counting on pipeline_k to call cutlass::arch::fence_barrier_init();
219
+ PipelineParamsK pipeline_params_k;
220
+ pipeline_params_k.role = warp_group_idx == 0
221
+ ? MainloopPipelineK::ThreadCategory::Producer
222
+ : MainloopPipelineK::ThreadCategory::Consumer;
223
+ if constexpr (Use_TMA_KV) {
224
+ pipeline_params_k.transaction_bytes = CollectiveMainloop::TmaTransactionBytesK;
225
+ pipeline_params_k.is_leader = warp_group_thread_idx == 0;
226
+ pipeline_params_k.num_consumers = !LargeHeadDimV ? NumMmaThreads : cutlass::NumThreadsPerWarpGroup;
227
+ } else {
228
+ pipeline_params_k.consumer_arv_count = !LargeHeadDimV ? NumMmaThreads : cutlass::NumThreadsPerWarpGroup;
229
+ pipeline_params_k.producer_arv_count = NumProducerThreads;
230
+ }
231
+
232
+ static_assert(is_same_v<PipelineParamsK, PipelineParamsVt>);
233
+ PipelineParamsVt pipeline_params_vt = pipeline_params_k;
234
+ if constexpr (Use_TMA_KV && !SameHeadDim) {
235
+ pipeline_params_vt.transaction_bytes = CollectiveMainloop::TmaTransactionBytesV;
236
+ if constexpr (LargeHeadDimV) { pipeline_params_vt.num_consumers = NumMmaThreads; }
237
+ } else {
238
+ if constexpr (LargeHeadDimV) { pipeline_params_vt.consumer_arv_count = NumMmaThreads; }
239
+ }
240
+
241
+ MainloopPipelineK pipeline_k = [&] {
242
+ if constexpr (Use_TMA_KV) {
243
+ return MainloopPipelineK(shared_storage.pipelines.pipeline_k, pipeline_params_k, ClusterShape{});
244
+ } else {
245
+ return MainloopPipelineK(shared_storage.pipelines.pipeline_k, pipeline_params_k);
246
+ }
247
+ }();
248
+ // MainloopPipelineV pipeline_v(shared_storage.pipelines.pipeline_v, pipeline_params_v, ClusterShape{});
249
+ MainloopPipelineV pipeline_v = [&] {
250
+ if constexpr (!Transpose_V) {
251
+ static_assert(is_same_v<PipelineParamsK, PipelineParamsV>);
252
+ if constexpr (Use_TMA_KV) {
253
+ return MainloopPipelineV(shared_storage.pipelines.pipeline_v, pipeline_params_vt, ClusterShape{});
254
+ } else {
255
+ return MainloopPipelineV(shared_storage.pipelines.pipeline_v, pipeline_params_vt);
256
+ }
257
+ } else {
258
+ PipelineParamsV pipeline_params_v;
259
+ pipeline_params_v.role = warp_group_idx == 0
260
+ ? MainloopPipelineV::ThreadCategory::Producer
261
+ : MainloopPipelineV::ThreadCategory::Consumer;
262
+ pipeline_params_v.producer_arv_count = NumProducerThreads;
263
+ pipeline_params_v.consumer_arv_count = NumMmaThreads;
264
+ return MainloopPipelineV(shared_storage.pipelines.pipeline_v, pipeline_params_v);
265
+ }
266
+ }();
267
+ // If we need to transpose V (e.g. FP8 and V is row-major), we use pipeline_vt for the TMA, then
268
+ // the producer WG will read from pipeline_vt and write to pipeline_v.
269
+ // If we don't need to transpose V, we use pipeline_v for the TMA, and pipeline_vt won't be used.
270
+ // Technically for pipeline_params_vt, warp0 of WG0 is the producer and all of WG0 are consumers.
271
+ // However, the thread role isn't used in the pipeline implementation.
272
+ MainloopPipelineVt pipeline_vt = [&] {
273
+ if constexpr (Use_TMA_KV) {
274
+ pipeline_params_vt.num_consumers = NumProducerThreads; // TMA_V is only consumed by the producer WG
275
+ return MainloopPipelineVt(shared_storage.pipelines.pipeline_vt, pipeline_params_vt, ClusterShape{});
276
+ } else {
277
+ pipeline_params_vt.consumer_arv_count = NumProducerThreads; // TMA_V is only consumed by the producer WG
278
+ return MainloopPipelineVt(shared_storage.pipelines.pipeline_vt, pipeline_params_vt);
279
+ }
280
+ }();
281
+
282
+ PipelineParamsKVNew pipeline_params_kv_new;
283
+ pipeline_params_kv_new.role = warp_group_idx == 0
284
+ ? MainloopPipelineKVNew::ThreadCategory::Producer
285
+ : MainloopPipelineKVNew::ThreadCategory::Consumer;
286
+ pipeline_params_kv_new.transaction_bytes = CollectiveMainloop::TmaTransactionBytesK;
287
+ pipeline_params_kv_new.is_leader = warp_group_thread_idx == 0;
288
+ pipeline_params_kv_new.num_consumers = NumMmaThreads;
289
+ auto pipeline_k_new = cute::conditional_return<AppendKV>(MainloopPipelineKVNew(shared_storage.pipelines.pipeline_k_new, pipeline_params_kv_new, ClusterShape{}), nullptr);
290
+ if constexpr (!SameHeadDim) {
291
+ pipeline_params_kv_new.transaction_bytes = CollectiveMainloop::TmaTransactionBytesV;
292
+ }
293
+ auto pipeline_v_new = cute::conditional_return<AppendKV>(MainloopPipelineKVNew(shared_storage.pipelines.pipeline_v_new, pipeline_params_kv_new, ClusterShape{}), nullptr);
294
+
295
+ CollectiveMainloop mainloop;
296
+ CollectiveEpilogue epilogue;
297
+
298
+ // We need this to guarantee that the Pipeline init is visible to all producers and consumer blocks in the Cluster
299
+ if constexpr (size(ClusterShape{}) > 1) {
300
+ cute::cluster_arrive_relaxed();
301
+ cute::cluster_wait();
302
+ } else {
303
+ __syncthreads();
304
+ }
305
+
306
+ TileScheduler scheduler(reinterpret_cast<typename TileScheduler::SharedStorage*>(&shared_storage.pipelines.smem_scheduler));
307
+
308
+ if (warp_group_idx == 0) { // Producer
309
+ cutlass::arch::warpgroup_reg_dealloc<LoadRegisterRequirement>();
310
+
311
+ // The pipelines for AppendKV and main attention are different, since e.g. main attention
312
+ // might use cp.async to load KV (if PagedKVNonTMA) while AppendKV always uses TMA to load
313
+ // KV_new. Since the pipeline states are different, we have to manually sync to make
314
+ // sure the two pipelines don't race when accessing smem_k and smem_v.
315
+ PipelineState smem_pipe_write = cutlass::make_producer_start_state<MainloopPipelineK>();
316
+ PipelineState smem_pipe_write_new = cutlass::make_producer_start_state<MainloopPipelineKVNew>();
317
+ int work_idx = 0;
318
+ int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0);
319
+ static constexpr bool SingleProducerWarp = NumProducerThreads == cutlass::NumThreadsPerWarp;
320
+ if constexpr (SingleProducerWarp) {
321
+ if (warp_idx_in_warpgroup != 0) { return; }
322
+ }
323
+ if (!SingleProducerWarp && warp_idx_in_warpgroup != 0) { scheduler.init_consumer(); }
324
+
325
+ cutlass::arch::wait_on_dependent_grids();
326
+
327
+ // Load Q, K, V
328
+ for (auto work_tile_info = SingleProducerWarp || warp_idx_in_warpgroup == 0 ? scheduler.template get_initial_work</*IsProducerWarp=*/true>(params.scheduler) : scheduler.template get_initial_work</*IsProducerWarp=*/false>(params.scheduler);
329
+ work_tile_info.is_valid(params.scheduler);
330
+ work_tile_info = SingleProducerWarp || warp_idx_in_warpgroup == 0 ? scheduler.template get_next_work</*IsProducerWarp=*/true>(params.scheduler, work_tile_info) : scheduler.template get_next_work</*IsProducerWarp=*/false>(params.scheduler, work_tile_info)) {
331
+
332
+ auto block_coord = work_tile_info.get_block_coord(params.scheduler);
333
+ SeqlenInfo_t seqlen_info{
334
+ get<2>(block_coord) /*bidb*/,
335
+ get<0>(params.mainloop.shape_Q),
336
+ !params.mainloop.ptr_pagetable ? size<0>(params.mainloop.shape_K) : size<0>(params.mainloop.shape_K) * size<1>(params.mainloop.shape_pagetable),
337
+ get<0>(params.mainloop.shape_K_new),
338
+ params.mainloop.cu_seqlens_q, params.mainloop.cu_seqlens_k, params.mainloop.cu_seqlens_k_new,
339
+ params.mainloop.seqused_q, params.mainloop.seqused_k, params.mainloop.leftpad_k,
340
+ params.mainloop.seqlens_rotary
341
+ };
342
+ if constexpr (AppendKV) {
343
+ bool tile_new_valid = mainloop.load_kv_new(
344
+ params.mainloop, pipeline_k_new, pipeline_v_new,
345
+ smem_pipe_write_new, shared_storage, seqlen_info, block_coord, work_idx);
346
+ if (tile_new_valid) {
347
+ // if (threadIdx.x == 0) { printf("Producer: Before sync\n"); }
348
+ cutlass::arch::NamedBarrier::sync(NumMmaThreads + NumProducerThreads, static_cast<uint32_t>(FwdNamedBarriers::AppendKV) /*id*/);
349
+ // if (threadIdx.x == 0) { printf("Producer: After sync\n"); }
350
+ }
351
+ }
352
+ auto scheduler_prefetch = [&scheduler, &params, &work_tile_info]() {
353
+ scheduler.prefetch_next_work(params.scheduler, work_tile_info);
354
+ };
355
+ // pipeline_vt won't be used if we don't need to transpose V.
356
+ mainloop.load(params.mainloop, pipeline_k, pipeline_v, pipeline_vt, smem_pipe_write,
357
+ shared_storage, scheduler_prefetch, seqlen_info, block_coord, work_idx);
358
+ }
359
+ mainloop.load_tail(pipeline_k, pipeline_v, pipeline_vt, smem_pipe_write, shared_storage, work_idx);
360
+ } else { // Consumer
361
+ cutlass::arch::warpgroup_reg_alloc<MmaRegisterRequirement>();
362
+
363
+ // Initialize matmul objects.
364
+ TiledMmaPV tiled_mma_pv;
365
+
366
+ PipelineState smem_pipe_read;
367
+ PipelineState smem_pipe_read_new;
368
+ // We don't need separate variables smem_pipe_release_k and smem_pipe_release_v
369
+ // (like in Cutlass's gemm) because the read and release pipeline states are always the same.
370
+
371
+ scheduler.init_consumer();
372
+ mainloop.mma_init();
373
+
374
+ int work_idx = 0;
375
+ CUTLASS_PRAGMA_NO_UNROLL
376
+ for (auto work_tile_info = scheduler.template get_initial_work</*IsProducerWarp=*/false>(params.scheduler);
377
+ work_tile_info.is_valid(params.scheduler);
378
+ // get_next_work will be called before the epilogue
379
+ ) {
380
+ auto block_coord = work_tile_info.get_block_coord(params.scheduler);
381
+ int const bidb = get<2>(block_coord);
382
+ SeqlenInfo_t seqlen_info{
383
+ bidb,
384
+ get<0>(params.mainloop.shape_Q),
385
+ !params.mainloop.ptr_pagetable ? size<0>(params.mainloop.shape_K) : size<0>(params.mainloop.shape_K) * size<1>(params.mainloop.shape_pagetable),
386
+ get<0>(params.mainloop.shape_K_new),
387
+ params.mainloop.cu_seqlens_q, params.mainloop.cu_seqlens_k, params.mainloop.cu_seqlens_k_new,
388
+ params.mainloop.seqused_q, params.mainloop.seqused_k, params.mainloop.leftpad_k,
389
+ params.mainloop.seqlens_rotary
390
+ };
391
+ if constexpr (AppendKV) {
392
+ bool tile_new_valid = mainloop.store_kv_new(
393
+ params.mainloop, pipeline_k_new, pipeline_v_new, smem_pipe_read_new,
394
+ threadIdx.x - MmaThreadOffset, shared_storage, seqlen_info, block_coord);
395
+ if (tile_new_valid) {
396
+ // if (threadIdx.x == 128) { printf("Consumer: Before sync\n"); }
397
+ // We need this sync so that the gmem write from the consumers is visible to the producer
398
+ // that might do TMA read after that.
399
+ asm volatile ("fence.proxy.async.global;");
400
+ cutlass::arch::NamedBarrier::arrive(NumMmaThreads + NumProducerThreads, static_cast<uint32_t>(FwdNamedBarriers::AppendKV) /*id*/);
401
+ // arrive is enough, we don't need sync. The producer will sync, which means
402
+ // after that sync we're guaranteed that the AppendKV pipeline have finished
403
+ // loading and consumer smem_k and smem_v.
404
+ // if (threadIdx.x == 128) { printf("Consumer: After sync\n"); }
405
+ }
406
+ }
407
+ // If there's tanh softcap, the scaling will be done before tanh.
408
+ float softmax_scale_log2 = params.mainloop.softmax_scale_log2;
409
+ if constexpr (Is_FP8 && !Has_softcap) {
410
+ int const bidh = get<1>(block_coord);
411
+ int const bidh_kv = !PackGQA ? params.mainloop.qhead_per_khead_divmod.divide(bidh) : bidh;
412
+ float const q_descale = params.mainloop.ptr_q_descale == nullptr ? 1.0f : params.mainloop.ptr_q_descale[bidb * get<0>(params.mainloop.stride_q_descale) + bidh_kv * get<1>(params.mainloop.stride_q_descale)];
413
+ float const k_descale = params.mainloop.ptr_k_descale == nullptr ? 1.0f : params.mainloop.ptr_k_descale[bidb * get<0>(params.mainloop.stride_k_descale) + bidh_kv * get<1>(params.mainloop.stride_k_descale)];
414
+ softmax_scale_log2 *= q_descale * k_descale;
415
+ }
416
+ flash::Softmax<!LargeHeadDimV ? 2 * (2 * kBlockM / NumMmaThreads) : 2, /*Max_offset=*/!Is_FP8 ? 0 : 8> softmax(softmax_scale_log2);
417
+ // Attention output (GEMM-II) accumulator.
418
+ Tensor tOrO = partition_fragment_C(tiled_mma_pv, select<0, 1>(TileShape_MNK_PV{}));
419
+ bool tile_valid;
420
+ if constexpr (!LargeHeadDimV) {
421
+ tile_valid = mainloop.mma(
422
+ params.mainloop, pipeline_k, pipeline_v, smem_pipe_read,
423
+ tOrO, softmax, threadIdx.x - MmaThreadOffset, work_idx, seqlen_info, block_coord, shared_storage);
424
+ } else { // mma_pv might not compile if !LargeHeadDimV
425
+ if (warp_group_idx == 1) {
426
+ tile_valid = mainloop.mma(
427
+ params.mainloop, pipeline_k, pipeline_v, smem_pipe_read,
428
+ tOrO, softmax, threadIdx.x - MmaThreadOffset, work_idx, seqlen_info, block_coord, shared_storage);
429
+ } else {
430
+ tile_valid = mainloop.mma_pv(
431
+ params.mainloop, pipeline_v, smem_pipe_read,
432
+ tOrO, softmax, threadIdx.x - MmaThreadOffset, seqlen_info, block_coord, shared_storage);
433
+ }
434
+ }
435
+ // Do this here before the epilogue so that the next tile is ready to go.
436
+ work_tile_info = scheduler.template get_next_work</*IsProducerWarp=*/false>(params.scheduler, work_tile_info);
437
+ if constexpr (Split && Varlen) {
438
+ if (!work_tile_info.is_valid(params.scheduler)) { // Last tile
439
+ cutlass::arch::launch_dependent_grids();
440
+ }
441
+ }
442
+ if (tile_valid) {
443
+ // if (threadIdx.x == 128) { printf("Before epilogue, bid.x = %d, bid.y = %d, bid.z = %d, m_block = %d, bidb = %d, split_idx = %d\n", blockIdx.x, blockIdx.y, blockIdx.z, m_block, bidb, split_idx); }
444
+ epilogue.store(params.epilogue, tOrO, softmax.row_sum, shared_storage, tiled_mma_pv,
445
+ threadIdx.x - MmaThreadOffset, block_coord);
446
+ } else {
447
+ // Write 0 to gO and -inf to gLSE.
448
+ epilogue.store_zero(params.epilogue, threadIdx.x - MmaThreadOffset, block_coord);
449
+ }
450
+ }
451
+ epilogue.store_tail();
452
+ }
453
+
454
+ }
455
+
456
+ };
457
+
458
+ } // namespace flash
flash-attn/flash_fwd_launch_template.h ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ #pragma once
6
+
7
+ #include "cute/tensor.hpp"
8
+
9
+ #include "cutlass/cutlass.h"
10
+ #include "cutlass/device_kernel.h" // For device_kernel
11
+ #include <cutlass/kernel_hardware_info.h>
12
+ #include "cutlass/cluster_launch.hpp"
13
+ #include "cutlass/kernel_launch.h"
14
+
15
+ #include "static_switch.h"
16
+ #include "flash.h"
17
+ #include "tile_size.h"
18
+ #include "tile_scheduler.hpp"
19
+ #include "flash_fwd_kernel_sm90.h"
20
+ #include "flash_fwd_kernel_sm80.h"
21
+ #include "mainloop_fwd_sm90_tma_gmma_ws.hpp"
22
+ #include "mainloop_fwd_sm80.hpp"
23
+ #include "epilogue_fwd.hpp"
24
+
25
+ using namespace cute;
26
+
27
+ template <int Arch, int kHeadDim, int kHeadDimV, int ClusterM, typename Element, typename ElementOut,
28
+ bool Is_causal, bool Is_local, bool Has_softcap, bool Varlen, bool PagedKVNonTMA, bool AppendKV, bool HasQv,
29
+ bool PackGQA, bool Split, bool V_colmajor>
30
+ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
31
+ static_assert(!(Is_causal && Is_local), "Causal and Local cannot be enabled at the same time");
32
+ static_assert(!(AppendKV && V_colmajor), "AppendKV and V_colmajor cannot be enabled at the same time");
33
+ static_assert(!(AppendKV && !Varlen), "AppendKV requires Varlen");
34
+ static constexpr bool Is_FP8 = cute::is_same_v<Element, cutlass::float_e4m3_t> || cute::is_same_v<Element, cutlass::float_e5m2_t>;
35
+ static constexpr bool FP8_TransposeV = Is_FP8 && !V_colmajor;
36
+ using ArchTag = std::conditional_t<Arch >= 90, cutlass::arch::Sm90, cutlass::arch::Sm80>;
37
+
38
+ // Can't use structured binding since it's not compatible with constexpr
39
+ static constexpr std::tuple<int, int, bool, bool> kBlockMN_RS_IntraWGOverlap = tile_size_fwd_sm90(kHeadDim, kHeadDimV, Is_causal, Is_local, sizeof(Element) /*element_size*/, V_colmajor, PagedKVNonTMA, Has_softcap);
40
+ static constexpr std::tuple<int, int, int, int, bool> kBlockMN_kNWarps_Stages_RS = tile_size_fwd_sm8x(Arch == 86 || Arch == 89, kHeadDim, kHeadDimV, Is_causal, Is_local, sizeof(Element) /*element_size*/, PagedKVNonTMA, Varlen && Split, Has_softcap, AppendKV);
41
+ static constexpr int kBlockM = Arch >= 90 ? std::get<0>(kBlockMN_RS_IntraWGOverlap) : std::get<0>(kBlockMN_kNWarps_Stages_RS);
42
+ static constexpr int kBlockN = Arch >= 90 ? std::get<1>(kBlockMN_RS_IntraWGOverlap) : std::get<1>(kBlockMN_kNWarps_Stages_RS);
43
+ static constexpr bool MmaPV_is_RS = std::get<2>(kBlockMN_RS_IntraWGOverlap);
44
+ static constexpr bool IntraWGOverlap = std::get<3>(kBlockMN_RS_IntraWGOverlap);
45
+ static constexpr int kNWarps = std::get<2>(kBlockMN_kNWarps_Stages_RS);
46
+ static constexpr int kStages = Arch >= 90 ? 2 : std::get<3>(kBlockMN_kNWarps_Stages_RS);
47
+ static constexpr bool Q_in_regs = Arch >= 90 ? false : std::get<4>(kBlockMN_kNWarps_Stages_RS);
48
+
49
+ using TileShape_MNK = cute::Shape<Int<kBlockM>, Int<kBlockN>, Int<kHeadDim>>;
50
+ using TileShape_MNK_PV = cute::Shape<Int<kBlockM>, Int<kHeadDimV>, Int<kBlockN>>;
51
+ using ClusterShape = cute::Shape<Int<ClusterM>, _1, _1>;
52
+ using CollectiveMainloop = std::conditional_t<
53
+ Arch >= 90,
54
+ flash::CollectiveMainloopFwdSm90<kStages, ClusterShape, TileShape_MNK, kHeadDimV, Element, float, cutlass::arch::Sm90, Is_causal, Is_local, Has_softcap, Varlen, PagedKVNonTMA, AppendKV, HasQv, MmaPV_is_RS, IntraWGOverlap, PackGQA, Split, V_colmajor>,
55
+ flash::CollectiveMainloopFwdSm80<kNWarps, kStages, Q_in_regs, TileShape_MNK, kHeadDimV, Element, float, cutlass::arch::Sm80, Is_causal, Is_local, Has_softcap, Varlen, PagedKVNonTMA, AppendKV, PackGQA, Split>
56
+ >;
57
+ using CollectiveEpilogue = flash::CollectiveEpilogueFwd<TileShape_MNK_PV, ClusterShape, ElementOut, ArchTag, CollectiveMainloop::NumMmaThreads, Varlen, PackGQA, Split, FP8_TransposeV>;
58
+
59
+ static constexpr int NumProducerThreads = Arch >= 90 ? CollectiveMainloop::NumProducerThreads : CollectiveMainloop::NumMmaThreads;
60
+ using SchedulerPersistent = std::conditional_t<Varlen,
61
+ flash::VarlenDynamicPersistentTileScheduler<kBlockM, CollectiveMainloop::NumMmaThreads, NumProducerThreads, Split, PackGQA, Arch >= 90 /*WarpSpecialized*/>,
62
+ std::conditional_t<!Is_causal && !Is_local,
63
+ flash::StaticPersistentTileScheduler<Split>,
64
+ flash::DynamicPersistentTileScheduler<CollectiveMainloop::NumMmaThreads, NumProducerThreads, Split, PackGQA, Arch >= 90 /*WarpSpecialized*/>
65
+ >
66
+ >;
67
+ using SchedulerSingleTile = flash::SingleTileScheduler<Varlen, Split, PackGQA, kBlockM>;
68
+ // If Split then we probably don't have enough work for PersistentScheduler to be useful.
69
+ // However, if Varlen (e.g., during decode where we have max_seqlens), using PersistentScheduler is better
70
+ // since we'll avoid launching a bunch of thread blocks that immediately exit.
71
+ // On Sm80, noncausal persistent seems a bit slower.
72
+ static constexpr bool UsePersistentScheduler = Arch >= 90 ? !(Split && !Varlen) : ((Is_causal && !Varlen) || (Varlen && Split));
73
+ using Scheduler = std::conditional_t<!UsePersistentScheduler, SchedulerSingleTile, SchedulerPersistent>;
74
+ using AttnKernel = std::conditional_t<
75
+ Arch >= 90,
76
+ flash::enable_sm90_or_later<flash::FlashAttnFwdSm90<CollectiveMainloop, CollectiveEpilogue, Scheduler>>,
77
+ flash::enable_sm80_to_sm89<flash::FlashAttnFwdSm80<CollectiveMainloop, CollectiveEpilogue, Scheduler>>
78
+ >;
79
+
80
+ bool const is_varlen_q = params.cu_seqlens_q;
81
+ bool const is_varlen_k = params.cu_seqlens_k;
82
+ bool const is_varlen_k_new = params.cu_seqlens_knew;
83
+ int seqlen_q = !is_varlen_q ? params.seqlen_q : params.total_q;
84
+ int batch_q = !is_varlen_q ? params.b : 1;
85
+ int batch_k = !is_varlen_k ? (params.kv_batch_idx ? params.b_k : params.b) : 1;
86
+ typename CollectiveMainloop::StrideV v_strides =
87
+ cute::conditional_return<!V_colmajor>(
88
+ make_stride(params.v_row_stride, _1{}, params.v_head_stride, !is_varlen_k ? params.v_batch_stride : 0),
89
+ make_stride(_1{}, params.v_dim_stride, params.v_head_stride, !is_varlen_k ? params.v_batch_stride : 0));
90
+ typename CollectiveMainloop::Arguments mainloop_args {
91
+ static_cast<Element const*>(params.q_ptr),
92
+ {seqlen_q, params.d, params.h, batch_q}, // shape_Q
93
+ {params.q_row_stride, _1{}, params.q_head_stride, !is_varlen_q ? params.q_batch_stride : 0}, // stride_Q
94
+ static_cast<Element*>(params.k_ptr),
95
+ {!params.page_table ? (!is_varlen_k ? params.seqlen_k : params.total_k) : params.page_size,
96
+ params.d, params.h_k, !params.page_table ? batch_k : params.num_pages}, // shape_K
97
+ {params.k_row_stride, _1{}, params.k_head_stride, !is_varlen_k ? params.k_batch_stride : 0}, // stride_K
98
+ static_cast<Element*>(params.v_ptr),
99
+ params.dv, // headdim_v
100
+ v_strides, // stride_V
101
+ static_cast<Element const*>(params.knew_ptr),
102
+ {!is_varlen_k_new ? params.seqlen_knew : params.total_knew, params.d, params.h_k, !is_varlen_k_new ? params.b : 1}, // shape_K_new
103
+ {params.knew_row_stride, _1{}, params.knew_head_stride, !is_varlen_k_new ? params.knew_batch_stride : 0}, // stride_K_new
104
+ static_cast<Element const*>(params.vnew_ptr),
105
+ {params.vnew_row_stride, _1{}, params.vnew_head_stride, !is_varlen_k_new ? params.vnew_batch_stride : 0}, // stride_V_new
106
+ static_cast<Element const*>(params.qv_ptr),
107
+ {params.qv_row_stride, _1{}, params.qv_head_stride, !is_varlen_q ? params.qv_batch_stride : 0}, // stride_Qv
108
+ static_cast<Element const*>(params.rotary_cos_ptr),
109
+ {params.seqlen_k, params.rotary_dim / 2}, // shape_rotary, the seqlen shape doesn't matter
110
+ {params.rotary_dim / 2, _1{}}, // stride_rotary_cos
111
+ static_cast<Element const*>(params.rotary_sin_ptr),
112
+ {params.rotary_dim / 2, _1{}}, // stride_rotary_sin
113
+ params.is_rotary_interleaved,
114
+ params.page_table,
115
+ // if page_size is not set, avoid dividing by zero
116
+ {params.kv_batch_idx ? params.b_k : params.b, !params.page_table ? 0 : params.seqlen_k / params.page_size}, // shape_page_table
117
+ {params.page_table_batch_stride, _1{}}, // stride_page_table
118
+ params.scale_softmax,
119
+ params.q_descale_ptr, params.k_descale_ptr, params.v_descale_ptr,
120
+ {params.q_descale_batch_stride, params.q_descale_head_stride},
121
+ {params.k_descale_batch_stride, params.k_descale_head_stride},
122
+ {params.v_descale_batch_stride, params.v_descale_head_stride},
123
+ params.window_size_left, params.window_size_right, params.attention_chunk,
124
+ params.softcap,
125
+ params.num_splits,
126
+ params.kv_batch_idx,
127
+ params.cu_seqlens_q, params.cu_seqlens_k, params.cu_seqlens_knew,
128
+ params.seqused_q, params.seqused_k,
129
+ params.leftpad_k, params.seqlens_rotary
130
+ };
131
+ typename CollectiveEpilogue::Arguments epilogue_args {
132
+ static_cast<ElementOut*>(params.o_ptr),
133
+ {seqlen_q, params.dv, params.h, batch_q, params.num_splits}, // shape_O
134
+ {params.o_row_stride, _1{}, params.o_head_stride, !is_varlen_q ? params.o_batch_stride : 0, 0}, // stride_O
135
+ static_cast<float*>(params.oaccum_ptr),
136
+ {params.oaccum_row_stride, _1{}, params.oaccum_head_stride, !is_varlen_q ? params.oaccum_batch_stride : 0, params.oaccum_split_stride}, // stride_O_partial
137
+ static_cast<float*>(params.softmax_lse_ptr),
138
+ {_1{}, seqlen_q, !is_varlen_q ? params.h * seqlen_q : 0, 0}, // stride_LSE
139
+ static_cast<float*>(params.softmax_lseaccum_ptr),
140
+ {_1{}, seqlen_q, !is_varlen_q ? params.h * seqlen_q : 0, params.h * seqlen_q * batch_q}, // stride_LSE_partial
141
+ params.h_k,
142
+ params.cu_seqlens_q, params.seqused_q
143
+ };
144
+
145
+ int qhead_per_khead = !PackGQA ? 1 : cutlass::ceil_div(params.h, params.h_k);
146
+ int num_blocks_m = cutlass::ceil_div(params.seqlen_q * qhead_per_khead, get<0>(TileShape_MNK{}));
147
+ num_blocks_m = cutlass::round_up(num_blocks_m, size<0>(ClusterShape{}));
148
+ typename flash::TileSchedulerArguments scheduler_args {
149
+ num_blocks_m, !PackGQA ? params.h : params.h_k, params.b, params.num_splits,
150
+ params.h / params.h_k,
151
+ params.seqlen_q,
152
+ params.seqlen_k, params.d, params.dv, sizeof(Element),
153
+ params.tile_count_semaphore, params.cu_seqlens_q, params.seqused_q,
154
+ // params.num_m_blocks_ptr,
155
+ params.num_splits_dynamic_ptr,
156
+ };
157
+
158
+ if (Varlen && params.num_splits_dynamic_ptr && !params.skip_scheduler_metadata_computation) {
159
+ prepare_varlen_num_blocks(params, stream, PackGQA, kBlockM, kBlockN, Arch >= 90 /*enable_pdl*/);
160
+ CHECK_CUDA_KERNEL_LAUNCH();
161
+ }
162
+
163
+ int device;
164
+ CHECK_CUDA(cudaGetDevice(&device));
165
+ typename AttnKernel::Params kernel_params = AttnKernel::to_underlying_arguments({
166
+ mainloop_args, epilogue_args, {device, params.num_sm}, scheduler_args
167
+ });
168
+
169
+ dim3 grid_dims = AttnKernel::get_grid_shape(kernel_params);
170
+ dim3 block_dims = AttnKernel::get_block_shape();
171
+ int smem_size = AttnKernel::SharedStorageSize;
172
+ // int smem_size_q = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_q));
173
+ // int smem_size_k = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_k));
174
+ // int smem_size_v = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_v));
175
+ // printf("smem_size = %d, q = %d, k = %d, v = %d\n", smem_size, smem_size_q, smem_size_k, smem_size_v);
176
+ // Get the ptr to kernel function.
177
+ if constexpr (size(ClusterShape{}) > 1) {
178
+ void const* kernel = (void const*) cutlass::device_kernel<AttnKernel>;
179
+ if (smem_size >= 48 * 1024) {
180
+ CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
181
+ }
182
+ dim3 cluster_dims(size<0>(ClusterShape{}), size<1>(ClusterShape{}), size<2>(ClusterShape{}));
183
+ cutlass::ClusterLaunchParams launch_params{grid_dims, block_dims, cluster_dims, smem_size, stream};
184
+ cutlass::launch_kernel_on_cluster(launch_params, kernel, kernel_params);
185
+ } else {
186
+ auto kernel = cutlass::device_kernel<AttnKernel>;
187
+ if (smem_size >= 48 * 1024) {
188
+ CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
189
+ }
190
+ // kernel<<<grid_dims, block_dims, smem_size, stream>>>(kernel_params);
191
+ cutlass::kernel_launch<AttnKernel>(grid_dims, block_dims, smem_size, stream, kernel_params,
192
+ Arch >= 90 && Varlen && params.num_splits_dynamic_ptr && !params.skip_scheduler_metadata_computation /*launch_with_pdl*/);
193
+ }
194
+ CHECK_CUDA_KERNEL_LAUNCH();
195
+ }
196
+
197
+ template<int Arch, typename T, int kHeadDim, int kHeadDimV, bool Split, bool PagedKVNonTMA, bool Has_softcap, bool PackGQA>
198
+ void run_mha_fwd_(Flash_fwd_params &params, cudaStream_t stream) {
199
+ static_assert(sizeof(T) == 2 || sizeof(T) == 1, "Only 16bit and 8bit are supported");
200
+ static constexpr bool Is_FP8 = cute::is_same_v<T, cutlass::float_e4m3_t> || cute::is_same_v<T, cutlass::float_e5m2_t>;
201
+ using T_out = std::conditional_t<!Is_FP8, T, cutlass::bfloat16_t>;
202
+ CAUSAL_LOCAL_SWITCH(params.is_causal, params.is_local, Is_causal, Is_local, [&] {
203
+ VCOLMAJOR_SWITCH(params.v_dim_stride != 1, V_colmajor_, [&] {
204
+ static constexpr bool V_colmajor = V_colmajor_ && sizeof(T) == 1;
205
+ VARLEN_SWITCH(params.cu_seqlens_q || params.cu_seqlens_k || params.seqused_q || params.seqused_k || params.leftpad_k, Varlen, [&] {
206
+ // Only needed here to decide if we should use cluster
207
+ static constexpr int kBlockM = Arch >= 90 ? std::get<0>(tile_size_fwd_sm90(kHeadDim, kHeadDimV, Is_causal, Is_local, sizeof(T) /*element_size*/, V_colmajor, PagedKVNonTMA, Has_softcap)) : 128;
208
+
209
+ static constexpr bool Enable_cluster = Arch == 90 && (sizeof(T) == 2 ? (kHeadDim >= 128) : (kHeadDim == 192)) && !Is_causal && !Is_local && !Split && !PagedKVNonTMA && !Varlen;
210
+ BOOL_SWITCH(params.qv_ptr, HasQV_, [&] {
211
+ static constexpr bool HasQv = HasQV_ && Arch == 90 && !Is_FP8 && kHeadDim == 64 && kHeadDimV >= 256;
212
+ APPENDKV_SWITCH(params.knew_ptr, AppendKV, [&] {
213
+ // Only use Cluster if number of tiles along seqlen_q is even and not varlen
214
+ CLUSTER_SWITCH(cutlass::ceil_div(params.seqlen_q * (!PackGQA ? 1 : params.h / params.h_k), kBlockM) % 2 == 0, Use_cluster, [&] {
215
+ static constexpr int ClusterM = Enable_cluster && Use_cluster ? 2 : 1;
216
+ run_flash_fwd<Arch, kHeadDim, kHeadDimV, ClusterM, T, T_out, Is_causal, Is_local, Has_softcap, Varlen, PagedKVNonTMA, AppendKV && Varlen, HasQv, PackGQA, Split, V_colmajor>(params, stream);
217
+ });
218
+ });
219
+ });
220
+ });
221
+ });
222
+ });
223
+ }
flash-attn/flash_prepare_scheduler.cu ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ #include "cutlass/fast_math.h"
6
+ #include "cutlass/barrier.h"
7
+ #include "cutlass/arch/barrier.h"
8
+
9
+ #include "cutlass/arch/grid_dependency_control.h"
10
+
11
+ #include "flash.h"
12
+
13
+ namespace flash {
14
+
15
+ __global__ void prepare_varlen_num_blocks_kernel(
16
+ int seqlen_q_static, int seqlen_k_static, int seqlen_k_new_static,
17
+ int const* const cu_seqlens_q, int const* const cu_seqlens_k, int const* const cu_seqlens_k_new,
18
+ int const* const seqused_q, int const* const seqused_k, int const* const leftpad_k_ptr,
19
+ int num_batch, int num_head, int qhead_per_khead, int num_sm, int num_splits_static,
20
+ cutlass::FastDivmod blockm_divmod, cutlass::FastDivmod blockn_divmod,
21
+ int* const tile_count_semaphore,
22
+ // int* const num_m_blocks_ptr,
23
+ int* const num_splits_dynamic_ptr,
24
+ bool enable_pdl) {
25
+
26
+ static constexpr int kNumBatchPerWarp = cutlass::NumThreadsPerWarp - 1;
27
+ static constexpr int kSmemSize = 1;
28
+ // Assume that there's only one block in the grid
29
+ __shared__ int total_blocks_smem[kSmemSize];
30
+
31
+ // There's only 1 block in the grid, so might as well start launching the main attn kernel
32
+ if (enable_pdl) { cutlass::arch::launch_dependent_grids(); }
33
+
34
+ if (threadIdx.x < kSmemSize) { total_blocks_smem[threadIdx.x] = 0; }
35
+ __syncthreads();
36
+
37
+ if (threadIdx.x == 0 && tile_count_semaphore) { *tile_count_semaphore = 0; }
38
+
39
+ int lane = threadIdx.x % cutlass::NumThreadsPerWarp;
40
+
41
+ auto get_num_m_blocks = [&](int bidb_start) {
42
+ int batch_idx = lane + bidb_start;
43
+ int seqlen;
44
+ if (seqused_q) {
45
+ seqlen = batch_idx < num_batch ? seqused_q[batch_idx] : 0;
46
+ } else if (cu_seqlens_q) {
47
+ int cur_cu_seqlen = batch_idx <= num_batch ? cu_seqlens_q[batch_idx] : 0;
48
+ int next_cu_seqlen = __shfl_down_sync(0xffffffff, cur_cu_seqlen, 1);
49
+ seqlen = next_cu_seqlen - cur_cu_seqlen;
50
+ } else {
51
+ seqlen = seqlen_q_static;
52
+ }
53
+ seqlen *= qhead_per_khead;
54
+ return batch_idx < num_batch && lane < kNumBatchPerWarp
55
+ ? blockm_divmod.div(seqlen + blockm_divmod.divisor - 1) : 0;
56
+ };
57
+
58
+ auto get_num_n_blocks = [&](int bidb_start) {
59
+ int batch_idx = lane + bidb_start;
60
+ int leftpad_k = batch_idx < num_batch && leftpad_k_ptr != nullptr ? leftpad_k_ptr[batch_idx] : 0;
61
+ int seqlen;
62
+ if (seqused_k) {
63
+ seqlen = batch_idx < num_batch ? seqused_k[batch_idx] : 0;
64
+ } else if (cu_seqlens_k) {
65
+ int cur_cu_seqlen = batch_idx <= num_batch ? cu_seqlens_k[batch_idx] : 0;
66
+ int next_cu_seqlen = __shfl_down_sync(0xffffffff, cur_cu_seqlen, 1);
67
+ seqlen = next_cu_seqlen - cur_cu_seqlen;
68
+ } else {
69
+ seqlen = seqlen_k_static;
70
+ }
71
+ int seqlen_new;
72
+ if (cu_seqlens_k_new) {
73
+ int cur_cu_seqlen_new = batch_idx <= num_batch ? cu_seqlens_k_new[batch_idx] : 0;
74
+ int next_cu_seqlen_new = __shfl_down_sync(0xffffffff, cur_cu_seqlen_new, 1);
75
+ seqlen_new = next_cu_seqlen_new - cur_cu_seqlen_new;
76
+ } else {
77
+ seqlen_new = seqlen_k_new_static;
78
+ }
79
+ // if (threadIdx.x == 0) { printf("seqlen = %d, seqlen_new = %d, leftpad_k = %d\n", seqlen, seqlen_new, leftpad_k); }
80
+ seqlen = seqlen - leftpad_k + seqlen_new;
81
+ return batch_idx < num_batch && lane < kNumBatchPerWarp
82
+ ? blockn_divmod.div(seqlen + blockn_divmod.divisor - 1) : 0;
83
+ };
84
+
85
+ int warp_idx = threadIdx.x / cutlass::NumThreadsPerWarp;
86
+ int bidb_start = kNumBatchPerWarp * warp_idx;
87
+ int num_m_blocks = get_num_m_blocks(bidb_start);
88
+ int num_n_blocks = get_num_n_blocks(bidb_start);
89
+
90
+ int total_blocks = num_m_blocks * num_n_blocks;
91
+ // Warp sum
92
+ #pragma unroll
93
+ for (int i = cutlass::NumThreadsPerWarp / 2; i >= 1; i /= 2) {
94
+ total_blocks += __shfl_down_sync(0xffffffff, total_blocks, i);
95
+ }
96
+ if (lane == 0) { atomicAdd(total_blocks_smem, total_blocks); }
97
+ __syncthreads();
98
+ total_blocks = total_blocks_smem[0];
99
+ // 10% margin
100
+ int blocks_per_sm = static_cast<int>(ceilf(float(total_blocks) * 1.1f * float(num_head) / float(num_sm)));
101
+ // blocks_per_sm = std::max(1, blocks_per_sm); // 1 is the minimum number of blocks per SM
102
+ int num_splits_dynamic = std::max(std::min((num_n_blocks + blocks_per_sm - 1) / blocks_per_sm, num_splits_static), 1);
103
+ if (bidb_start + lane < num_batch && lane < kNumBatchPerWarp) {
104
+ num_splits_dynamic_ptr[bidb_start + lane] = num_splits_dynamic;
105
+ // printf("idx = %d, num_m_blocks = %d, num_n_blocks = %d, num_split_static = %d, num_splits_dynamic = %d\n", bidb_start + lane, num_m_blocks_ptr[bidb_start + lane], num_n_blocks, num_splits_static, num_splits_dynamic);
106
+ }
107
+ }
108
+
109
+ } // flash
110
+
111
+ void prepare_varlen_num_blocks(Flash_fwd_params &params, cudaStream_t stream, bool packgqa,
112
+ int blockM, int blockN, bool enable_pdl) {
113
+ // Only support batch <= 992 (32 warps, each with 31 batches)
114
+ int qhead_per_khead = !packgqa ? 1 : cutlass::ceil_div(params.h, params.h_k);
115
+ flash::prepare_varlen_num_blocks_kernel<<<1 /*grid*/, 1024 /*block*/, 0, stream>>>(
116
+ params.seqlen_q, params.seqlen_k, params.seqlen_knew,
117
+ params.cu_seqlens_q, params.cu_seqlens_k, params.cu_seqlens_knew,
118
+ params.seqused_q, params.seqused_k, params.leftpad_k,
119
+ params.b, !packgqa ? params.h : params.h_k, qhead_per_khead, params.num_sm, params.num_splits,
120
+ cutlass::FastDivmod(blockM), cutlass::FastDivmod(blockN),
121
+ params.tile_count_semaphore,
122
+ // params.num_m_blocks_ptr,
123
+ params.num_splits_dynamic_ptr, enable_pdl);
124
+ }
flash-attn/heuristics.h ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ #pragma once
6
+
7
+ #include <vector>
8
+
9
+ inline bool should_pack_gqa(bool varlen_q, int seqlen_q, int qhead_per_khead, int blockM) {
10
+ // If varlen, we don't actually know seqlen_q but only max_seqlen_q.
11
+ if (varlen_q) return true;
12
+ // Heuristic: PackGQA is a bit slower but can help if seqlen_q is small or not near a multiple of kBlockM
13
+ auto round_up = [](int a, int b) { return (a + b - 1) / b * b; };
14
+ float nopack_gqa_efficiency = float(seqlen_q) / float(round_up(seqlen_q, blockM));
15
+ float pack_gqa_efficiency = float(seqlen_q * qhead_per_khead) / float(round_up(seqlen_q * qhead_per_khead, blockM));
16
+ return nopack_gqa_efficiency < 0.9 * pack_gqa_efficiency;
17
+ };
18
+
19
+ // Find the number of splits that maximizes the occupancy. For example, if we have
20
+ // batch * n_heads = 48 and we have 108 SMs, having 2 splits (efficiency = 0.89) is
21
+ // better than having 3 splits (efficiency = 0.67). However, we also don't want too many
22
+ // splits as that would incur more HBM reads/writes.
23
+ // So we find the best efficiency, then find the smallest number of splits that gets 85%
24
+ // of the best efficiency.
25
+ inline int num_splits_heuristic(int total_mblocks, int num_SMs, int num_n_blocks, int num_m_blocks, int size_one_kv_head, bool is_causal_or_local, int max_splits) {
26
+ // If we have enough to almost fill the SMs, then just use 1 split
27
+ // However, in the case of super long seqlen where each head of KV doesn't even fit into
28
+ // L2 (we assume that L2 size is 50MB), we want to split.
29
+ if (total_mblocks >= 0.8f * num_SMs) {
30
+ int const size_l2 = 50 * 1024 * 1024;
31
+ // Only split if there are enough queries to go over the KV at least twice
32
+ // Don't split if causal
33
+ if (size_one_kv_head > size_l2 && num_m_blocks >= num_SMs * 2 && !is_causal_or_local) {
34
+ return std::min((size_one_kv_head + size_l2 - 1) / size_l2, max_splits);
35
+ } else {
36
+ return 1;
37
+ }
38
+ }
39
+ // If num_n_blocks is too small, use 1 split. For example, we never split for hdim = 128 and seqlen_k = 512.
40
+ if (num_n_blocks <= 4) { return 1; }
41
+ max_splits = std::min({max_splits, num_SMs, num_n_blocks});
42
+ float max_efficiency = 0.f;
43
+ std::vector<float> efficiency;
44
+ efficiency.reserve(max_splits);
45
+ for (int num_splits = 1; num_splits <= max_splits; num_splits++) {
46
+ float n_waves = float(total_mblocks * num_splits) / num_SMs;
47
+ float eff = n_waves / ceil(n_waves);
48
+ // printf("num_splits = %d, eff = %f\n", num_splits, eff);
49
+ if (eff > max_efficiency) { max_efficiency = eff; }
50
+ efficiency.push_back(eff);
51
+ }
52
+ for (int num_splits = 1; num_splits <= max_splits; num_splits++) {
53
+ if (efficiency[num_splits - 1] >= 0.85 * max_efficiency) {
54
+ // printf("num_splits chosen = %d\n", num_splits);
55
+ return num_splits;
56
+ }
57
+ }
58
+ return 1;
59
+ }
flash-attn/instantiations/flash_bwd_hdim128_bf16_sm80.cu ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
2
+ // Splitting the different template instantiations to different files to speed up compilation.
3
+ // This file is auto-generated. See "generate_kernels.py"
4
+
5
+ #include "flash_bwd_launch_template.h"
6
+
7
+ #ifndef FLASHATTENTION_DISABLE_SM8x
8
+ #ifndef FLASHATTENTION_DISABLE_HDIM128
9
+ template<>
10
+ void run_mha_bwd_<80, cutlass::bfloat16_t, 128, false>(Flash_bwd_params &params, cudaStream_t stream) {
11
+ run_mha_bwd_hdim128<80, cutlass::bfloat16_t, false>(params, stream);
12
+ }
13
+ template<>
14
+ void run_mha_bwd_<86, cutlass::bfloat16_t, 128, false>(Flash_bwd_params &params, cudaStream_t stream) {
15
+ run_mha_bwd_hdim128<86, cutlass::bfloat16_t, false>(params, stream);
16
+ }
17
+ #endif
18
+ #endif
flash-attn/instantiations/flash_bwd_hdim128_bf16_sm90.cu ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
2
+ // Splitting the different template instantiations to different files to speed up compilation.
3
+ // This file is auto-generated. See "generate_kernels.py"
4
+
5
+ #include "flash_bwd_launch_template.h"
6
+
7
+ #ifndef FLASHATTENTION_DISABLE_HDIM128
8
+ template<>
9
+ void run_mha_bwd_<90, cutlass::bfloat16_t, 128, false>(Flash_bwd_params &params, cudaStream_t stream) {
10
+ run_mha_bwd_hdim128<90, cutlass::bfloat16_t, false>(params, stream);
11
+ }
12
+ #endif
flash-attn/instantiations/flash_bwd_hdim128_bf16_softcap_sm80.cu ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
2
+ // Splitting the different template instantiations to different files to speed up compilation.
3
+ // This file is auto-generated. See "generate_kernels.py"
4
+
5
+ #include "flash_bwd_launch_template.h"
6
+
7
+ #ifndef FLASHATTENTION_DISABLE_SM8x
8
+ #ifndef FLASHATTENTION_DISABLE_HDIM128
9
+ template<>
10
+ void run_mha_bwd_<80, cutlass::bfloat16_t, 128, true>(Flash_bwd_params &params, cudaStream_t stream) {
11
+ run_mha_bwd_hdim128<80, cutlass::bfloat16_t, true>(params, stream);
12
+ }
13
+ template<>
14
+ void run_mha_bwd_<86, cutlass::bfloat16_t, 128, true>(Flash_bwd_params &params, cudaStream_t stream) {
15
+ run_mha_bwd_hdim128<86, cutlass::bfloat16_t, true>(params, stream);
16
+ }
17
+ #endif
18
+ #endif
flash-attn/instantiations/flash_bwd_hdim128_bf16_softcap_sm90.cu ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
2
+ // Splitting the different template instantiations to different files to speed up compilation.
3
+ // This file is auto-generated. See "generate_kernels.py"
4
+
5
+ #include "flash_bwd_launch_template.h"
6
+
7
+ #ifndef FLASHATTENTION_DISABLE_HDIM128
8
+ template<>
9
+ void run_mha_bwd_<90, cutlass::bfloat16_t, 128, true>(Flash_bwd_params &params, cudaStream_t stream) {
10
+ run_mha_bwd_hdim128<90, cutlass::bfloat16_t, true>(params, stream);
11
+ }
12
+ #endif
flash-attn/instantiations/flash_bwd_hdim128_bf16_softcapall_sm90.cu ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
2
+ // Splitting the different template instantiations to different files to speed up compilation.
3
+ // This file is auto-generated. See "generate_kernels.py"
4
+
5
+ #include "flash_bwd_hdim128_bf16_sm90.cu"
6
+ #include "flash_bwd_hdim128_bf16_softcap_sm90.cu"
flash-attn/instantiations/flash_bwd_hdim128_fp16_sm80.cu ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
2
+ // Splitting the different template instantiations to different files to speed up compilation.
3
+ // This file is auto-generated. See "generate_kernels.py"
4
+
5
+ #include "flash_bwd_launch_template.h"
6
+
7
+ #ifndef FLASHATTENTION_DISABLE_SM8x
8
+ #ifndef FLASHATTENTION_DISABLE_HDIM128
9
+ template<>
10
+ void run_mha_bwd_<80, cutlass::half_t, 128, false>(Flash_bwd_params &params, cudaStream_t stream) {
11
+ run_mha_bwd_hdim128<80, cutlass::half_t, false>(params, stream);
12
+ }
13
+ template<>
14
+ void run_mha_bwd_<86, cutlass::half_t, 128, false>(Flash_bwd_params &params, cudaStream_t stream) {
15
+ run_mha_bwd_hdim128<86, cutlass::half_t, false>(params, stream);
16
+ }
17
+ #endif
18
+ #endif
flash-attn/instantiations/flash_bwd_hdim128_fp16_sm90.cu ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
2
+ // Splitting the different template instantiations to different files to speed up compilation.
3
+ // This file is auto-generated. See "generate_kernels.py"
4
+
5
+ #include "flash_bwd_launch_template.h"
6
+
7
+ #ifndef FLASHATTENTION_DISABLE_HDIM128
8
+ template<>
9
+ void run_mha_bwd_<90, cutlass::half_t, 128, false>(Flash_bwd_params &params, cudaStream_t stream) {
10
+ run_mha_bwd_hdim128<90, cutlass::half_t, false>(params, stream);
11
+ }
12
+ #endif
flash-attn/instantiations/flash_bwd_hdim128_fp16_softcap_sm80.cu ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
2
+ // Splitting the different template instantiations to different files to speed up compilation.
3
+ // This file is auto-generated. See "generate_kernels.py"
4
+
5
+ #include "flash_bwd_launch_template.h"
6
+
7
+ #ifndef FLASHATTENTION_DISABLE_SM8x
8
+ #ifndef FLASHATTENTION_DISABLE_HDIM128
9
+ template<>
10
+ void run_mha_bwd_<80, cutlass::half_t, 128, true>(Flash_bwd_params &params, cudaStream_t stream) {
11
+ run_mha_bwd_hdim128<80, cutlass::half_t, true>(params, stream);
12
+ }
13
+ template<>
14
+ void run_mha_bwd_<86, cutlass::half_t, 128, true>(Flash_bwd_params &params, cudaStream_t stream) {
15
+ run_mha_bwd_hdim128<86, cutlass::half_t, true>(params, stream);
16
+ }
17
+ #endif
18
+ #endif
flash-attn/instantiations/flash_bwd_hdim128_fp16_softcap_sm90.cu ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
2
+ // Splitting the different template instantiations to different files to speed up compilation.
3
+ // This file is auto-generated. See "generate_kernels.py"
4
+
5
+ #include "flash_bwd_launch_template.h"
6
+
7
+ #ifndef FLASHATTENTION_DISABLE_HDIM128
8
+ template<>
9
+ void run_mha_bwd_<90, cutlass::half_t, 128, true>(Flash_bwd_params &params, cudaStream_t stream) {
10
+ run_mha_bwd_hdim128<90, cutlass::half_t, true>(params, stream);
11
+ }
12
+ #endif
flash-attn/instantiations/flash_bwd_hdim128_fp16_softcapall_sm90.cu ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
2
+ // Splitting the different template instantiations to different files to speed up compilation.
3
+ // This file is auto-generated. See "generate_kernels.py"
4
+
5
+ #include "flash_bwd_hdim128_fp16_sm90.cu"
6
+ #include "flash_bwd_hdim128_fp16_softcap_sm90.cu"
flash-attn/instantiations/flash_bwd_hdim192_bf16_sm80.cu ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
2
+ // Splitting the different template instantiations to different files to speed up compilation.
3
+ // This file is auto-generated. See "generate_kernels.py"
4
+
5
+ #include "flash_bwd_launch_template.h"
6
+
7
+ #ifndef FLASHATTENTION_DISABLE_SM8x
8
+ #ifndef FLASHATTENTION_DISABLE_HDIM192
9
+ template<>
10
+ void run_mha_bwd_<80, cutlass::bfloat16_t, 192, false>(Flash_bwd_params &params, cudaStream_t stream) {
11
+ run_mha_bwd_hdim192<80, cutlass::bfloat16_t, false>(params, stream);
12
+ }
13
+ template<>
14
+ void run_mha_bwd_<86, cutlass::bfloat16_t, 192, false>(Flash_bwd_params &params, cudaStream_t stream) {
15
+ run_mha_bwd_hdim192<86, cutlass::bfloat16_t, false>(params, stream);
16
+ }
17
+ #endif
18
+ #endif
flash-attn/instantiations/flash_bwd_hdim192_bf16_sm90.cu ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
2
+ // Splitting the different template instantiations to different files to speed up compilation.
3
+ // This file is auto-generated. See "generate_kernels.py"
4
+
5
+ #include "flash_bwd_launch_template.h"
6
+
7
+ #ifndef FLASHATTENTION_DISABLE_HDIM192
8
+ template<>
9
+ void run_mha_bwd_<90, cutlass::bfloat16_t, 192, false>(Flash_bwd_params &params, cudaStream_t stream) {
10
+ run_mha_bwd_hdim192<90, cutlass::bfloat16_t, false>(params, stream);
11
+ }
12
+ #endif
flash-attn/instantiations/flash_bwd_hdim192_bf16_softcap_sm80.cu ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
2
+ // Splitting the different template instantiations to different files to speed up compilation.
3
+ // This file is auto-generated. See "generate_kernels.py"
4
+
5
+ #include "flash_bwd_launch_template.h"
6
+
7
+ #ifndef FLASHATTENTION_DISABLE_SM8x
8
+ #ifndef FLASHATTENTION_DISABLE_HDIM192
9
+ template<>
10
+ void run_mha_bwd_<80, cutlass::bfloat16_t, 192, true>(Flash_bwd_params &params, cudaStream_t stream) {
11
+ run_mha_bwd_hdim192<80, cutlass::bfloat16_t, true>(params, stream);
12
+ }
13
+ template<>
14
+ void run_mha_bwd_<86, cutlass::bfloat16_t, 192, true>(Flash_bwd_params &params, cudaStream_t stream) {
15
+ run_mha_bwd_hdim192<86, cutlass::bfloat16_t, true>(params, stream);
16
+ }
17
+ #endif
18
+ #endif
flash-attn/instantiations/flash_bwd_hdim192_bf16_softcap_sm90.cu ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
2
+ // Splitting the different template instantiations to different files to speed up compilation.
3
+ // This file is auto-generated. See "generate_kernels.py"
4
+
5
+ #include "flash_bwd_launch_template.h"
6
+
7
+ #ifndef FLASHATTENTION_DISABLE_HDIM192
8
+ template<>
9
+ void run_mha_bwd_<90, cutlass::bfloat16_t, 192, true>(Flash_bwd_params &params, cudaStream_t stream) {
10
+ run_mha_bwd_hdim192<90, cutlass::bfloat16_t, true>(params, stream);
11
+ }
12
+ #endif
flash-attn/instantiations/flash_bwd_hdim192_bf16_softcapall_sm90.cu ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
2
+ // Splitting the different template instantiations to different files to speed up compilation.
3
+ // This file is auto-generated. See "generate_kernels.py"
4
+
5
+ #include "flash_bwd_hdim192_bf16_sm90.cu"
6
+ #include "flash_bwd_hdim192_bf16_softcap_sm90.cu"
flash-attn/instantiations/flash_bwd_hdim192_fp16_sm80.cu ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
2
+ // Splitting the different template instantiations to different files to speed up compilation.
3
+ // This file is auto-generated. See "generate_kernels.py"
4
+
5
+ #include "flash_bwd_launch_template.h"
6
+
7
+ #ifndef FLASHATTENTION_DISABLE_SM8x
8
+ #ifndef FLASHATTENTION_DISABLE_HDIM192
9
+ template<>
10
+ void run_mha_bwd_<80, cutlass::half_t, 192, false>(Flash_bwd_params &params, cudaStream_t stream) {
11
+ run_mha_bwd_hdim192<80, cutlass::half_t, false>(params, stream);
12
+ }
13
+ template<>
14
+ void run_mha_bwd_<86, cutlass::half_t, 192, false>(Flash_bwd_params &params, cudaStream_t stream) {
15
+ run_mha_bwd_hdim192<86, cutlass::half_t, false>(params, stream);
16
+ }
17
+ #endif
18
+ #endif
flash-attn/instantiations/flash_bwd_hdim192_fp16_sm90.cu ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
2
+ // Splitting the different template instantiations to different files to speed up compilation.
3
+ // This file is auto-generated. See "generate_kernels.py"
4
+
5
+ #include "flash_bwd_launch_template.h"
6
+
7
+ #ifndef FLASHATTENTION_DISABLE_HDIM192
8
+ template<>
9
+ void run_mha_bwd_<90, cutlass::half_t, 192, false>(Flash_bwd_params &params, cudaStream_t stream) {
10
+ run_mha_bwd_hdim192<90, cutlass::half_t, false>(params, stream);
11
+ }
12
+ #endif
flash-attn/instantiations/flash_bwd_hdim192_fp16_softcap_sm80.cu ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
2
+ // Splitting the different template instantiations to different files to speed up compilation.
3
+ // This file is auto-generated. See "generate_kernels.py"
4
+
5
+ #include "flash_bwd_launch_template.h"
6
+
7
+ #ifndef FLASHATTENTION_DISABLE_SM8x
8
+ #ifndef FLASHATTENTION_DISABLE_HDIM192
9
+ template<>
10
+ void run_mha_bwd_<80, cutlass::half_t, 192, true>(Flash_bwd_params &params, cudaStream_t stream) {
11
+ run_mha_bwd_hdim192<80, cutlass::half_t, true>(params, stream);
12
+ }
13
+ template<>
14
+ void run_mha_bwd_<86, cutlass::half_t, 192, true>(Flash_bwd_params &params, cudaStream_t stream) {
15
+ run_mha_bwd_hdim192<86, cutlass::half_t, true>(params, stream);
16
+ }
17
+ #endif
18
+ #endif
flash-attn/instantiations/flash_bwd_hdim192_fp16_softcap_sm90.cu ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
2
+ // Splitting the different template instantiations to different files to speed up compilation.
3
+ // This file is auto-generated. See "generate_kernels.py"
4
+
5
+ #include "flash_bwd_launch_template.h"
6
+
7
+ #ifndef FLASHATTENTION_DISABLE_HDIM192
8
+ template<>
9
+ void run_mha_bwd_<90, cutlass::half_t, 192, true>(Flash_bwd_params &params, cudaStream_t stream) {
10
+ run_mha_bwd_hdim192<90, cutlass::half_t, true>(params, stream);
11
+ }
12
+ #endif
flash-attn/instantiations/flash_bwd_hdim192_fp16_softcapall_sm90.cu ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
2
+ // Splitting the different template instantiations to different files to speed up compilation.
3
+ // This file is auto-generated. See "generate_kernels.py"
4
+
5
+ #include "flash_bwd_hdim192_fp16_sm90.cu"
6
+ #include "flash_bwd_hdim192_fp16_softcap_sm90.cu"
flash-attn/instantiations/flash_bwd_hdim256_bf16_sm80.cu ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
2
+ // Splitting the different template instantiations to different files to speed up compilation.
3
+ // This file is auto-generated. See "generate_kernels.py"
4
+
5
+ #include "flash_bwd_launch_template.h"
6
+
7
+ #ifndef FLASHATTENTION_DISABLE_SM8x
8
+ #ifndef FLASHATTENTION_DISABLE_HDIM256
9
+ template<>
10
+ void run_mha_bwd_<80, cutlass::bfloat16_t, 256, false>(Flash_bwd_params &params, cudaStream_t stream) {
11
+ run_mha_bwd_hdim256<80, cutlass::bfloat16_t, false>(params, stream);
12
+ }
13
+ template<>
14
+ void run_mha_bwd_<86, cutlass::bfloat16_t, 256, false>(Flash_bwd_params &params, cudaStream_t stream) {
15
+ run_mha_bwd_hdim256<86, cutlass::bfloat16_t, false>(params, stream);
16
+ }
17
+ #endif
18
+ #endif
flash-attn/instantiations/flash_bwd_hdim256_bf16_sm90.cu ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
2
+ // Splitting the different template instantiations to different files to speed up compilation.
3
+ // This file is auto-generated. See "generate_kernels.py"
4
+
5
+ #include "flash_bwd_launch_template.h"
6
+
7
+ #ifndef FLASHATTENTION_DISABLE_HDIM256
8
+ template<>
9
+ void run_mha_bwd_<90, cutlass::bfloat16_t, 256, false>(Flash_bwd_params &params, cudaStream_t stream) {
10
+ run_mha_bwd_hdim256<90, cutlass::bfloat16_t, false>(params, stream);
11
+ }
12
+ #endif
flash-attn/instantiations/flash_bwd_hdim256_bf16_softcap_sm80.cu ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
2
+ // Splitting the different template instantiations to different files to speed up compilation.
3
+ // This file is auto-generated. See "generate_kernels.py"
4
+
5
+ #include "flash_bwd_launch_template.h"
6
+
7
+ #ifndef FLASHATTENTION_DISABLE_SM8x
8
+ #ifndef FLASHATTENTION_DISABLE_HDIM256
9
+ template<>
10
+ void run_mha_bwd_<80, cutlass::bfloat16_t, 256, true>(Flash_bwd_params &params, cudaStream_t stream) {
11
+ run_mha_bwd_hdim256<80, cutlass::bfloat16_t, true>(params, stream);
12
+ }
13
+ template<>
14
+ void run_mha_bwd_<86, cutlass::bfloat16_t, 256, true>(Flash_bwd_params &params, cudaStream_t stream) {
15
+ run_mha_bwd_hdim256<86, cutlass::bfloat16_t, true>(params, stream);
16
+ }
17
+ #endif
18
+ #endif
flash-attn/instantiations/flash_bwd_hdim256_bf16_softcap_sm90.cu ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
2
+ // Splitting the different template instantiations to different files to speed up compilation.
3
+ // This file is auto-generated. See "generate_kernels.py"
4
+
5
+ #include "flash_bwd_launch_template.h"
6
+
7
+ #ifndef FLASHATTENTION_DISABLE_HDIM256
8
+ template<>
9
+ void run_mha_bwd_<90, cutlass::bfloat16_t, 256, true>(Flash_bwd_params &params, cudaStream_t stream) {
10
+ run_mha_bwd_hdim256<90, cutlass::bfloat16_t, true>(params, stream);
11
+ }
12
+ #endif
flash-attn/instantiations/flash_bwd_hdim256_bf16_softcapall_sm90.cu ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
2
+ // Splitting the different template instantiations to different files to speed up compilation.
3
+ // This file is auto-generated. See "generate_kernels.py"
4
+
5
+ #include "flash_bwd_hdim256_bf16_sm90.cu"
6
+ #include "flash_bwd_hdim256_bf16_softcap_sm90.cu"
flash-attn/instantiations/flash_bwd_hdim256_fp16_sm80.cu ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
2
+ // Splitting the different template instantiations to different files to speed up compilation.
3
+ // This file is auto-generated. See "generate_kernels.py"
4
+
5
+ #include "flash_bwd_launch_template.h"
6
+
7
+ #ifndef FLASHATTENTION_DISABLE_SM8x
8
+ #ifndef FLASHATTENTION_DISABLE_HDIM256
9
+ template<>
10
+ void run_mha_bwd_<80, cutlass::half_t, 256, false>(Flash_bwd_params &params, cudaStream_t stream) {
11
+ run_mha_bwd_hdim256<80, cutlass::half_t, false>(params, stream);
12
+ }
13
+ template<>
14
+ void run_mha_bwd_<86, cutlass::half_t, 256, false>(Flash_bwd_params &params, cudaStream_t stream) {
15
+ run_mha_bwd_hdim256<86, cutlass::half_t, false>(params, stream);
16
+ }
17
+ #endif
18
+ #endif
flash-attn/instantiations/flash_bwd_hdim256_fp16_sm90.cu ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
2
+ // Splitting the different template instantiations to different files to speed up compilation.
3
+ // This file is auto-generated. See "generate_kernels.py"
4
+
5
+ #include "flash_bwd_launch_template.h"
6
+
7
+ #ifndef FLASHATTENTION_DISABLE_HDIM256
8
+ template<>
9
+ void run_mha_bwd_<90, cutlass::half_t, 256, false>(Flash_bwd_params &params, cudaStream_t stream) {
10
+ run_mha_bwd_hdim256<90, cutlass::half_t, false>(params, stream);
11
+ }
12
+ #endif