Commit
·
eb8ddce
0
Parent(s):
Convert FA3 to Kernel Hub format
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- build.toml +593 -0
- flake.lock +168 -0
- flake.nix +17 -0
- flash-attn/block.h +139 -0
- flash-attn/copy_sm90_bulk_reduce.hpp +49 -0
- flash-attn/cuda_check.h +19 -0
- flash-attn/epilogue_bwd.hpp +533 -0
- flash-attn/epilogue_fwd.hpp +484 -0
- flash-attn/flash.h +218 -0
- flash-attn/flash_api.cpp +1720 -0
- flash-attn/flash_bwd_kernel_sm80.h +173 -0
- flash-attn/flash_bwd_kernel_sm90.h +282 -0
- flash-attn/flash_bwd_launch_template.h +390 -0
- flash-attn/flash_bwd_postprocess_kernel.h +256 -0
- flash-attn/flash_bwd_preprocess_kernel.h +252 -0
- flash-attn/flash_fwd_combine.cu +13 -0
- flash-attn/flash_fwd_combine_kernel.h +482 -0
- flash-attn/flash_fwd_combine_launch_template.h +80 -0
- flash-attn/flash_fwd_kernel_sm80.h +215 -0
- flash-attn/flash_fwd_kernel_sm90.h +458 -0
- flash-attn/flash_fwd_launch_template.h +223 -0
- flash-attn/flash_prepare_scheduler.cu +124 -0
- flash-attn/heuristics.h +59 -0
- flash-attn/instantiations/flash_bwd_hdim128_bf16_sm80.cu +18 -0
- flash-attn/instantiations/flash_bwd_hdim128_bf16_sm90.cu +12 -0
- flash-attn/instantiations/flash_bwd_hdim128_bf16_softcap_sm80.cu +18 -0
- flash-attn/instantiations/flash_bwd_hdim128_bf16_softcap_sm90.cu +12 -0
- flash-attn/instantiations/flash_bwd_hdim128_bf16_softcapall_sm90.cu +6 -0
- flash-attn/instantiations/flash_bwd_hdim128_fp16_sm80.cu +18 -0
- flash-attn/instantiations/flash_bwd_hdim128_fp16_sm90.cu +12 -0
- flash-attn/instantiations/flash_bwd_hdim128_fp16_softcap_sm80.cu +18 -0
- flash-attn/instantiations/flash_bwd_hdim128_fp16_softcap_sm90.cu +12 -0
- flash-attn/instantiations/flash_bwd_hdim128_fp16_softcapall_sm90.cu +6 -0
- flash-attn/instantiations/flash_bwd_hdim192_bf16_sm80.cu +18 -0
- flash-attn/instantiations/flash_bwd_hdim192_bf16_sm90.cu +12 -0
- flash-attn/instantiations/flash_bwd_hdim192_bf16_softcap_sm80.cu +18 -0
- flash-attn/instantiations/flash_bwd_hdim192_bf16_softcap_sm90.cu +12 -0
- flash-attn/instantiations/flash_bwd_hdim192_bf16_softcapall_sm90.cu +6 -0
- flash-attn/instantiations/flash_bwd_hdim192_fp16_sm80.cu +18 -0
- flash-attn/instantiations/flash_bwd_hdim192_fp16_sm90.cu +12 -0
- flash-attn/instantiations/flash_bwd_hdim192_fp16_softcap_sm80.cu +18 -0
- flash-attn/instantiations/flash_bwd_hdim192_fp16_softcap_sm90.cu +12 -0
- flash-attn/instantiations/flash_bwd_hdim192_fp16_softcapall_sm90.cu +6 -0
- flash-attn/instantiations/flash_bwd_hdim256_bf16_sm80.cu +18 -0
- flash-attn/instantiations/flash_bwd_hdim256_bf16_sm90.cu +12 -0
- flash-attn/instantiations/flash_bwd_hdim256_bf16_softcap_sm80.cu +18 -0
- flash-attn/instantiations/flash_bwd_hdim256_bf16_softcap_sm90.cu +12 -0
- flash-attn/instantiations/flash_bwd_hdim256_bf16_softcapall_sm90.cu +6 -0
- flash-attn/instantiations/flash_bwd_hdim256_fp16_sm80.cu +18 -0
- flash-attn/instantiations/flash_bwd_hdim256_fp16_sm90.cu +12 -0
build.toml
ADDED
@@ -0,0 +1,593 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[general]
|
2 |
+
name = "flash_attn3"
|
3 |
+
universal = false
|
4 |
+
cuda-minver = "12.4"
|
5 |
+
cuda-maxver = "12.4"
|
6 |
+
|
7 |
+
[torch]
|
8 |
+
src = [
|
9 |
+
"torch-ext/pytorch_shim.h",
|
10 |
+
"torch-ext/torch_binding.cpp",
|
11 |
+
"torch-ext/torch_binding.h",
|
12 |
+
]
|
13 |
+
|
14 |
+
[kernel.flash_attn]
|
15 |
+
backend = "cuda"
|
16 |
+
cuda-capabilities = ["8.0", "9.0a"]
|
17 |
+
cuda-flags = [
|
18 |
+
"-O3",
|
19 |
+
"-std=c++17",
|
20 |
+
"--ftemplate-backtrace-limit=0", # To debug template code
|
21 |
+
"--use_fast_math",
|
22 |
+
"-DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED",
|
23 |
+
"-DCUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1",
|
24 |
+
"-DCUTLASS_ENABLE_GDC_FOR_SM90",
|
25 |
+
"--expt-relaxed-constexpr",
|
26 |
+
"--expt-extended-lambda",
|
27 |
+
"--use_fast_math",
|
28 |
+
"-DNDEBUG",
|
29 |
+
]
|
30 |
+
|
31 |
+
src = [
|
32 |
+
"flash-attn/cuda_check.h",
|
33 |
+
"flash-attn/flash_api.cpp",
|
34 |
+
"flash-attn/flash_fwd_combine.cu",
|
35 |
+
"flash-attn/flash_fwd_combine_kernel.h",
|
36 |
+
"flash-attn/flash_fwd_combine_launch_template.h",
|
37 |
+
"flash-attn/flash.h",
|
38 |
+
"flash-attn/flash_prepare_scheduler.cu",
|
39 |
+
"flash-attn/heuristics.h",
|
40 |
+
"flash-attn/seqlen.h",
|
41 |
+
"flash-attn/static_switch.h",
|
42 |
+
"flash-attn/tile_size.h",
|
43 |
+
"flash-attn/utils.h",
|
44 |
+
]
|
45 |
+
depends = ["torch", "cutlass_3_9"]
|
46 |
+
|
47 |
+
[kernel.flash_attn_sm80]
|
48 |
+
backend = "cuda"
|
49 |
+
cuda-capabilities = ["8.0", "9.0a"]
|
50 |
+
cuda-flags = [
|
51 |
+
"-O3",
|
52 |
+
"-std=c++17",
|
53 |
+
"--ftemplate-backtrace-limit=0", # To debug template code
|
54 |
+
"--use_fast_math",
|
55 |
+
"-DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED",
|
56 |
+
"-DCUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1",
|
57 |
+
"-DCUTLASS_ENABLE_GDC_FOR_SM90",
|
58 |
+
"--expt-relaxed-constexpr",
|
59 |
+
"--expt-extended-lambda",
|
60 |
+
"--use_fast_math",
|
61 |
+
"-DNDEBUG",
|
62 |
+
]
|
63 |
+
src = [
|
64 |
+
"flash-attn/block.h",
|
65 |
+
"flash-attn/copy_sm90_bulk_reduce.hpp",
|
66 |
+
"flash-attn/epilogue_bwd.hpp",
|
67 |
+
"flash-attn/epilogue_fwd.hpp",
|
68 |
+
"flash-attn/flash.h",
|
69 |
+
"flash-attn/flash_bwd_kernel_sm80.h",
|
70 |
+
"flash-attn/flash_bwd_kernel_sm90.h",
|
71 |
+
"flash-attn/flash_bwd_launch_template.h",
|
72 |
+
"flash-attn/flash_bwd_postprocess_kernel.h",
|
73 |
+
"flash-attn/flash_bwd_preprocess_kernel.h",
|
74 |
+
"flash-attn/flash_fwd_launch_template.h",
|
75 |
+
"flash-attn/flash_fwd_kernel_sm80.h",
|
76 |
+
"flash-attn/flash_fwd_kernel_sm90.h",
|
77 |
+
"flash-attn/heuristics.h",
|
78 |
+
"flash-attn/mainloop_bwd_sm80.hpp",
|
79 |
+
"flash-attn/mainloop_fwd_sm80.hpp",
|
80 |
+
"flash-attn/mainloop_bwd_sm90_tma_gmma_ws.hpp",
|
81 |
+
"flash-attn/mainloop_fwd_sm90_tma_gmma_ws.hpp",
|
82 |
+
"flash-attn/mask.h",
|
83 |
+
"flash-attn/named_barrier.hpp",
|
84 |
+
"flash-attn/pack_gqa.h",
|
85 |
+
"flash-attn/paged_kv.h",
|
86 |
+
"flash-attn/rotary.h",
|
87 |
+
"flash-attn/sm90_pipeline_no_cluster.hpp",
|
88 |
+
"flash-attn/softmax.h",
|
89 |
+
"flash-attn/tile_size.h",
|
90 |
+
"flash-attn/tile_scheduler.hpp",
|
91 |
+
|
92 |
+
"flash-attn/instantiations/flash_bwd_hdim128_bf16_sm80.cu",
|
93 |
+
"flash-attn/instantiations/flash_bwd_hdim128_bf16_softcap_sm80.cu",
|
94 |
+
"flash-attn/instantiations/flash_bwd_hdim128_fp16_sm80.cu",
|
95 |
+
"flash-attn/instantiations/flash_bwd_hdim128_fp16_softcap_sm80.cu",
|
96 |
+
"flash-attn/instantiations/flash_bwd_hdim192_bf16_sm80.cu",
|
97 |
+
"flash-attn/instantiations/flash_bwd_hdim192_bf16_softcap_sm80.cu",
|
98 |
+
"flash-attn/instantiations/flash_bwd_hdim192_fp16_sm80.cu",
|
99 |
+
"flash-attn/instantiations/flash_bwd_hdim192_fp16_softcap_sm80.cu",
|
100 |
+
"flash-attn/instantiations/flash_bwd_hdim256_bf16_sm80.cu",
|
101 |
+
"flash-attn/instantiations/flash_bwd_hdim256_bf16_softcap_sm80.cu",
|
102 |
+
"flash-attn/instantiations/flash_bwd_hdim256_fp16_sm80.cu",
|
103 |
+
"flash-attn/instantiations/flash_bwd_hdim256_fp16_softcap_sm80.cu",
|
104 |
+
"flash-attn/instantiations/flash_bwd_hdim64_bf16_sm80.cu",
|
105 |
+
"flash-attn/instantiations/flash_bwd_hdim64_bf16_softcap_sm80.cu",
|
106 |
+
"flash-attn/instantiations/flash_bwd_hdim64_fp16_sm80.cu",
|
107 |
+
"flash-attn/instantiations/flash_bwd_hdim64_fp16_softcap_sm80.cu",
|
108 |
+
"flash-attn/instantiations/flash_bwd_hdim96_bf16_sm80.cu",
|
109 |
+
"flash-attn/instantiations/flash_bwd_hdim96_bf16_softcap_sm80.cu",
|
110 |
+
"flash-attn/instantiations/flash_bwd_hdim96_fp16_sm80.cu",
|
111 |
+
"flash-attn/instantiations/flash_bwd_hdim96_fp16_softcap_sm80.cu",
|
112 |
+
"flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_sm80.cu",
|
113 |
+
"flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_softcap_sm80.cu",
|
114 |
+
"flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_split_sm80.cu",
|
115 |
+
"flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_split_softcap_sm80.cu",
|
116 |
+
"flash-attn/instantiations/flash_fwd_hdim128_bf16_sm80.cu",
|
117 |
+
"flash-attn/instantiations/flash_fwd_hdim128_bf16_softcap_sm80.cu",
|
118 |
+
"flash-attn/instantiations/flash_fwd_hdim128_bf16_split_sm80.cu",
|
119 |
+
"flash-attn/instantiations/flash_fwd_hdim128_bf16_split_softcap_sm80.cu",
|
120 |
+
"flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_sm80.cu",
|
121 |
+
"flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_softcap_sm80.cu",
|
122 |
+
"flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_split_sm80.cu",
|
123 |
+
"flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_split_softcap_sm80.cu",
|
124 |
+
"flash-attn/instantiations/flash_fwd_hdim128_fp16_sm80.cu",
|
125 |
+
"flash-attn/instantiations/flash_fwd_hdim128_fp16_softcap_sm80.cu",
|
126 |
+
"flash-attn/instantiations/flash_fwd_hdim128_fp16_split_sm80.cu",
|
127 |
+
"flash-attn/instantiations/flash_fwd_hdim128_fp16_split_softcap_sm80.cu",
|
128 |
+
"flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_sm80.cu",
|
129 |
+
"flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_softcap_sm80.cu",
|
130 |
+
"flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_split_sm80.cu",
|
131 |
+
"flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_split_softcap_sm80.cu",
|
132 |
+
"flash-attn/instantiations/flash_fwd_hdim192_bf16_sm80.cu",
|
133 |
+
"flash-attn/instantiations/flash_fwd_hdim192_bf16_softcap_sm80.cu",
|
134 |
+
"flash-attn/instantiations/flash_fwd_hdim192_bf16_split_sm80.cu",
|
135 |
+
"flash-attn/instantiations/flash_fwd_hdim192_bf16_split_softcap_sm80.cu",
|
136 |
+
"flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_sm80.cu",
|
137 |
+
"flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_softcap_sm80.cu",
|
138 |
+
"flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_split_sm80.cu",
|
139 |
+
"flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_split_softcap_sm80.cu",
|
140 |
+
"flash-attn/instantiations/flash_fwd_hdim192_fp16_sm80.cu",
|
141 |
+
"flash-attn/instantiations/flash_fwd_hdim192_fp16_softcap_sm80.cu",
|
142 |
+
"flash-attn/instantiations/flash_fwd_hdim192_fp16_split_sm80.cu",
|
143 |
+
"flash-attn/instantiations/flash_fwd_hdim192_fp16_split_softcap_sm80.cu",
|
144 |
+
"flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_sm80.cu",
|
145 |
+
"flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_softcap_sm80.cu",
|
146 |
+
"flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_split_sm80.cu",
|
147 |
+
"flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_split_softcap_sm80.cu",
|
148 |
+
"flash-attn/instantiations/flash_fwd_hdim256_bf16_sm80.cu",
|
149 |
+
"flash-attn/instantiations/flash_fwd_hdim256_bf16_softcap_sm80.cu",
|
150 |
+
"flash-attn/instantiations/flash_fwd_hdim256_bf16_split_sm80.cu",
|
151 |
+
"flash-attn/instantiations/flash_fwd_hdim256_bf16_split_softcap_sm80.cu",
|
152 |
+
"flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_sm80.cu",
|
153 |
+
"flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_softcap_sm80.cu",
|
154 |
+
"flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_split_sm80.cu",
|
155 |
+
"flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_split_softcap_sm80.cu",
|
156 |
+
"flash-attn/instantiations/flash_fwd_hdim256_fp16_sm80.cu",
|
157 |
+
"flash-attn/instantiations/flash_fwd_hdim256_fp16_softcap_sm80.cu",
|
158 |
+
"flash-attn/instantiations/flash_fwd_hdim256_fp16_split_sm80.cu",
|
159 |
+
"flash-attn/instantiations/flash_fwd_hdim256_fp16_split_softcap_sm80.cu",
|
160 |
+
"flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_sm80.cu",
|
161 |
+
"flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_softcap_sm80.cu",
|
162 |
+
"flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_split_sm80.cu",
|
163 |
+
"flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_split_softcap_sm80.cu",
|
164 |
+
"flash-attn/instantiations/flash_fwd_hdim64_bf16_sm80.cu",
|
165 |
+
"flash-attn/instantiations/flash_fwd_hdim64_bf16_softcap_sm80.cu",
|
166 |
+
"flash-attn/instantiations/flash_fwd_hdim64_bf16_split_sm80.cu",
|
167 |
+
"flash-attn/instantiations/flash_fwd_hdim64_bf16_split_softcap_sm80.cu",
|
168 |
+
"flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_sm80.cu",
|
169 |
+
"flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_softcap_sm80.cu",
|
170 |
+
"flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_split_sm80.cu",
|
171 |
+
"flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_split_softcap_sm80.cu",
|
172 |
+
"flash-attn/instantiations/flash_fwd_hdim64_fp16_sm80.cu",
|
173 |
+
"flash-attn/instantiations/flash_fwd_hdim64_fp16_softcap_sm80.cu",
|
174 |
+
"flash-attn/instantiations/flash_fwd_hdim64_fp16_split_sm80.cu",
|
175 |
+
"flash-attn/instantiations/flash_fwd_hdim64_fp16_split_softcap_sm80.cu",
|
176 |
+
"flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_sm80.cu",
|
177 |
+
"flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_softcap_sm80.cu",
|
178 |
+
"flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_split_sm80.cu",
|
179 |
+
"flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_split_softcap_sm80.cu",
|
180 |
+
"flash-attn/instantiations/flash_fwd_hdim96_bf16_sm80.cu",
|
181 |
+
"flash-attn/instantiations/flash_fwd_hdim96_bf16_softcap_sm80.cu",
|
182 |
+
"flash-attn/instantiations/flash_fwd_hdim96_bf16_split_sm80.cu",
|
183 |
+
"flash-attn/instantiations/flash_fwd_hdim96_bf16_split_softcap_sm80.cu",
|
184 |
+
"flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_sm80.cu",
|
185 |
+
"flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_softcap_sm80.cu",
|
186 |
+
"flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_split_sm80.cu",
|
187 |
+
"flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_split_softcap_sm80.cu",
|
188 |
+
"flash-attn/instantiations/flash_fwd_hdim96_fp16_sm80.cu",
|
189 |
+
"flash-attn/instantiations/flash_fwd_hdim96_fp16_softcap_sm80.cu",
|
190 |
+
"flash-attn/instantiations/flash_fwd_hdim96_fp16_split_sm80.cu",
|
191 |
+
"flash-attn/instantiations/flash_fwd_hdim96_fp16_split_softcap_sm80.cu"
|
192 |
+
]
|
193 |
+
include = ["flash-attn"]
|
194 |
+
depends = ["torch", "cutlass_3_9"]
|
195 |
+
|
196 |
+
[kernel.flash_attn_sm90]
|
197 |
+
backend = "cuda"
|
198 |
+
cuda-capabilities = ["8.0", "9.0a"]
|
199 |
+
cuda-flags = [
|
200 |
+
"-O3",
|
201 |
+
"-std=c++17",
|
202 |
+
"--ftemplate-backtrace-limit=0", # To debug template code
|
203 |
+
"--use_fast_math",
|
204 |
+
"-DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED",
|
205 |
+
"-DCUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1",
|
206 |
+
"-DCUTLASS_ENABLE_GDC_FOR_SM90",
|
207 |
+
"--expt-relaxed-constexpr",
|
208 |
+
"--expt-extended-lambda",
|
209 |
+
"--use_fast_math",
|
210 |
+
"-DNDEBUG",
|
211 |
+
]
|
212 |
+
src = [
|
213 |
+
"flash-attn/block.h",
|
214 |
+
"flash-attn/copy_sm90_bulk_reduce.hpp",
|
215 |
+
"flash-attn/epilogue_bwd.hpp",
|
216 |
+
"flash-attn/epilogue_fwd.hpp",
|
217 |
+
"flash-attn/flash.h",
|
218 |
+
"flash-attn/flash_bwd_kernel_sm80.h",
|
219 |
+
"flash-attn/flash_bwd_kernel_sm90.h",
|
220 |
+
"flash-attn/flash_bwd_launch_template.h",
|
221 |
+
"flash-attn/flash_bwd_postprocess_kernel.h",
|
222 |
+
"flash-attn/flash_bwd_preprocess_kernel.h",
|
223 |
+
"flash-attn/flash_fwd_launch_template.h",
|
224 |
+
"flash-attn/flash_fwd_kernel_sm80.h",
|
225 |
+
"flash-attn/flash_fwd_kernel_sm90.h",
|
226 |
+
"flash-attn/heuristics.h",
|
227 |
+
"flash-attn/mainloop_bwd_sm80.hpp",
|
228 |
+
"flash-attn/mainloop_fwd_sm80.hpp",
|
229 |
+
"flash-attn/mainloop_bwd_sm90_tma_gmma_ws.hpp",
|
230 |
+
"flash-attn/mainloop_fwd_sm90_tma_gmma_ws.hpp",
|
231 |
+
"flash-attn/mask.h",
|
232 |
+
"flash-attn/named_barrier.hpp",
|
233 |
+
"flash-attn/pack_gqa.h",
|
234 |
+
"flash-attn/paged_kv.h",
|
235 |
+
"flash-attn/rotary.h",
|
236 |
+
"flash-attn/sm90_pipeline_no_cluster.hpp",
|
237 |
+
"flash-attn/softmax.h",
|
238 |
+
"flash-attn/tile_size.h",
|
239 |
+
"flash-attn/tile_scheduler.hpp",
|
240 |
+
|
241 |
+
"flash-attn/instantiations/flash_bwd_hdim128_bf16_sm90.cu",
|
242 |
+
"flash-attn/instantiations/flash_bwd_hdim128_bf16_softcap_sm90.cu",
|
243 |
+
"flash-attn/instantiations/flash_bwd_hdim128_fp16_sm90.cu",
|
244 |
+
"flash-attn/instantiations/flash_bwd_hdim128_fp16_softcap_sm90.cu",
|
245 |
+
"flash-attn/instantiations/flash_bwd_hdim192_bf16_sm90.cu",
|
246 |
+
"flash-attn/instantiations/flash_bwd_hdim192_bf16_softcap_sm90.cu",
|
247 |
+
"flash-attn/instantiations/flash_bwd_hdim192_fp16_sm90.cu",
|
248 |
+
"flash-attn/instantiations/flash_bwd_hdim192_fp16_softcap_sm90.cu",
|
249 |
+
"flash-attn/instantiations/flash_bwd_hdim256_bf16_sm90.cu",
|
250 |
+
"flash-attn/instantiations/flash_bwd_hdim256_bf16_softcap_sm90.cu",
|
251 |
+
"flash-attn/instantiations/flash_bwd_hdim256_fp16_sm90.cu",
|
252 |
+
"flash-attn/instantiations/flash_bwd_hdim256_fp16_softcap_sm90.cu",
|
253 |
+
"flash-attn/instantiations/flash_bwd_hdim64_bf16_sm90.cu",
|
254 |
+
"flash-attn/instantiations/flash_bwd_hdim64_bf16_softcap_sm90.cu",
|
255 |
+
"flash-attn/instantiations/flash_bwd_hdim64_fp16_sm90.cu",
|
256 |
+
"flash-attn/instantiations/flash_bwd_hdim64_fp16_softcap_sm90.cu",
|
257 |
+
"flash-attn/instantiations/flash_bwd_hdim96_bf16_sm90.cu",
|
258 |
+
"flash-attn/instantiations/flash_bwd_hdim96_bf16_softcap_sm90.cu",
|
259 |
+
"flash-attn/instantiations/flash_bwd_hdim96_fp16_sm90.cu",
|
260 |
+
"flash-attn/instantiations/flash_bwd_hdim96_fp16_softcap_sm90.cu",
|
261 |
+
"flash-attn/instantiations/flash_fwd_hdim128_bf16_packgqa_sm90.cu",
|
262 |
+
"flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_sm90.cu",
|
263 |
+
"flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_softcap_sm90.cu",
|
264 |
+
"flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_split_sm90.cu",
|
265 |
+
"flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_split_softcap_sm90.cu",
|
266 |
+
"flash-attn/instantiations/flash_fwd_hdim128_bf16_sm90.cu",
|
267 |
+
"flash-attn/instantiations/flash_fwd_hdim128_bf16_softcap_packgqa_sm90.cu",
|
268 |
+
"flash-attn/instantiations/flash_fwd_hdim128_bf16_softcap_sm90.cu",
|
269 |
+
"flash-attn/instantiations/flash_fwd_hdim128_bf16_split_sm90.cu",
|
270 |
+
"flash-attn/instantiations/flash_fwd_hdim128_bf16_split_softcap_sm90.cu",
|
271 |
+
"flash-attn/instantiations/flash_fwd_hdim128_e4m3_packgqa_sm90.cu",
|
272 |
+
"flash-attn/instantiations/flash_fwd_hdim128_e4m3_paged_sm90.cu",
|
273 |
+
"flash-attn/instantiations/flash_fwd_hdim128_e4m3_paged_softcap_sm90.cu",
|
274 |
+
"flash-attn/instantiations/flash_fwd_hdim128_e4m3_paged_split_sm90.cu",
|
275 |
+
"flash-attn/instantiations/flash_fwd_hdim128_e4m3_paged_split_softcap_sm90.cu",
|
276 |
+
"flash-attn/instantiations/flash_fwd_hdim128_e4m3_sm90.cu",
|
277 |
+
"flash-attn/instantiations/flash_fwd_hdim128_e4m3_softcap_packgqa_sm90.cu",
|
278 |
+
"flash-attn/instantiations/flash_fwd_hdim128_e4m3_softcap_sm90.cu",
|
279 |
+
"flash-attn/instantiations/flash_fwd_hdim128_e4m3_split_sm90.cu",
|
280 |
+
"flash-attn/instantiations/flash_fwd_hdim128_e4m3_split_softcap_sm90.cu",
|
281 |
+
"flash-attn/instantiations/flash_fwd_hdim128_fp16_packgqa_sm90.cu",
|
282 |
+
"flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_sm90.cu",
|
283 |
+
"flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_softcap_sm90.cu",
|
284 |
+
"flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_split_sm90.cu",
|
285 |
+
"flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_split_softcap_sm90.cu",
|
286 |
+
"flash-attn/instantiations/flash_fwd_hdim128_fp16_sm90.cu",
|
287 |
+
"flash-attn/instantiations/flash_fwd_hdim128_fp16_softcap_packgqa_sm90.cu",
|
288 |
+
"flash-attn/instantiations/flash_fwd_hdim128_fp16_softcap_sm90.cu",
|
289 |
+
"flash-attn/instantiations/flash_fwd_hdim128_fp16_split_sm90.cu",
|
290 |
+
"flash-attn/instantiations/flash_fwd_hdim128_fp16_split_softcap_sm90.cu",
|
291 |
+
"flash-attn/instantiations/flash_fwd_hdim192_128_bf16_packgqa_sm90.cu",
|
292 |
+
"flash-attn/instantiations/flash_fwd_hdim192_128_bf16_paged_sm90.cu",
|
293 |
+
"flash-attn/instantiations/flash_fwd_hdim192_128_bf16_paged_softcap_sm90.cu",
|
294 |
+
"flash-attn/instantiations/flash_fwd_hdim192_128_bf16_paged_split_sm90.cu",
|
295 |
+
"flash-attn/instantiations/flash_fwd_hdim192_128_bf16_paged_split_softcap_sm90.cu",
|
296 |
+
"flash-attn/instantiations/flash_fwd_hdim192_128_bf16_sm90.cu",
|
297 |
+
"flash-attn/instantiations/flash_fwd_hdim192_128_bf16_softcap_packgqa_sm90.cu",
|
298 |
+
"flash-attn/instantiations/flash_fwd_hdim192_128_bf16_softcap_sm90.cu",
|
299 |
+
"flash-attn/instantiations/flash_fwd_hdim192_128_bf16_split_sm90.cu",
|
300 |
+
"flash-attn/instantiations/flash_fwd_hdim192_128_bf16_split_softcap_sm90.cu",
|
301 |
+
"flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_packgqa_sm90.cu",
|
302 |
+
"flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_paged_sm90.cu",
|
303 |
+
"flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_paged_softcap_sm90.cu",
|
304 |
+
"flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_paged_split_sm90.cu",
|
305 |
+
"flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_paged_split_softcap_sm90.cu",
|
306 |
+
"flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_sm90.cu",
|
307 |
+
"flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_softcap_packgqa_sm90.cu",
|
308 |
+
"flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_softcap_sm90.cu",
|
309 |
+
"flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_split_sm90.cu",
|
310 |
+
"flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_split_softcap_sm90.cu",
|
311 |
+
"flash-attn/instantiations/flash_fwd_hdim192_128_fp16_packgqa_sm90.cu",
|
312 |
+
"flash-attn/instantiations/flash_fwd_hdim192_128_fp16_paged_sm90.cu",
|
313 |
+
"flash-attn/instantiations/flash_fwd_hdim192_128_fp16_paged_softcap_sm90.cu",
|
314 |
+
"flash-attn/instantiations/flash_fwd_hdim192_128_fp16_paged_split_sm90.cu",
|
315 |
+
"flash-attn/instantiations/flash_fwd_hdim192_128_fp16_paged_split_softcap_sm90.cu",
|
316 |
+
"flash-attn/instantiations/flash_fwd_hdim192_128_fp16_sm90.cu",
|
317 |
+
"flash-attn/instantiations/flash_fwd_hdim192_128_fp16_softcap_packgqa_sm90.cu",
|
318 |
+
"flash-attn/instantiations/flash_fwd_hdim192_128_fp16_softcap_sm90.cu",
|
319 |
+
"flash-attn/instantiations/flash_fwd_hdim192_128_fp16_split_sm90.cu",
|
320 |
+
"flash-attn/instantiations/flash_fwd_hdim192_128_fp16_split_softcap_sm90.cu",
|
321 |
+
"flash-attn/instantiations/flash_fwd_hdim192_bf16_packgqa_sm90.cu",
|
322 |
+
"flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_sm90.cu",
|
323 |
+
"flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_softcap_sm90.cu",
|
324 |
+
"flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_split_sm90.cu",
|
325 |
+
"flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_split_softcap_sm90.cu",
|
326 |
+
"flash-attn/instantiations/flash_fwd_hdim192_bf16_sm90.cu",
|
327 |
+
"flash-attn/instantiations/flash_fwd_hdim192_bf16_softcap_packgqa_sm90.cu",
|
328 |
+
"flash-attn/instantiations/flash_fwd_hdim192_bf16_softcap_sm90.cu",
|
329 |
+
"flash-attn/instantiations/flash_fwd_hdim192_bf16_split_sm90.cu",
|
330 |
+
"flash-attn/instantiations/flash_fwd_hdim192_bf16_split_softcap_sm90.cu",
|
331 |
+
"flash-attn/instantiations/flash_fwd_hdim192_e4m3_packgqa_sm90.cu",
|
332 |
+
"flash-attn/instantiations/flash_fwd_hdim192_e4m3_paged_sm90.cu",
|
333 |
+
"flash-attn/instantiations/flash_fwd_hdim192_e4m3_paged_softcap_sm90.cu",
|
334 |
+
"flash-attn/instantiations/flash_fwd_hdim192_e4m3_paged_split_sm90.cu",
|
335 |
+
"flash-attn/instantiations/flash_fwd_hdim192_e4m3_paged_split_softcap_sm90.cu",
|
336 |
+
"flash-attn/instantiations/flash_fwd_hdim192_e4m3_sm90.cu",
|
337 |
+
"flash-attn/instantiations/flash_fwd_hdim192_e4m3_softcap_packgqa_sm90.cu",
|
338 |
+
"flash-attn/instantiations/flash_fwd_hdim192_e4m3_softcap_sm90.cu",
|
339 |
+
"flash-attn/instantiations/flash_fwd_hdim192_e4m3_split_sm90.cu",
|
340 |
+
"flash-attn/instantiations/flash_fwd_hdim192_e4m3_split_softcap_sm90.cu",
|
341 |
+
"flash-attn/instantiations/flash_fwd_hdim192_fp16_packgqa_sm90.cu",
|
342 |
+
"flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_sm90.cu",
|
343 |
+
"flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_softcap_sm90.cu",
|
344 |
+
"flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_split_sm90.cu",
|
345 |
+
"flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_split_softcap_sm90.cu",
|
346 |
+
"flash-attn/instantiations/flash_fwd_hdim192_fp16_sm90.cu",
|
347 |
+
"flash-attn/instantiations/flash_fwd_hdim192_fp16_softcap_packgqa_sm90.cu",
|
348 |
+
"flash-attn/instantiations/flash_fwd_hdim192_fp16_softcap_sm90.cu",
|
349 |
+
"flash-attn/instantiations/flash_fwd_hdim192_fp16_split_sm90.cu",
|
350 |
+
"flash-attn/instantiations/flash_fwd_hdim192_fp16_split_softcap_sm90.cu",
|
351 |
+
"flash-attn/instantiations/flash_fwd_hdim256_bf16_packgqa_sm90.cu",
|
352 |
+
"flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_sm90.cu",
|
353 |
+
"flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_softcap_sm90.cu",
|
354 |
+
"flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_split_sm90.cu",
|
355 |
+
"flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_split_softcap_sm90.cu",
|
356 |
+
"flash-attn/instantiations/flash_fwd_hdim256_bf16_sm90.cu",
|
357 |
+
"flash-attn/instantiations/flash_fwd_hdim256_bf16_softcap_packgqa_sm90.cu",
|
358 |
+
"flash-attn/instantiations/flash_fwd_hdim256_bf16_softcap_sm90.cu",
|
359 |
+
"flash-attn/instantiations/flash_fwd_hdim256_bf16_split_sm90.cu",
|
360 |
+
"flash-attn/instantiations/flash_fwd_hdim256_bf16_split_softcap_sm90.cu",
|
361 |
+
"flash-attn/instantiations/flash_fwd_hdim256_e4m3_packgqa_sm90.cu",
|
362 |
+
"flash-attn/instantiations/flash_fwd_hdim256_e4m3_paged_sm90.cu",
|
363 |
+
"flash-attn/instantiations/flash_fwd_hdim256_e4m3_paged_softcap_sm90.cu",
|
364 |
+
"flash-attn/instantiations/flash_fwd_hdim256_e4m3_paged_split_sm90.cu",
|
365 |
+
"flash-attn/instantiations/flash_fwd_hdim256_e4m3_paged_split_softcap_sm90.cu",
|
366 |
+
"flash-attn/instantiations/flash_fwd_hdim256_e4m3_sm90.cu",
|
367 |
+
"flash-attn/instantiations/flash_fwd_hdim256_e4m3_softcap_packgqa_sm90.cu",
|
368 |
+
"flash-attn/instantiations/flash_fwd_hdim256_e4m3_softcap_sm90.cu",
|
369 |
+
"flash-attn/instantiations/flash_fwd_hdim256_e4m3_split_sm90.cu",
|
370 |
+
"flash-attn/instantiations/flash_fwd_hdim256_e4m3_split_softcap_sm90.cu",
|
371 |
+
"flash-attn/instantiations/flash_fwd_hdim256_fp16_packgqa_sm90.cu",
|
372 |
+
"flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_sm90.cu",
|
373 |
+
"flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_softcap_sm90.cu",
|
374 |
+
"flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_split_sm90.cu",
|
375 |
+
"flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_split_softcap_sm90.cu",
|
376 |
+
"flash-attn/instantiations/flash_fwd_hdim256_fp16_sm90.cu",
|
377 |
+
"flash-attn/instantiations/flash_fwd_hdim256_fp16_softcap_packgqa_sm90.cu",
|
378 |
+
"flash-attn/instantiations/flash_fwd_hdim256_fp16_softcap_sm90.cu",
|
379 |
+
"flash-attn/instantiations/flash_fwd_hdim256_fp16_split_sm90.cu",
|
380 |
+
"flash-attn/instantiations/flash_fwd_hdim256_fp16_split_softcap_sm90.cu",
|
381 |
+
"flash-attn/instantiations/flash_fwd_hdim64_256_bf16_packgqa_sm90.cu",
|
382 |
+
"flash-attn/instantiations/flash_fwd_hdim64_256_bf16_paged_sm90.cu",
|
383 |
+
"flash-attn/instantiations/flash_fwd_hdim64_256_bf16_paged_softcap_sm90.cu",
|
384 |
+
"flash-attn/instantiations/flash_fwd_hdim64_256_bf16_paged_split_sm90.cu",
|
385 |
+
"flash-attn/instantiations/flash_fwd_hdim64_256_bf16_paged_split_softcap_sm90.cu",
|
386 |
+
"flash-attn/instantiations/flash_fwd_hdim64_256_bf16_sm90.cu",
|
387 |
+
"flash-attn/instantiations/flash_fwd_hdim64_256_bf16_softcap_packgqa_sm90.cu",
|
388 |
+
"flash-attn/instantiations/flash_fwd_hdim64_256_bf16_softcap_sm90.cu",
|
389 |
+
"flash-attn/instantiations/flash_fwd_hdim64_256_bf16_split_sm90.cu",
|
390 |
+
"flash-attn/instantiations/flash_fwd_hdim64_256_bf16_split_softcap_sm90.cu",
|
391 |
+
"flash-attn/instantiations/flash_fwd_hdim64_256_fp16_packgqa_sm90.cu",
|
392 |
+
"flash-attn/instantiations/flash_fwd_hdim64_256_fp16_paged_sm90.cu",
|
393 |
+
"flash-attn/instantiations/flash_fwd_hdim64_256_fp16_paged_softcap_sm90.cu",
|
394 |
+
"flash-attn/instantiations/flash_fwd_hdim64_256_fp16_paged_split_sm90.cu",
|
395 |
+
"flash-attn/instantiations/flash_fwd_hdim64_256_fp16_paged_split_softcap_sm90.cu",
|
396 |
+
"flash-attn/instantiations/flash_fwd_hdim64_256_fp16_sm90.cu",
|
397 |
+
"flash-attn/instantiations/flash_fwd_hdim64_256_fp16_softcap_packgqa_sm90.cu",
|
398 |
+
"flash-attn/instantiations/flash_fwd_hdim64_256_fp16_softcap_sm90.cu",
|
399 |
+
"flash-attn/instantiations/flash_fwd_hdim64_256_fp16_split_sm90.cu",
|
400 |
+
"flash-attn/instantiations/flash_fwd_hdim64_256_fp16_split_softcap_sm90.cu",
|
401 |
+
"flash-attn/instantiations/flash_fwd_hdim64_512_bf16_packgqa_sm90.cu",
|
402 |
+
"flash-attn/instantiations/flash_fwd_hdim64_512_bf16_paged_sm90.cu",
|
403 |
+
"flash-attn/instantiations/flash_fwd_hdim64_512_bf16_paged_softcap_sm90.cu",
|
404 |
+
"flash-attn/instantiations/flash_fwd_hdim64_512_bf16_paged_split_sm90.cu",
|
405 |
+
"flash-attn/instantiations/flash_fwd_hdim64_512_bf16_paged_split_softcap_sm90.cu",
|
406 |
+
"flash-attn/instantiations/flash_fwd_hdim64_512_bf16_sm90.cu",
|
407 |
+
"flash-attn/instantiations/flash_fwd_hdim64_512_bf16_softcap_packgqa_sm90.cu",
|
408 |
+
"flash-attn/instantiations/flash_fwd_hdim64_512_bf16_softcap_sm90.cu",
|
409 |
+
"flash-attn/instantiations/flash_fwd_hdim64_512_bf16_split_sm90.cu",
|
410 |
+
"flash-attn/instantiations/flash_fwd_hdim64_512_bf16_split_softcap_sm90.cu",
|
411 |
+
"flash-attn/instantiations/flash_fwd_hdim64_512_fp16_packgqa_sm90.cu",
|
412 |
+
"flash-attn/instantiations/flash_fwd_hdim64_512_fp16_paged_sm90.cu",
|
413 |
+
"flash-attn/instantiations/flash_fwd_hdim64_512_fp16_paged_softcap_sm90.cu",
|
414 |
+
"flash-attn/instantiations/flash_fwd_hdim64_512_fp16_paged_split_sm90.cu",
|
415 |
+
"flash-attn/instantiations/flash_fwd_hdim64_512_fp16_paged_split_softcap_sm90.cu",
|
416 |
+
"flash-attn/instantiations/flash_fwd_hdim64_512_fp16_sm90.cu",
|
417 |
+
"flash-attn/instantiations/flash_fwd_hdim64_512_fp16_softcap_packgqa_sm90.cu",
|
418 |
+
"flash-attn/instantiations/flash_fwd_hdim64_512_fp16_softcap_sm90.cu",
|
419 |
+
"flash-attn/instantiations/flash_fwd_hdim64_512_fp16_split_sm90.cu",
|
420 |
+
"flash-attn/instantiations/flash_fwd_hdim64_512_fp16_split_softcap_sm90.cu",
|
421 |
+
"flash-attn/instantiations/flash_fwd_hdim64_bf16_packgqa_sm90.cu",
|
422 |
+
"flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_sm90.cu",
|
423 |
+
"flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_softcap_sm90.cu",
|
424 |
+
"flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_split_sm90.cu",
|
425 |
+
"flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_split_softcap_sm90.cu",
|
426 |
+
"flash-attn/instantiations/flash_fwd_hdim64_bf16_sm90.cu",
|
427 |
+
"flash-attn/instantiations/flash_fwd_hdim64_bf16_softcap_packgqa_sm90.cu",
|
428 |
+
"flash-attn/instantiations/flash_fwd_hdim64_bf16_softcap_sm90.cu",
|
429 |
+
"flash-attn/instantiations/flash_fwd_hdim64_bf16_split_sm90.cu",
|
430 |
+
"flash-attn/instantiations/flash_fwd_hdim64_bf16_split_softcap_sm90.cu",
|
431 |
+
"flash-attn/instantiations/flash_fwd_hdim64_e4m3_packgqa_sm90.cu",
|
432 |
+
"flash-attn/instantiations/flash_fwd_hdim64_e4m3_paged_sm90.cu",
|
433 |
+
"flash-attn/instantiations/flash_fwd_hdim64_e4m3_paged_softcap_sm90.cu",
|
434 |
+
"flash-attn/instantiations/flash_fwd_hdim64_e4m3_paged_split_sm90.cu",
|
435 |
+
"flash-attn/instantiations/flash_fwd_hdim64_e4m3_paged_split_softcap_sm90.cu",
|
436 |
+
"flash-attn/instantiations/flash_fwd_hdim64_e4m3_sm90.cu",
|
437 |
+
"flash-attn/instantiations/flash_fwd_hdim64_e4m3_softcap_packgqa_sm90.cu",
|
438 |
+
"flash-attn/instantiations/flash_fwd_hdim64_e4m3_softcap_sm90.cu",
|
439 |
+
"flash-attn/instantiations/flash_fwd_hdim64_e4m3_split_sm90.cu",
|
440 |
+
"flash-attn/instantiations/flash_fwd_hdim64_e4m3_split_softcap_sm90.cu",
|
441 |
+
"flash-attn/instantiations/flash_fwd_hdim64_fp16_packgqa_sm90.cu",
|
442 |
+
"flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_sm90.cu",
|
443 |
+
"flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_softcap_sm90.cu",
|
444 |
+
"flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_split_sm90.cu",
|
445 |
+
"flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_split_softcap_sm90.cu",
|
446 |
+
"flash-attn/instantiations/flash_fwd_hdim64_fp16_sm90.cu",
|
447 |
+
"flash-attn/instantiations/flash_fwd_hdim64_fp16_softcap_packgqa_sm90.cu",
|
448 |
+
"flash-attn/instantiations/flash_fwd_hdim64_fp16_softcap_sm90.cu",
|
449 |
+
"flash-attn/instantiations/flash_fwd_hdim64_fp16_split_sm90.cu",
|
450 |
+
"flash-attn/instantiations/flash_fwd_hdim64_fp16_split_softcap_sm90.cu",
|
451 |
+
"flash-attn/instantiations/flash_fwd_hdim96_bf16_packgqa_sm90.cu",
|
452 |
+
"flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_sm90.cu",
|
453 |
+
"flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_softcap_sm90.cu",
|
454 |
+
"flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_split_sm90.cu",
|
455 |
+
"flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_split_softcap_sm90.cu",
|
456 |
+
"flash-attn/instantiations/flash_fwd_hdim96_bf16_sm90.cu",
|
457 |
+
"flash-attn/instantiations/flash_fwd_hdim96_bf16_softcap_packgqa_sm90.cu",
|
458 |
+
"flash-attn/instantiations/flash_fwd_hdim96_bf16_softcap_sm90.cu",
|
459 |
+
"flash-attn/instantiations/flash_fwd_hdim96_bf16_split_sm90.cu",
|
460 |
+
"flash-attn/instantiations/flash_fwd_hdim96_bf16_split_softcap_sm90.cu",
|
461 |
+
"flash-attn/instantiations/flash_fwd_hdim96_e4m3_packgqa_sm90.cu",
|
462 |
+
"flash-attn/instantiations/flash_fwd_hdim96_e4m3_paged_sm90.cu",
|
463 |
+
"flash-attn/instantiations/flash_fwd_hdim96_e4m3_paged_softcap_sm90.cu",
|
464 |
+
"flash-attn/instantiations/flash_fwd_hdim96_e4m3_paged_split_sm90.cu",
|
465 |
+
"flash-attn/instantiations/flash_fwd_hdim96_e4m3_paged_split_softcap_sm90.cu",
|
466 |
+
"flash-attn/instantiations/flash_fwd_hdim96_e4m3_sm90.cu",
|
467 |
+
"flash-attn/instantiations/flash_fwd_hdim96_e4m3_softcap_packgqa_sm90.cu",
|
468 |
+
"flash-attn/instantiations/flash_fwd_hdim96_e4m3_softcap_sm90.cu",
|
469 |
+
"flash-attn/instantiations/flash_fwd_hdim96_e4m3_split_sm90.cu",
|
470 |
+
"flash-attn/instantiations/flash_fwd_hdim96_e4m3_split_softcap_sm90.cu",
|
471 |
+
"flash-attn/instantiations/flash_fwd_hdim96_fp16_packgqa_sm90.cu",
|
472 |
+
"flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_sm90.cu",
|
473 |
+
"flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_softcap_sm90.cu",
|
474 |
+
"flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_split_sm90.cu",
|
475 |
+
"flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_split_softcap_sm90.cu",
|
476 |
+
"flash-attn/instantiations/flash_fwd_hdim96_fp16_sm90.cu",
|
477 |
+
"flash-attn/instantiations/flash_fwd_hdim96_fp16_softcap_packgqa_sm90.cu",
|
478 |
+
"flash-attn/instantiations/flash_fwd_hdim96_fp16_softcap_sm90.cu",
|
479 |
+
"flash-attn/instantiations/flash_fwd_hdim96_fp16_split_sm90.cu",
|
480 |
+
"flash-attn/instantiations/flash_fwd_hdim96_fp16_split_softcap_sm90.cu",
|
481 |
+
"flash-attn/instantiations/flash_fwd_hdimall_bf16_packgqa_sm90.cu",
|
482 |
+
"flash-attn/instantiations/flash_fwd_hdimall_bf16_paged_sm90.cu",
|
483 |
+
"flash-attn/instantiations/flash_fwd_hdimall_bf16_paged_softcap_sm90.cu",
|
484 |
+
"flash-attn/instantiations/flash_fwd_hdimall_bf16_paged_split_sm90.cu",
|
485 |
+
"flash-attn/instantiations/flash_fwd_hdimall_bf16_paged_split_softcap_sm90.cu",
|
486 |
+
"flash-attn/instantiations/flash_fwd_hdimall_bf16_sm90.cu",
|
487 |
+
"flash-attn/instantiations/flash_fwd_hdimall_bf16_softcap_packgqa_sm90.cu",
|
488 |
+
"flash-attn/instantiations/flash_fwd_hdimall_bf16_softcap_sm90.cu",
|
489 |
+
"flash-attn/instantiations/flash_fwd_hdimall_bf16_split_sm90.cu",
|
490 |
+
"flash-attn/instantiations/flash_fwd_hdimall_bf16_split_softcap_sm90.cu",
|
491 |
+
"flash-attn/instantiations/flash_fwd_hdimall_e4m3_packgqa_sm90.cu",
|
492 |
+
"flash-attn/instantiations/flash_fwd_hdimall_e4m3_paged_sm90.cu",
|
493 |
+
"flash-attn/instantiations/flash_fwd_hdimall_e4m3_paged_softcap_sm90.cu",
|
494 |
+
"flash-attn/instantiations/flash_fwd_hdimall_e4m3_paged_split_sm90.cu",
|
495 |
+
"flash-attn/instantiations/flash_fwd_hdimall_e4m3_paged_split_softcap_sm90.cu",
|
496 |
+
"flash-attn/instantiations/flash_fwd_hdimall_e4m3_sm90.cu",
|
497 |
+
"flash-attn/instantiations/flash_fwd_hdimall_e4m3_softcap_packgqa_sm90.cu",
|
498 |
+
"flash-attn/instantiations/flash_fwd_hdimall_e4m3_softcap_sm90.cu",
|
499 |
+
"flash-attn/instantiations/flash_fwd_hdimall_e4m3_split_sm90.cu",
|
500 |
+
"flash-attn/instantiations/flash_fwd_hdimall_e4m3_split_softcap_sm90.cu",
|
501 |
+
"flash-attn/instantiations/flash_fwd_hdimall_fp16_packgqa_sm90.cu",
|
502 |
+
"flash-attn/instantiations/flash_fwd_hdimall_fp16_paged_sm90.cu",
|
503 |
+
"flash-attn/instantiations/flash_fwd_hdimall_fp16_paged_softcap_sm90.cu",
|
504 |
+
"flash-attn/instantiations/flash_fwd_hdimall_fp16_paged_split_sm90.cu",
|
505 |
+
"flash-attn/instantiations/flash_fwd_hdimall_fp16_paged_split_softcap_sm90.cu",
|
506 |
+
"flash-attn/instantiations/flash_fwd_hdimall_fp16_sm90.cu",
|
507 |
+
"flash-attn/instantiations/flash_fwd_hdimall_fp16_softcap_packgqa_sm90.cu",
|
508 |
+
"flash-attn/instantiations/flash_fwd_hdimall_fp16_softcap_sm90.cu",
|
509 |
+
"flash-attn/instantiations/flash_fwd_hdimall_fp16_split_sm90.cu",
|
510 |
+
"flash-attn/instantiations/flash_fwd_hdimall_fp16_split_softcap_sm90.cu",
|
511 |
+
"flash-attn/instantiations/flash_fwd_hdimdiff_bf16_packgqa_sm90.cu",
|
512 |
+
"flash-attn/instantiations/flash_fwd_hdimdiff_bf16_paged_sm90.cu",
|
513 |
+
"flash-attn/instantiations/flash_fwd_hdimdiff_bf16_paged_softcap_sm90.cu",
|
514 |
+
"flash-attn/instantiations/flash_fwd_hdimdiff_bf16_paged_split_sm90.cu",
|
515 |
+
"flash-attn/instantiations/flash_fwd_hdimdiff_bf16_paged_split_softcap_sm90.cu",
|
516 |
+
"flash-attn/instantiations/flash_fwd_hdimdiff_bf16_sm90.cu",
|
517 |
+
"flash-attn/instantiations/flash_fwd_hdimdiff_bf16_softcap_packgqa_sm90.cu",
|
518 |
+
"flash-attn/instantiations/flash_fwd_hdimdiff_bf16_softcap_sm90.cu",
|
519 |
+
"flash-attn/instantiations/flash_fwd_hdimdiff_bf16_split_sm90.cu",
|
520 |
+
"flash-attn/instantiations/flash_fwd_hdimdiff_bf16_split_softcap_sm90.cu",
|
521 |
+
"flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_packgqa_sm90.cu",
|
522 |
+
"flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_paged_sm90.cu",
|
523 |
+
"flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_paged_softcap_sm90.cu",
|
524 |
+
"flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_paged_split_sm90.cu",
|
525 |
+
"flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_paged_split_softcap_sm90.cu",
|
526 |
+
"flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_sm90.cu",
|
527 |
+
"flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_softcap_packgqa_sm90.cu",
|
528 |
+
"flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_softcap_sm90.cu",
|
529 |
+
"flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_split_sm90.cu",
|
530 |
+
"flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_split_softcap_sm90.cu",
|
531 |
+
"flash-attn/instantiations/flash_fwd_hdimdiff_fp16_packgqa_sm90.cu",
|
532 |
+
"flash-attn/instantiations/flash_fwd_hdimdiff_fp16_paged_sm90.cu",
|
533 |
+
"flash-attn/instantiations/flash_fwd_hdimdiff_fp16_paged_softcap_sm90.cu",
|
534 |
+
"flash-attn/instantiations/flash_fwd_hdimdiff_fp16_paged_split_sm90.cu",
|
535 |
+
"flash-attn/instantiations/flash_fwd_hdimdiff_fp16_paged_split_softcap_sm90.cu",
|
536 |
+
"flash-attn/instantiations/flash_fwd_hdimdiff_fp16_sm90.cu",
|
537 |
+
"flash-attn/instantiations/flash_fwd_hdimdiff_fp16_softcap_packgqa_sm90.cu",
|
538 |
+
"flash-attn/instantiations/flash_fwd_hdimdiff_fp16_softcap_sm90.cu",
|
539 |
+
"flash-attn/instantiations/flash_fwd_hdimdiff_fp16_split_sm90.cu",
|
540 |
+
"flash-attn/instantiations/flash_fwd_hdimdiff_fp16_split_softcap_sm90.cu",
|
541 |
+
]
|
542 |
+
include = ["flash-attn"]
|
543 |
+
depends = ["torch", "cutlass_3_9"]
|
544 |
+
|
545 |
+
# [kernel.flash_attn_sm100]
|
546 |
+
# backend = "cuda"
|
547 |
+
# cuda-capabilities = ["8.0", "9.0a", "10.0"]
|
548 |
+
# cuda-flags = [
|
549 |
+
# "-O3",
|
550 |
+
# "-std=c++17",
|
551 |
+
# "--ftemplate-backtrace-limit=0", # To debug template code
|
552 |
+
# "--use_fast_math",
|
553 |
+
# "-DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED",
|
554 |
+
# "-DCUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1",
|
555 |
+
# "-DCUTLASS_ENABLE_GDC_FOR_SM90",
|
556 |
+
# "--expt-relaxed-constexpr",
|
557 |
+
# "--expt-extended-lambda",
|
558 |
+
# "--use_fast_math",
|
559 |
+
# "-DNDEBUG",
|
560 |
+
# ]
|
561 |
+
# src = [
|
562 |
+
# "flash-attn/block.h",
|
563 |
+
# "flash-attn/copy_sm90_bulk_reduce.hpp",
|
564 |
+
# "flash-attn/epilogue_bwd.hpp",
|
565 |
+
# "flash-attn/epilogue_fwd.hpp",
|
566 |
+
# "flash-attn/flash.h",
|
567 |
+
# "flash-attn/flash_bwd_kernel_sm80.h",
|
568 |
+
# "flash-attn/flash_bwd_kernel_sm90.h",
|
569 |
+
# "flash-attn/flash_bwd_launch_template.h",
|
570 |
+
# "flash-attn/flash_bwd_postprocess_kernel.h",
|
571 |
+
# "flash-attn/flash_bwd_preprocess_kernel.h",
|
572 |
+
# "flash-attn/flash_fwd_launch_template.h",
|
573 |
+
# "flash-attn/flash_fwd_kernel_sm80.h",
|
574 |
+
# "flash-attn/flash_fwd_kernel_sm90.h",
|
575 |
+
# "flash-attn/heuristics.h",
|
576 |
+
# "flash-attn/mainloop_bwd_sm80.hpp",
|
577 |
+
# "flash-attn/mainloop_fwd_sm80.hpp",
|
578 |
+
# "flash-attn/mainloop_bwd_sm90_tma_gmma_ws.hpp",
|
579 |
+
# "flash-attn/mainloop_fwd_sm90_tma_gmma_ws.hpp",
|
580 |
+
# "flash-attn/mask.h",
|
581 |
+
# "flash-attn/named_barrier.hpp",
|
582 |
+
# "flash-attn/pack_gqa.h",
|
583 |
+
# "flash-attn/paged_kv.h",
|
584 |
+
# "flash-attn/rotary.h",
|
585 |
+
# "flash-attn/sm90_pipeline_no_cluster.hpp",
|
586 |
+
# "flash-attn/softmax.h",
|
587 |
+
# "flash-attn/tile_size.h",
|
588 |
+
# "flash-attn/tile_scheduler.hpp",
|
589 |
+
#
|
590 |
+
# "flash-attn/instantiations/flash_fwd_hdim128_bf16_sm100.cu",
|
591 |
+
# ]
|
592 |
+
# include = ["flash-attn"]
|
593 |
+
# depends = ["torch", "cutlass_3_9"]
|
flake.lock
ADDED
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"nodes": {
|
3 |
+
"flake-compat": {
|
4 |
+
"locked": {
|
5 |
+
"lastModified": 1747046372,
|
6 |
+
"narHash": "sha256-CIVLLkVgvHYbgI2UpXvIIBJ12HWgX+fjA8Xf8PUmqCY=",
|
7 |
+
"owner": "edolstra",
|
8 |
+
"repo": "flake-compat",
|
9 |
+
"rev": "9100a0f413b0c601e0533d1d94ffd501ce2e7885",
|
10 |
+
"type": "github"
|
11 |
+
},
|
12 |
+
"original": {
|
13 |
+
"owner": "edolstra",
|
14 |
+
"repo": "flake-compat",
|
15 |
+
"type": "github"
|
16 |
+
}
|
17 |
+
},
|
18 |
+
"flake-compat_2": {
|
19 |
+
"locked": {
|
20 |
+
"lastModified": 1733328505,
|
21 |
+
"narHash": "sha256-NeCCThCEP3eCl2l/+27kNNK7QrwZB1IJCrXfrbv5oqU=",
|
22 |
+
"owner": "edolstra",
|
23 |
+
"repo": "flake-compat",
|
24 |
+
"rev": "ff81ac966bb2cae68946d5ed5fc4994f96d0ffec",
|
25 |
+
"type": "github"
|
26 |
+
},
|
27 |
+
"original": {
|
28 |
+
"owner": "edolstra",
|
29 |
+
"repo": "flake-compat",
|
30 |
+
"type": "github"
|
31 |
+
}
|
32 |
+
},
|
33 |
+
"flake-utils": {
|
34 |
+
"inputs": {
|
35 |
+
"systems": "systems"
|
36 |
+
},
|
37 |
+
"locked": {
|
38 |
+
"lastModified": 1731533236,
|
39 |
+
"narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
|
40 |
+
"owner": "numtide",
|
41 |
+
"repo": "flake-utils",
|
42 |
+
"rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
|
43 |
+
"type": "github"
|
44 |
+
},
|
45 |
+
"original": {
|
46 |
+
"owner": "numtide",
|
47 |
+
"repo": "flake-utils",
|
48 |
+
"type": "github"
|
49 |
+
}
|
50 |
+
},
|
51 |
+
"flake-utils_2": {
|
52 |
+
"inputs": {
|
53 |
+
"systems": "systems_2"
|
54 |
+
},
|
55 |
+
"locked": {
|
56 |
+
"lastModified": 1731533236,
|
57 |
+
"narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
|
58 |
+
"owner": "numtide",
|
59 |
+
"repo": "flake-utils",
|
60 |
+
"rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
|
61 |
+
"type": "github"
|
62 |
+
},
|
63 |
+
"original": {
|
64 |
+
"owner": "numtide",
|
65 |
+
"repo": "flake-utils",
|
66 |
+
"type": "github"
|
67 |
+
}
|
68 |
+
},
|
69 |
+
"hf-nix": {
|
70 |
+
"inputs": {
|
71 |
+
"flake-compat": "flake-compat_2",
|
72 |
+
"flake-utils": "flake-utils_2",
|
73 |
+
"nixpkgs": "nixpkgs"
|
74 |
+
},
|
75 |
+
"locked": {
|
76 |
+
"lastModified": 1750234878,
|
77 |
+
"narHash": "sha256-q9DRC9zdpzUf88qqg1qbhP1qgJbE2cMtn8oUmosuyT8=",
|
78 |
+
"owner": "huggingface",
|
79 |
+
"repo": "hf-nix",
|
80 |
+
"rev": "c7132f90763d756da3e77da62e01be0a4546dc57",
|
81 |
+
"type": "github"
|
82 |
+
},
|
83 |
+
"original": {
|
84 |
+
"owner": "huggingface",
|
85 |
+
"repo": "hf-nix",
|
86 |
+
"type": "github"
|
87 |
+
}
|
88 |
+
},
|
89 |
+
"kernel-builder": {
|
90 |
+
"inputs": {
|
91 |
+
"flake-compat": "flake-compat",
|
92 |
+
"flake-utils": "flake-utils",
|
93 |
+
"hf-nix": "hf-nix",
|
94 |
+
"nixpkgs": [
|
95 |
+
"kernel-builder",
|
96 |
+
"hf-nix",
|
97 |
+
"nixpkgs"
|
98 |
+
]
|
99 |
+
},
|
100 |
+
"locked": {
|
101 |
+
"lastModified": 1750275112,
|
102 |
+
"narHash": "sha256-gqAxmLLt0tYvuRYumOZHQgryMeEFdt6j3nEC8B5rT14=",
|
103 |
+
"owner": "huggingface",
|
104 |
+
"repo": "kernel-builder",
|
105 |
+
"rev": "1b63210b2a1fc3cda2e3a579e7aa8f8c8532626f",
|
106 |
+
"type": "github"
|
107 |
+
},
|
108 |
+
"original": {
|
109 |
+
"owner": "huggingface",
|
110 |
+
"repo": "kernel-builder",
|
111 |
+
"type": "github"
|
112 |
+
}
|
113 |
+
},
|
114 |
+
"nixpkgs": {
|
115 |
+
"locked": {
|
116 |
+
"lastModified": 1747820358,
|
117 |
+
"narHash": "sha256-fTqsZsUX6M3yeEvgyQvXcbGmT2CaRVyVwsi8eK29Oj4=",
|
118 |
+
"owner": "danieldk",
|
119 |
+
"repo": "nixpkgs",
|
120 |
+
"rev": "d3c1681180717528068082103bf323147de6ab0b",
|
121 |
+
"type": "github"
|
122 |
+
},
|
123 |
+
"original": {
|
124 |
+
"owner": "danieldk",
|
125 |
+
"ref": "cudatoolkit-12.9-kernel-builder",
|
126 |
+
"repo": "nixpkgs",
|
127 |
+
"type": "github"
|
128 |
+
}
|
129 |
+
},
|
130 |
+
"root": {
|
131 |
+
"inputs": {
|
132 |
+
"kernel-builder": "kernel-builder"
|
133 |
+
}
|
134 |
+
},
|
135 |
+
"systems": {
|
136 |
+
"locked": {
|
137 |
+
"lastModified": 1681028828,
|
138 |
+
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
|
139 |
+
"owner": "nix-systems",
|
140 |
+
"repo": "default",
|
141 |
+
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
|
142 |
+
"type": "github"
|
143 |
+
},
|
144 |
+
"original": {
|
145 |
+
"owner": "nix-systems",
|
146 |
+
"repo": "default",
|
147 |
+
"type": "github"
|
148 |
+
}
|
149 |
+
},
|
150 |
+
"systems_2": {
|
151 |
+
"locked": {
|
152 |
+
"lastModified": 1681028828,
|
153 |
+
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
|
154 |
+
"owner": "nix-systems",
|
155 |
+
"repo": "default",
|
156 |
+
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
|
157 |
+
"type": "github"
|
158 |
+
},
|
159 |
+
"original": {
|
160 |
+
"owner": "nix-systems",
|
161 |
+
"repo": "default",
|
162 |
+
"type": "github"
|
163 |
+
}
|
164 |
+
}
|
165 |
+
},
|
166 |
+
"root": "root",
|
167 |
+
"version": 7
|
168 |
+
}
|
flake.nix
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
description = "Flake for Hopper Flash Attention kernel";
|
3 |
+
|
4 |
+
inputs = {
|
5 |
+
kernel-builder.url = "github:huggingface/kernel-builder";
|
6 |
+
};
|
7 |
+
|
8 |
+
outputs =
|
9 |
+
{
|
10 |
+
self,
|
11 |
+
kernel-builder,
|
12 |
+
}:
|
13 |
+
kernel-builder.lib.genFlakeOutputs {
|
14 |
+
path = ./.;
|
15 |
+
rev = self.shortRev or self.dirtyShortRev or self.lastModifiedDate;
|
16 |
+
};
|
17 |
+
}
|
flash-attn/block.h
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/******************************************************************************
|
2 |
+
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
3 |
+
******************************************************************************/
|
4 |
+
|
5 |
+
#pragma once
|
6 |
+
|
7 |
+
namespace flash {
|
8 |
+
|
9 |
+
template <class SeqlenInfo_t, int kBlockM, int kBlockN, bool Is_causal, bool Is_local, bool PackGQA=false, bool Split=false>
|
10 |
+
struct BlockMN {
|
11 |
+
|
12 |
+
static
|
13 |
+
CUTLASS_DEVICE
|
14 |
+
cute::tuple<int, int> get_n_block_min_max(
|
15 |
+
SeqlenInfo_t const& seqlen_info,
|
16 |
+
int const m_block, int const bidb, int const split_idx, int const num_splits,
|
17 |
+
int const window_size_left, int const window_size_right,
|
18 |
+
cutlass::FastDivmod const& attention_chunk_divmod,
|
19 |
+
cutlass::FastDivmod const& qhead_per_khead_divmod) {
|
20 |
+
|
21 |
+
int const seqlen_k = seqlen_info.seqlen_k;
|
22 |
+
int const seqlen_q = seqlen_info.seqlen_q;
|
23 |
+
int n_block_max = cute::ceil_div(seqlen_k, kBlockN);
|
24 |
+
if constexpr (Is_causal || Is_local) {
|
25 |
+
int m_idx_max = (m_block + 1) * kBlockM;
|
26 |
+
// TODO: check off-by-1 error
|
27 |
+
if (PackGQA) { m_idx_max = qhead_per_khead_divmod.divide(m_idx_max - 1) + 1 ; }
|
28 |
+
int const n_idx = m_idx_max + seqlen_info.seqlen_k - seqlen_info.seqlen_q;
|
29 |
+
int n_idx_right = !Is_local ? n_idx : n_idx + window_size_right;
|
30 |
+
if (Is_local && attention_chunk_divmod.divisor > 0) {
|
31 |
+
n_idx_right = std::min(n_idx_right, flash::round_up(attention_chunk_divmod, n_idx));
|
32 |
+
}
|
33 |
+
n_block_max = std::min(n_block_max, cute::ceil_div(n_idx_right, kBlockN));
|
34 |
+
}
|
35 |
+
int n_block_min = 0;
|
36 |
+
if constexpr (Is_local) {
|
37 |
+
int m_idx_min = m_block * kBlockM;
|
38 |
+
if (PackGQA) { m_idx_min = qhead_per_khead_divmod.divide(m_idx_min); }
|
39 |
+
int const n_idx = m_idx_min + seqlen_k - seqlen_q;
|
40 |
+
int n_idx_left = n_idx - window_size_left;
|
41 |
+
if (attention_chunk_divmod.divisor > 0) {
|
42 |
+
n_idx_left = std::max(n_idx_left, flash::round_down(attention_chunk_divmod, n_idx));
|
43 |
+
}
|
44 |
+
n_block_min = std::max(int(0), n_idx_left / kBlockN);
|
45 |
+
}
|
46 |
+
// if (threadIdx.x == 128) { printf("Inside, bid.x = %d, bid.y = %d, bid.z = %d, split_idx = %d, n_block_min: %d, n_block_max: %d\n", blockIdx.x, blockIdx.y, blockIdx.z, split_idx, n_block_min, n_block_max); }
|
47 |
+
if constexpr (Split) {
|
48 |
+
uint32_t num_splits_dynamic_u = reinterpret_cast<uint32_t const&>(split_idx) >> 16; // first 16 bits are for num_splits
|
49 |
+
int num_splits_dynamic = reinterpret_cast<int&>(num_splits_dynamic_u);
|
50 |
+
int split_idx_actual = split_idx & 0x0000FFFF;
|
51 |
+
int num_splits_actual = num_splits_dynamic > 0 ? num_splits_dynamic : num_splits;
|
52 |
+
int num_n_blocks_per_split = n_block_max <= n_block_min ? 0 : cute::ceil_div(n_block_max - n_block_min, num_splits_actual);
|
53 |
+
n_block_min = n_block_min + split_idx_actual * num_n_blocks_per_split;
|
54 |
+
n_block_max = std::min(n_block_min + num_n_blocks_per_split, n_block_max);
|
55 |
+
// if (threadIdx.x == 128) { printf("Inside, bid.x = %d, bid.y = %d, bid.z = %d, split_idx = %d, num_splits_dynamic = %d, num_splits_actual = %d, num_n_blocks_per_split = %d, n_block_min: %d, n_block_max: %d\n", blockIdx.x, blockIdx.y, blockIdx.z, split_idx, num_splits_dynamic, num_splits_actual, num_n_blocks_per_split, n_block_min, n_block_max); }
|
56 |
+
}
|
57 |
+
// if (threadIdx.x == 128) { printf("After split, inside, bid.y = %d, bid.z = %d, split_idx = %d, n_block_min: %d, n_block_max: %d\n", blockIdx.y, blockIdx.z, split_idx, n_block_min, n_block_max); }
|
58 |
+
return {n_block_min, n_block_max};
|
59 |
+
}
|
60 |
+
|
61 |
+
static
|
62 |
+
CUTLASS_DEVICE
|
63 |
+
cute::tuple<int, int> get_n_block_k_new_min_max(
|
64 |
+
SeqlenInfo_t const& seqlen_info,
|
65 |
+
int const m_block, int const bidb, int const split_idx, int const num_splits,
|
66 |
+
int const window_size_left, int const window_size_right,
|
67 |
+
cutlass::FastDivmod const& attention_chunk_divmod,
|
68 |
+
cutlass::FastDivmod const& qhead_per_khead_divmod) {
|
69 |
+
|
70 |
+
auto [n_block_min, n_block_max] = get_n_block_min_max(
|
71 |
+
seqlen_info, m_block, bidb, split_idx, num_splits,
|
72 |
+
window_size_left, window_size_right, attention_chunk_divmod, qhead_per_khead_divmod);
|
73 |
+
int const idx_k_new_min = std::max(n_block_min * kBlockN - seqlen_info.seqlen_k_og, 0);
|
74 |
+
int const idx_k_new_max = std::min(n_block_max * kBlockN - seqlen_info.seqlen_k_og, seqlen_info.seqlen_k_new);
|
75 |
+
int const n_block_new_min = idx_k_new_min / kBlockN;
|
76 |
+
int const n_block_new_max = idx_k_new_max > idx_k_new_min ? cute::ceil_div(idx_k_new_max, kBlockN) : n_block_new_min;
|
77 |
+
// if (threadIdx.x == 128 && m_block == 0) { printf("bidb = %d, seqlen_k_new = %d, seqlen_k_og = %d, n_block_min = %d, n_block_max = %d, idx_k_new_min = %d, idx_k_new_max = %d, n_block_new_min = %d, n_block_new_max = %d\n", bidb, seqlen_k_new, seqlen_k_og, n_block_min, n_block_max, idx_k_new_min, idx_k_new_max, n_block_new_min, n_block_new_max);}
|
78 |
+
return {n_block_new_min, n_block_new_max};
|
79 |
+
}
|
80 |
+
|
81 |
+
static
|
82 |
+
CUTLASS_DEVICE
|
83 |
+
cute::tuple<int, int> get_m_block_min_max(
|
84 |
+
SeqlenInfo_t const& seqlen_info,
|
85 |
+
int const n_block, int const bidb,
|
86 |
+
int const window_size_left, int const window_size_right, int const sink_token_length) {
|
87 |
+
// TODO: support attention_chunk
|
88 |
+
int const seqlen_q = seqlen_info.seqlen_q;
|
89 |
+
int const seqlen_k = seqlen_info.seqlen_k;
|
90 |
+
int m_block_max = cute::ceil_div(seqlen_q, kBlockM);
|
91 |
+
if constexpr (Is_local) {
|
92 |
+
if (n_block >= cute::ceil_div(sink_token_length, kBlockN)) {
|
93 |
+
m_block_max = std::min(m_block_max, cute::ceil_div((n_block + 1) * kBlockN + seqlen_q - seqlen_k + window_size_left, kBlockM));
|
94 |
+
}
|
95 |
+
}
|
96 |
+
int m_block_min = 0;
|
97 |
+
if constexpr (Is_causal || Is_local) {
|
98 |
+
m_block_min = std::max(m_block_min, (n_block * kBlockN + seqlen_q - seqlen_k - window_size_right) / kBlockM);
|
99 |
+
}
|
100 |
+
return {m_block_min, m_block_max};
|
101 |
+
}
|
102 |
+
|
103 |
+
// If we have separate iterations with causal or local masking at the start, where do we stop
|
104 |
+
static
|
105 |
+
CUTLASS_DEVICE
|
106 |
+
int get_n_block_min_causal_local_mask(
|
107 |
+
SeqlenInfo_t const& seqlen_info,
|
108 |
+
int const m_block, int const n_block_min, int const window_size_right,
|
109 |
+
cutlass::FastDivmod const& attention_chunk_divmod,
|
110 |
+
cutlass::FastDivmod const& qhead_per_khead_divmod) {
|
111 |
+
int const m_idx_min = !PackGQA ? m_block * kBlockM : qhead_per_khead_divmod.divide(m_block * kBlockM);
|
112 |
+
int const n_idx = m_idx_min + seqlen_info.seqlen_k - seqlen_info.seqlen_q;
|
113 |
+
int n_idx_right = !Is_local ? n_idx : n_idx + window_size_right;
|
114 |
+
if (Is_local && attention_chunk_divmod.divisor > 0) {
|
115 |
+
n_idx_right = std::min(n_idx_right, flash::round_up(attention_chunk_divmod, n_idx));
|
116 |
+
}
|
117 |
+
return std::max(n_block_min, n_idx_right / kBlockN);
|
118 |
+
}
|
119 |
+
|
120 |
+
// If we have separate iterations with local masking at the end, where do we stop the non-masked iterations
|
121 |
+
static
|
122 |
+
CUTLASS_DEVICE
|
123 |
+
int get_n_block_min_before_local_mask(
|
124 |
+
SeqlenInfo_t const& seqlen_info,
|
125 |
+
int const m_block, int const n_block_min, int const window_size_left,
|
126 |
+
cutlass::FastDivmod const& attention_chunk_divmod,
|
127 |
+
cutlass::FastDivmod const& qhead_per_khead_divmod) {
|
128 |
+
int const m_idx_max = !PackGQA ? (m_block + 1) * kBlockM : qhead_per_khead_divmod.divide((m_block + 1) * kBlockM - 1) + 1;
|
129 |
+
int const n_idx = m_idx_max + seqlen_info.seqlen_k - seqlen_info.seqlen_q;
|
130 |
+
int n_idx_left = !Is_local ? n_idx : n_idx - window_size_left;
|
131 |
+
if (Is_local && attention_chunk_divmod.divisor > 0) {
|
132 |
+
n_idx_left = std::max(n_idx_left, flash::round_down(attention_chunk_divmod, n_idx));
|
133 |
+
}
|
134 |
+
return !Is_local ? n_block_min : std::max(n_block_min, cute::ceil_div(n_idx_left, kBlockN));
|
135 |
+
}
|
136 |
+
|
137 |
+
};
|
138 |
+
|
139 |
+
} // namespace flash
|
flash-attn/copy_sm90_bulk_reduce.hpp
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/******************************************************************************
|
2 |
+
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
3 |
+
******************************************************************************/
|
4 |
+
|
5 |
+
#pragma once
|
6 |
+
|
7 |
+
#include<cute/arch/copy_sm90_tma.hpp>
|
8 |
+
|
9 |
+
namespace cute
|
10 |
+
{
|
11 |
+
|
12 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
13 |
+
|
14 |
+
struct SM90_BULK_REDUCE_ADD
|
15 |
+
{
|
16 |
+
CUTE_HOST_DEVICE static void
|
17 |
+
copy(float const* smem_ptr,
|
18 |
+
float * gmem_ptr, int32_t store_bytes)
|
19 |
+
{
|
20 |
+
#if defined(CUTE_ARCH_TMA_SM90_ENABLED)
|
21 |
+
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr);
|
22 |
+
asm volatile("cp.reduce.async.bulk.global.shared::cta.bulk_group.add.f32 [%0], [%1], %2;\n"
|
23 |
+
:
|
24 |
+
: "l"(gmem_ptr), "r"(smem_int_ptr), "r"(store_bytes)
|
25 |
+
: "memory");
|
26 |
+
#else
|
27 |
+
CUTE_INVALID_CONTROL_PATH("Trying to use BULK_REDUCE_ADD without CUTE_ARCH_TMA_SM90_ENABLED.");
|
28 |
+
#endif
|
29 |
+
}
|
30 |
+
|
31 |
+
CUTE_HOST_DEVICE static void
|
32 |
+
copy(float const* smem_ptr,
|
33 |
+
float * gmem_ptr, int32_t store_bytes, uint64_t cache_hint)
|
34 |
+
{
|
35 |
+
#if defined(CUTE_ARCH_TMA_SM90_ENABLED)
|
36 |
+
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr);
|
37 |
+
asm volatile("cp.reduce.async.bulk.global.shared::cta.bulk_group.L2::cache_hint.add.f32 [%0], [%1], %2, %3;\n"
|
38 |
+
:
|
39 |
+
: "l"(gmem_ptr), "r"(smem_int_ptr), "r"(store_bytes), "l"(cache_hint)
|
40 |
+
: "memory");
|
41 |
+
#else
|
42 |
+
CUTE_INVALID_CONTROL_PATH("Trying to use BULK_REDUCE_ADD without CUTE_ARCH_TMA_SM90_ENABLED.");
|
43 |
+
#endif
|
44 |
+
}
|
45 |
+
};
|
46 |
+
|
47 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
48 |
+
|
49 |
+
} // end namespace cute
|
flash-attn/cuda_check.h
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/******************************************************************************
|
2 |
+
* Copyright (c) 2024, Tri Dao.
|
3 |
+
******************************************************************************/
|
4 |
+
|
5 |
+
#pragma once
|
6 |
+
|
7 |
+
#include <assert.h>
|
8 |
+
#include <stdlib.h>
|
9 |
+
|
10 |
+
#define CHECK_CUDA(call) \
|
11 |
+
do { \
|
12 |
+
cudaError_t status_ = call; \
|
13 |
+
if (status_ != cudaSuccess) { \
|
14 |
+
fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__, cudaGetErrorString(status_)); \
|
15 |
+
exit(1); \
|
16 |
+
} \
|
17 |
+
} while(0)
|
18 |
+
|
19 |
+
#define CHECK_CUDA_KERNEL_LAUNCH() CHECK_CUDA(cudaGetLastError())
|
flash-attn/epilogue_bwd.hpp
ADDED
@@ -0,0 +1,533 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/******************************************************************************
|
2 |
+
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
3 |
+
******************************************************************************/
|
4 |
+
|
5 |
+
#pragma once
|
6 |
+
|
7 |
+
#include "cutlass/cutlass.h"
|
8 |
+
#include "cutlass/barrier.h"
|
9 |
+
#include "cute/tensor.hpp"
|
10 |
+
|
11 |
+
#include "cutlass/gemm/collective/builders/sm90_common.inl"
|
12 |
+
|
13 |
+
#include "seqlen.h"
|
14 |
+
#include "named_barrier.hpp"
|
15 |
+
#include "utils.h"
|
16 |
+
|
17 |
+
namespace flash {
|
18 |
+
|
19 |
+
using namespace cute;
|
20 |
+
|
21 |
+
template <class TileShape_MNK_, class Element_, class ArchTag_,
|
22 |
+
int NumEpilogueThreads_, bool Varlen_, bool dKV_swapAB_, int AtomLayoutKdKV=1>
|
23 |
+
struct CollectiveEpilogueBwd {
|
24 |
+
|
25 |
+
using TileShape_MNK = TileShape_MNK_;
|
26 |
+
using Element = Element_;
|
27 |
+
using ArchTag = ArchTag_;
|
28 |
+
static constexpr int NumEpilogueThreads = NumEpilogueThreads_;
|
29 |
+
static constexpr bool Varlen = Varlen_;
|
30 |
+
static constexpr bool dKV_swapAB = dKV_swapAB_;
|
31 |
+
static constexpr bool Use_TMA = !Varlen && ArchTag::kMinComputeCapability >= 90;
|
32 |
+
|
33 |
+
static_assert(ArchTag::kMinComputeCapability >= 80);
|
34 |
+
|
35 |
+
using GmemTiledCopydKVTMA = cute::SM90_TMA_STORE;
|
36 |
+
|
37 |
+
// These are for storing the output tensor without TMA (e.g., for setting output to zero)
|
38 |
+
static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
|
39 |
+
static_assert(get<2>(TileShape_MNK{}) % kGmemElemsPerLoad == 0, "Headdim must be a multiple of kGmemElemsPerLoad");
|
40 |
+
static constexpr int kHeadDim = get<2>(TileShape_MNK{});
|
41 |
+
static constexpr int kGmemThreadsPerRow = cutlass::gcd(kHeadDim / kGmemElemsPerLoad, NumEpilogueThreads);
|
42 |
+
static_assert(NumEpilogueThreads % kGmemThreadsPerRow == 0, "NumEpilogueThreads must be a multiple of kGmemThreadsPerRow");
|
43 |
+
using GmemLayoutAtom = Layout<Shape <Int<NumEpilogueThreads / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
|
44 |
+
Stride<Int<kGmemThreadsPerRow>, _1>>;
|
45 |
+
using GmemTiledCopydKV = decltype(
|
46 |
+
make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>{},
|
47 |
+
GmemLayoutAtom{},
|
48 |
+
Layout<Shape<_1, Int<kGmemElemsPerLoad>>>{})); // Val layout, 8 or 16 vals per store
|
49 |
+
|
50 |
+
using SmemLayoutAtomdKVTMA = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
|
51 |
+
// TODO: do we have to change this if dKV_swapAB is true?
|
52 |
+
decltype(cute::get<1>(TileShape_MNK{})), Int<CUTE_STATIC_V(cute::get<2>(TileShape_MNK{})) / AtomLayoutKdKV>>());
|
53 |
+
using SmemLayoutdKVTMA = decltype(tile_to_shape(SmemLayoutAtomdKVTMA{}, select<1, 2>(TileShape_MNK{})));
|
54 |
+
using SmemLayoutdKVtTMA =
|
55 |
+
decltype(cute::composition(SmemLayoutdKVTMA{},
|
56 |
+
make_layout(make_shape(get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{})),
|
57 |
+
make_stride(decltype(get<1>(TileShape_MNK{})){}, _1{}))));
|
58 |
+
|
59 |
+
// If we don't use TMA
|
60 |
+
static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : (kHeadDim % 32 == 0 ? 32 : 16);
|
61 |
+
static constexpr int kSwizzle = kBlockKSmem == 64 ? 3 : (kBlockKSmem == 32 ? 2 : 1);
|
62 |
+
using SmemLayoutAtomdKVSTG =
|
63 |
+
decltype(composition(Swizzle<kSwizzle, 3, 3>{},
|
64 |
+
Layout<Shape<Int<8>, Int<kBlockKSmem>>,
|
65 |
+
Stride<Int<kBlockKSmem>, _1>>{}));
|
66 |
+
|
67 |
+
using SmemLayoutAtomdKV = std::conditional_t<Use_TMA, SmemLayoutAtomdKVTMA, SmemLayoutAtomdKVSTG>;
|
68 |
+
using SmemLayoutdKV = decltype(tile_to_shape(SmemLayoutAtomdKV{}, select<1, 2>(TileShape_MNK{})));
|
69 |
+
using SmemLayoutdKVt =
|
70 |
+
decltype(cute::composition(SmemLayoutdKV{},
|
71 |
+
make_layout(make_shape(get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{})),
|
72 |
+
make_stride(decltype(get<1>(TileShape_MNK{})){}, _1{}))));
|
73 |
+
|
74 |
+
using SmemCopyAtomdKV = Copy_Atom<
|
75 |
+
std::conditional_t<
|
76 |
+
ArchTag::kMinComputeCapability >= 90,
|
77 |
+
std::conditional_t<!dKV_swapAB, cute::SM90_U32x4_STSM_N, cute::SM90_U16x8_STSM_T>,
|
78 |
+
AutoVectorizingCopyWithAssumedAlignment<128>
|
79 |
+
>,
|
80 |
+
Element>;
|
81 |
+
|
82 |
+
static constexpr size_t SmemAlignmentdKV = ArchTag::kMinComputeCapability >= 90 ? cutlass::detail::alignment_for_swizzle(SmemLayoutdKV{}) : 128;
|
83 |
+
static_assert(SmemAlignmentdKV >= 128, "Require at least 128B alignment");
|
84 |
+
|
85 |
+
struct TensorStorage : cute::aligned_struct<SmemAlignmentdKV> {
|
86 |
+
cute::array_aligned<Element, cute::cosize_v<SmemLayoutdKV>, SmemAlignmentdKV> smem_dk;
|
87 |
+
cute::array_aligned<Element, cute::cosize_v<SmemLayoutdKV>, SmemAlignmentdKV> smem_dv;
|
88 |
+
};
|
89 |
+
|
90 |
+
using ShapedKV = cute::Shape<int32_t, int32_t, int32_t, int32_t>; // (seqlen_k, d, head, batch)
|
91 |
+
using StridedKV = cute::Stride<int64_t, _1, int64_t, int64_t>;
|
92 |
+
|
93 |
+
using TMA_dKV = std::conditional_t<
|
94 |
+
Use_TMA,
|
95 |
+
decltype(make_tma_copy(
|
96 |
+
GmemTiledCopydKVTMA{},
|
97 |
+
make_tensor(make_gmem_ptr(static_cast<Element*>(nullptr)), ShapedKV{}, StridedKV{}),
|
98 |
+
SmemLayoutdKVTMA{},
|
99 |
+
select<1, 2>(TileShape_MNK{}),
|
100 |
+
_1{})), // no mcast for dKV
|
101 |
+
std::nullptr_t
|
102 |
+
>;
|
103 |
+
|
104 |
+
// Host side kernel arguments
|
105 |
+
struct Arguments {
|
106 |
+
Element* ptr_dK;
|
107 |
+
ShapedKV const shape_dK;
|
108 |
+
StridedKV const stride_dK;
|
109 |
+
Element* ptr_dV;
|
110 |
+
ShapedKV const shape_dV;
|
111 |
+
StridedKV const stride_dV;
|
112 |
+
int const num_heads_q;
|
113 |
+
int* dk_semaphore;
|
114 |
+
int* dv_semaphore;
|
115 |
+
int const* cu_seqlens;
|
116 |
+
int const* seqused;
|
117 |
+
};
|
118 |
+
|
119 |
+
// Device side kernel params
|
120 |
+
struct Params {
|
121 |
+
Element* ptr_dK;
|
122 |
+
ShapedKV const shape_dK;
|
123 |
+
StridedKV const stride_dK;
|
124 |
+
Element* ptr_dV;
|
125 |
+
ShapedKV const shape_dV;
|
126 |
+
StridedKV const stride_dV;
|
127 |
+
TMA_dKV tma_store_dK, tma_store_dV;
|
128 |
+
int const* cu_seqlens = nullptr;
|
129 |
+
int const* seqused = nullptr;
|
130 |
+
};
|
131 |
+
|
132 |
+
static Params
|
133 |
+
to_underlying_arguments(Arguments const& args) {
|
134 |
+
Tensor mdK = make_tensor(make_gmem_ptr(args.ptr_dK), args.shape_dK, args.stride_dK);
|
135 |
+
Tensor mdV = make_tensor(make_gmem_ptr(args.ptr_dV), args.shape_dV, args.stride_dV);
|
136 |
+
TMA_dKV tma_store_dK = [&] {
|
137 |
+
if constexpr (Use_TMA) {
|
138 |
+
return make_tma_copy(GmemTiledCopydKVTMA{}, mdK, SmemLayoutdKVTMA{}, select<1, 2>(TileShape_MNK{}), _1{}); // no mcast for dKV
|
139 |
+
} else {
|
140 |
+
return nullptr;
|
141 |
+
}
|
142 |
+
}();
|
143 |
+
TMA_dKV tma_store_dV = [&] {
|
144 |
+
if constexpr (Use_TMA) {
|
145 |
+
return make_tma_copy(GmemTiledCopydKVTMA{}, mdV, SmemLayoutdKVTMA{}, select<1, 2>(TileShape_MNK{}), _1{}); // no mcast for dKV
|
146 |
+
} else {
|
147 |
+
return nullptr;
|
148 |
+
}
|
149 |
+
}();
|
150 |
+
return {args.ptr_dK, args.shape_dK, args.stride_dK, args.ptr_dV, args.shape_dV, args.stride_dV,
|
151 |
+
tma_store_dK, tma_store_dV, args.cu_seqlens, args.seqused};
|
152 |
+
}
|
153 |
+
|
154 |
+
/// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
|
155 |
+
CUTLASS_DEVICE
|
156 |
+
static void prefetch_tma_descriptors(Params const& params) {
|
157 |
+
if constexpr (Use_TMA) {
|
158 |
+
cute::prefetch_tma_descriptor(params.tma_store_dK.get_tma_descriptor());
|
159 |
+
cute::prefetch_tma_descriptor(params.tma_store_dV.get_tma_descriptor());
|
160 |
+
}
|
161 |
+
}
|
162 |
+
|
163 |
+
template <typename SharedStorage, typename FrgTensorO, typename TiledMma>
|
164 |
+
CUTLASS_DEVICE void
|
165 |
+
store(Params const& params,
|
166 |
+
FrgTensorO const& tdKrdK,
|
167 |
+
FrgTensorO const& tdVrdV,
|
168 |
+
SharedStorage& shared_storage,
|
169 |
+
TiledMma tiled_mma,
|
170 |
+
int thread_idx,
|
171 |
+
cute::tuple<int32_t, int32_t, int32_t> const& block_coord
|
172 |
+
) {
|
173 |
+
|
174 |
+
auto [n_block, bidh, bidb] = block_coord;
|
175 |
+
Tensor sdK = cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_dk.data()), SmemLayoutdKV{}));
|
176 |
+
Tensor sdV = cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_dv.data()), SmemLayoutdKV{}));
|
177 |
+
Tensor sdKt = cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_dk.data()), SmemLayoutdKVt{}));
|
178 |
+
Tensor sdVt = cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_dv.data()), SmemLayoutdKVt{}));
|
179 |
+
auto smem_tiled_copy_dKV = make_tiled_copy_C(SmemCopyAtomdKV{}, tiled_mma);
|
180 |
+
auto smem_thr_copy_dKV = smem_tiled_copy_dKV.get_thread_slice(thread_idx);
|
181 |
+
|
182 |
+
Tensor tdVrdV_out = make_tensor_like<Element>(tdVrdV);
|
183 |
+
flash::convert_type_out(tdVrdV, tdVrdV_out);
|
184 |
+
Tensor tdKrdK_out = make_tensor_like<Element>(tdKrdK);
|
185 |
+
flash::convert_type_out(tdKrdK, tdKrdK_out);
|
186 |
+
Tensor taccdKrdK = smem_thr_copy_dKV.retile_S(tdKrdK_out); // ((Atom,AtomNum), MMA_M, MMA_N)
|
187 |
+
Tensor taccdVrdV = smem_thr_copy_dKV.retile_S(tdVrdV_out); // ((Atom,AtomNum), MMA_M, MMA_N)
|
188 |
+
// if (blockIdx.x == 0 && threadIdx.x == 128) { print(smem_thr_copy_dKV); print(sdK); printf("\n"); print(sdKt); printf("\n"); }
|
189 |
+
Tensor taccdKsdK = smem_thr_copy_dKV.partition_D(cute::conditional_return<!dKV_swapAB>(sdK, sdKt)); // ((Atom,AtomNum),PIPE_M,PIPE_N)
|
190 |
+
Tensor taccdVsdV = smem_thr_copy_dKV.partition_D(cute::conditional_return<!dKV_swapAB>(sdV, sdVt)); // ((Atom,AtomNum),PIPE_M,PIPE_N)
|
191 |
+
|
192 |
+
// Make sure all WGs have finished reading K and V
|
193 |
+
flash::named_barrier_sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
|
194 |
+
cute::copy(smem_tiled_copy_dKV, taccdVrdV, taccdVsdV);
|
195 |
+
cute::copy(smem_tiled_copy_dKV, taccdKrdK, taccdKsdK);
|
196 |
+
if constexpr (Use_TMA) {
|
197 |
+
cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible to TMA
|
198 |
+
cutlass::arch::NamedBarrier::arrive(NumEpilogueThreads + cutlass::NumThreadsPerWarp,
|
199 |
+
cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
|
200 |
+
|
201 |
+
Tensor mdK = params.tma_store_dK.get_tma_tensor(params.shape_dK);
|
202 |
+
Tensor mdV = params.tma_store_dV.get_tma_tensor(params.shape_dV);
|
203 |
+
Tensor gdK = local_tile(mdK(_, _, bidh, bidb), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K)
|
204 |
+
Tensor gdV = local_tile(mdV(_, _, bidh, bidb), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K)
|
205 |
+
auto block_tma_dK = params.tma_store_dK.get_slice(_0{});
|
206 |
+
auto block_tma_dV = params.tma_store_dV.get_slice(_0{});
|
207 |
+
Tensor tdKgdK = block_tma_dK.partition_D(gdK); // (TMA, TMA_M, TMA_K)
|
208 |
+
Tensor tdKsdK = block_tma_dK.partition_S(sdK); // (TMA, TMA_M, TMA_K)
|
209 |
+
Tensor tdVgdV = block_tma_dV.partition_D(gdV); // (TMA, TMA_M, TMA_K)
|
210 |
+
Tensor tdVsdV = block_tma_dV.partition_S(sdV); // (TMA, TMA_M, TMA_K)
|
211 |
+
int warp_idx_sync = __shfl_sync(0xffffffff, thread_idx / cutlass::NumThreadsPerWarp, 0);
|
212 |
+
if (warp_idx_sync == NumEpilogueThreads / cutlass::NumThreadsPerWarp - 1) {
|
213 |
+
cutlass::arch::NamedBarrier::sync(NumEpilogueThreads + cutlass::NumThreadsPerWarp,
|
214 |
+
cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
|
215 |
+
if (cute::elect_one_sync()) {
|
216 |
+
cute::copy(params.tma_store_dV, tdVsdV, tdVgdV);
|
217 |
+
cute::copy(params.tma_store_dK, tdKsdK, tdKgdK);
|
218 |
+
tma_store_arrive();
|
219 |
+
}
|
220 |
+
}
|
221 |
+
tma_store_wait<0>();
|
222 |
+
// // Tell warp 0 that smem_k and smem_v are ready
|
223 |
+
// cutlass::arch::NamedBarrier::arrive(NumEpilogueThreads + cutlass::NumThreadsPerWarp, static_cast<uint32_t>(BwdNamedBarriers::KVEmpty) /*id*/);
|
224 |
+
|
225 |
+
} else {
|
226 |
+
flash::named_barrier_sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
|
227 |
+
static constexpr int kBlockN = get<1>(TileShape_MNK{});
|
228 |
+
flash::SeqlenInfo<Varlen, kBlockN> seqlen_info{bidb, size<0>(params.shape_dK), params.cu_seqlens, params.seqused};
|
229 |
+
bool const is_varlen = Varlen && params.cu_seqlens;
|
230 |
+
Tensor mdK = make_tensor(make_gmem_ptr(params.ptr_dK), params.shape_dK, params.stride_dK)(_, _, bidh, !is_varlen ? bidb : 0);
|
231 |
+
Tensor gdK = local_tile(cute::domain_offset(make_coord(seqlen_info.offset, _0{}), mdK), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K)
|
232 |
+
Tensor mdV = make_tensor(make_gmem_ptr(params.ptr_dV), params.shape_dV, params.stride_dV)(_, _, bidh, !is_varlen ? bidb : 0);
|
233 |
+
Tensor gdV = local_tile(cute::domain_offset(make_coord(seqlen_info.offset, _0{}), mdV), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K)
|
234 |
+
|
235 |
+
GmemTiledCopydKV gmem_tiled_copy_dKV;
|
236 |
+
auto gmem_thr_copy_dKV = gmem_tiled_copy_dKV.get_thread_slice(thread_idx);
|
237 |
+
Tensor tdKVgdV = gmem_thr_copy_dKV.partition_D(gdV);
|
238 |
+
Tensor tdKVsdV = gmem_thr_copy_dKV.partition_S(sdV); // (TMA, TMA_M, TMA_K)
|
239 |
+
Tensor tdKVgdK = gmem_thr_copy_dKV.partition_D(gdK);
|
240 |
+
Tensor tdKVsdK = gmem_thr_copy_dKV.partition_S(sdK); // (TMA, TMA_M, TMA_K)
|
241 |
+
Tensor tdKVrdV = make_fragment_like(tdKVgdV);
|
242 |
+
Tensor tdKVrdK = make_fragment_like(tdKVgdK);
|
243 |
+
Tensor cdKV = cute::make_identity_tensor(select<1, 2>(TileShape_MNK{})); // (BLK_N,BLK_K) -> (blk_n,blk_k)
|
244 |
+
// Repeat the partitioning with identity layouts
|
245 |
+
Tensor tdKVcdKV = gmem_thr_copy_dKV.partition_D(cdKV);
|
246 |
+
Tensor tdKVpdV = make_tensor<bool>(make_shape(size<2>(tdKVgdV)));
|
247 |
+
Tensor tdKVpdK = make_tensor<bool>(make_shape(size<2>(tdKVgdK)));
|
248 |
+
#pragma unroll
|
249 |
+
for (int k = 0; k < size(tdKVpdV); ++k) { tdKVpdV(k) = get<1>(tdKVcdKV(_0{}, _0{}, k)) < get<1>(params.shape_dV); }
|
250 |
+
#pragma unroll
|
251 |
+
for (int k = 0; k < size(tdKVpdK); ++k) { tdKVpdK(k) = get<1>(tdKVcdKV(_0{}, _0{}, k)) < get<1>(params.shape_dK); }
|
252 |
+
// Need to check OOB when reading from smem if kBlockN isn't evenly tiled
|
253 |
+
static constexpr bool EvenN = kBlockN % CUTE_STATIC_V(size<0>(GmemLayoutAtom{})) == 0;
|
254 |
+
flash::copy</*Is_even_MN=*/EvenN, /*Is_even_K=*/true, /*Clear_OOB_MN=*/false>(
|
255 |
+
gmem_tiled_copy_dKV, tdKVsdV, tdKVrdV, tdKVcdKV, tdKVpdV, kBlockN);
|
256 |
+
flash::copy</*Is_even_MN=*/EvenN, /*Is_even_K=*/true, /*Clear_OOB_MN=*/false>(
|
257 |
+
gmem_tiled_copy_dKV, tdKVsdK, tdKVrdK, tdKVcdKV, tdKVpdK, kBlockN);
|
258 |
+
// // Tell warp 0 that smem_k and smem_v are ready
|
259 |
+
// cutlass::arch::fence_view_async_shared(); // ensure smem reads are done before next TMA to smem_k/v
|
260 |
+
// flash::named_barrier_arrive(NumEpilogueThreads + cutlass::NumThreadsPerWarp, static_cast<uint32_t>(BwdNamedBarriers::KVEmpty) /*id*/);
|
261 |
+
// Construct identity layout for gdKV
|
262 |
+
// Clear_OOB_K must be false since we don't want to write zeros to gmem
|
263 |
+
flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
|
264 |
+
gmem_tiled_copy_dKV, tdKVrdV, tdKVgdV, tdKVcdKV, tdKVpdV, std::min(seqlen_info.seqlen - n_block * kBlockN, kBlockN)
|
265 |
+
);
|
266 |
+
flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
|
267 |
+
gmem_tiled_copy_dKV, tdKVrdK, tdKVgdK, tdKVcdKV, tdKVpdK, std::min(seqlen_info.seqlen - n_block * kBlockN, kBlockN)
|
268 |
+
);
|
269 |
+
}
|
270 |
+
}
|
271 |
+
|
272 |
+
CUTLASS_DEVICE void
|
273 |
+
store_tail() {
|
274 |
+
// if constexpr (Use_TMA) { tma_store_wait<0>(); }
|
275 |
+
}
|
276 |
+
|
277 |
+
// Write 0 to dK and dV
|
278 |
+
CUTLASS_DEVICE void
|
279 |
+
store_zero(
|
280 |
+
Params const& params,
|
281 |
+
int thread_idx,
|
282 |
+
cute::tuple<int32_t, int32_t, int32_t> const& block_coord
|
283 |
+
) {
|
284 |
+
static constexpr int kBlockN = get<1>(TileShape_MNK{});
|
285 |
+
auto [n_block, bidh, bidb] = block_coord;
|
286 |
+
flash::SeqlenInfo<Varlen, kBlockN> seqlen_info{bidb, size<0>(params.shape_dK), params.cu_seqlens, params.seqused};
|
287 |
+
bool const is_varlen = Varlen && params.cu_seqlens;
|
288 |
+
Tensor mdK = make_tensor(make_gmem_ptr(params.ptr_dK), params.shape_dK, params.stride_dK)(_, _, bidh, !is_varlen ? bidb : 0);
|
289 |
+
Tensor gdK = local_tile(cute::domain_offset(make_coord(seqlen_info.offset, _0{}), mdK), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K)
|
290 |
+
Tensor mdV = make_tensor(make_gmem_ptr(params.ptr_dV), params.shape_dV, params.stride_dV)(_, _, bidh, !is_varlen ? bidb : 0);
|
291 |
+
Tensor gdV = local_tile(cute::domain_offset(make_coord(seqlen_info.offset, _0{}), mdV), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K)
|
292 |
+
|
293 |
+
GmemTiledCopydKV gmem_tiled_copy_dKV;
|
294 |
+
auto gmem_thr_copy_dKV = gmem_tiled_copy_dKV.get_thread_slice(thread_idx);
|
295 |
+
Tensor tdKVgdK = gmem_thr_copy_dKV.partition_D(gdK);
|
296 |
+
Tensor tdKVgdV = gmem_thr_copy_dKV.partition_D(gdV);
|
297 |
+
Tensor tdKVrdKV = make_fragment_like(tdKVgdK);
|
298 |
+
clear(tdKVrdKV);
|
299 |
+
// Construct identity layout for gdKV
|
300 |
+
Tensor cdKV = cute::make_identity_tensor(select<1, 2>(TileShape_MNK{})); // (BLK_M,BLK_K) -> (blk_m,blk_k)
|
301 |
+
// Repeat the partitioning with identity layouts
|
302 |
+
Tensor tdKVcdKV = gmem_thr_copy_dKV.partition_D(cdKV);
|
303 |
+
Tensor tdKVpdK = make_tensor<bool>(make_shape(size<2>(tdKVgdK)));
|
304 |
+
Tensor tdKVpdV = make_tensor<bool>(make_shape(size<2>(tdKVgdV)));
|
305 |
+
#pragma unroll
|
306 |
+
for (int k = 0; k < size(tdKVpdK); ++k) { tdKVpdK(k) = get<1>(tdKVcdKV(_0{}, _0{}, k)) < get<1>(params.shape_dK); }
|
307 |
+
#pragma unroll
|
308 |
+
for (int k = 0; k < size(tdKVpdV); ++k) { tdKVpdV(k) = get<1>(tdKVcdKV(_0{}, _0{}, k)) < get<1>(params.shape_dV); }
|
309 |
+
// Clear_OOB_K must be false since we don't want to write zeros to gmem
|
310 |
+
flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
|
311 |
+
gmem_tiled_copy_dKV, tdKVrdKV, tdKVgdK, tdKVcdKV, tdKVpdK, seqlen_info.seqlen - n_block * kBlockN
|
312 |
+
);
|
313 |
+
flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
|
314 |
+
gmem_tiled_copy_dKV, tdKVrdKV, tdKVgdV, tdKVcdKV, tdKVpdV, seqlen_info.seqlen - n_block * kBlockN
|
315 |
+
);
|
316 |
+
}
|
317 |
+
|
318 |
+
};
|
319 |
+
|
320 |
+
template <class TileShape_MNK_, class ElementAccum, class ArchTag_,
|
321 |
+
int NumEpilogueThreads_, bool Varlen_, bool Deterministic>
|
322 |
+
struct CollectiveEpilogueBwdGQA {
|
323 |
+
|
324 |
+
using TileShape_MNK = TileShape_MNK_;
|
325 |
+
using Element = ElementAccum;
|
326 |
+
using ArchTag = ArchTag_;
|
327 |
+
static constexpr int NumEpilogueThreads = NumEpilogueThreads_;
|
328 |
+
static constexpr bool Varlen = Varlen_;
|
329 |
+
static constexpr bool Use_TMA = ArchTag::kMinComputeCapability >= 90;
|
330 |
+
|
331 |
+
static_assert(ArchTag::kMinComputeCapability >= 80);
|
332 |
+
|
333 |
+
static constexpr int kBlockN = get<1>(TileShape_MNK{});
|
334 |
+
static constexpr int kHeadDim = get<2>(TileShape_MNK{});
|
335 |
+
static_assert(NumEpilogueThreads % cutlass::NumThreadsPerWarp == 0, "NumEpilogueThreads must be a multiple of NumThreadsPerWarp");
|
336 |
+
static constexpr int NumWarpGroups = NumEpilogueThreads / cutlass::NumThreadsPerWarpGroup;
|
337 |
+
// Thread layout, 256 or 384 threads per row
|
338 |
+
// We split into NumWarpGroups so that we can use the same postprocessing kernel as dQ
|
339 |
+
using R2SLayoutAtomdKVaccum = Layout<Shape<Int<cutlass::NumThreadsPerWarpGroup>, Int<NumWarpGroups>>>;
|
340 |
+
using R2STiledCopydKVaccum = decltype(make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{}, R2SLayoutAtomdKVaccum{},
|
341 |
+
Layout<Shape < _4>>{})); // Val layout, 4 vals per store
|
342 |
+
// For Sm80
|
343 |
+
using R2GLayoutAtomdKVaccum = Layout<Shape<Int<NumEpilogueThreads>>>;
|
344 |
+
using R2GTiledCopydKVaccum = decltype(make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{}, R2GLayoutAtomdKVaccum{},
|
345 |
+
Layout<Shape < _1>>{})); // Val layout, 1 vals per store
|
346 |
+
|
347 |
+
using SmemLayoutdKVaccum = Layout<Shape<Int<kBlockN * kHeadDim / NumWarpGroups>, Int<NumWarpGroups>>>;
|
348 |
+
using SmemLayoutdKVaccumFlat = Layout<Shape<Int<kBlockN * kHeadDim>>>;
|
349 |
+
|
350 |
+
// Strangely without this SmemAlignment, the total smem for hdim 128 (80 x 128) is 228KB even though we
|
351 |
+
// only need 227KB. We use the same alignment as the non-GQA epilogue to avoid this issue.
|
352 |
+
static constexpr int SmemAlignment = kHeadDim % 64 == 0 ? 1024 : (kHeadDim % 32 == 0 ? 512 : 256);
|
353 |
+
struct TensorStorageTMA : cute::aligned_struct<SmemAlignment> {
|
354 |
+
cute::array_aligned<ElementAccum, cute::cosize_v<SmemLayoutdKVaccum>, SmemAlignment> smem_dkv;
|
355 |
+
};
|
356 |
+
struct TensorStorageSTG {
|
357 |
+
cute::array<ElementAccum, 0> smem_dkv;
|
358 |
+
};
|
359 |
+
using TensorStorage = std::conditional_t<Use_TMA, TensorStorageTMA, TensorStorageSTG>;
|
360 |
+
|
361 |
+
using ShapedKV = cute::Shape<int32_t, int32_t, int32_t>; // (seqlen_k_rounded * d, head, batch)
|
362 |
+
using StridedKV = cute::Stride<_1, int64_t, int64_t>;
|
363 |
+
|
364 |
+
// Host side kernel arguments
|
365 |
+
struct Arguments {
|
366 |
+
ElementAccum* ptr_dKaccum;
|
367 |
+
ShapedKV const shape_dKaccum;
|
368 |
+
StridedKV const stride_dKaccum;
|
369 |
+
ElementAccum* ptr_dVaccum;
|
370 |
+
ShapedKV const shape_dVaccum;
|
371 |
+
StridedKV const stride_dVaccum;
|
372 |
+
int num_heads_q;
|
373 |
+
int* dk_semaphore;
|
374 |
+
int* dv_semaphore;
|
375 |
+
int const* cu_seqlens;
|
376 |
+
int const* seqused;
|
377 |
+
};
|
378 |
+
|
379 |
+
// Device side kernel params
|
380 |
+
struct Params {
|
381 |
+
ElementAccum* ptr_dKaccum;
|
382 |
+
ShapedKV const shape_dKaccum;
|
383 |
+
StridedKV const stride_dKaccum;
|
384 |
+
ElementAccum* ptr_dVaccum;
|
385 |
+
ShapedKV const shape_dVaccum;
|
386 |
+
StridedKV const stride_dVaccum;
|
387 |
+
cutlass::FastDivmod qhead_per_khead_divmod;
|
388 |
+
int* dk_semaphore;
|
389 |
+
int* dv_semaphore;
|
390 |
+
int const* cu_seqlens = nullptr;
|
391 |
+
int const* seqused = nullptr;
|
392 |
+
};
|
393 |
+
|
394 |
+
static Params
|
395 |
+
to_underlying_arguments(Arguments const& args) {
|
396 |
+
if constexpr (Deterministic) {
|
397 |
+
assert(args.dk_semaphore != nullptr);
|
398 |
+
assert(args.dv_semaphore != nullptr);
|
399 |
+
}
|
400 |
+
return {args.ptr_dKaccum, args.shape_dKaccum, args.stride_dKaccum, args.ptr_dVaccum, args.shape_dVaccum, args.stride_dVaccum,
|
401 |
+
cutlass::FastDivmod(cute::ceil_div(args.num_heads_q, get<1>(args.shape_dKaccum))),
|
402 |
+
args.dk_semaphore, args.dv_semaphore,
|
403 |
+
args.cu_seqlens, args.seqused};
|
404 |
+
}
|
405 |
+
|
406 |
+
/// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
|
407 |
+
CUTLASS_DEVICE
|
408 |
+
static void prefetch_tma_descriptors(Params const& params) {
|
409 |
+
}
|
410 |
+
|
411 |
+
template <typename SharedStorage, typename FrgTensorO, typename TiledMma>
|
412 |
+
CUTLASS_DEVICE void
|
413 |
+
store(Params const& params,
|
414 |
+
FrgTensorO const& tdKrdK,
|
415 |
+
FrgTensorO const& tdVrdV,
|
416 |
+
SharedStorage& shared_storage,
|
417 |
+
TiledMma tiled_mma,
|
418 |
+
int thread_idx,
|
419 |
+
cute::tuple<int32_t, int32_t, int32_t> const& block_coord
|
420 |
+
) {
|
421 |
+
|
422 |
+
auto [n_block, bidh, bidb] = block_coord;
|
423 |
+
int bidh_idx_in_group;
|
424 |
+
int bidh_kv = params.qhead_per_khead_divmod.divmod(bidh_idx_in_group, bidh);
|
425 |
+
Tensor sdKV = make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_dkv.data()), SmemLayoutdKVaccum{});
|
426 |
+
Tensor sdKV_flat = make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_dkv.data()), SmemLayoutdKVaccumFlat{});
|
427 |
+
static constexpr int dKV_TMA_num_bytes = CUTE_STATIC_V(size(sdKV_flat)) * sizeof(ElementAccum);
|
428 |
+
|
429 |
+
flash::SeqlenInfo<Varlen, kBlockN> seqlen_info{bidb, size<0>(params.shape_dKaccum), params.cu_seqlens, params.seqused};
|
430 |
+
bool const is_varlen = Varlen && params.cu_seqlens;
|
431 |
+
Tensor mdKaccum = make_tensor(make_gmem_ptr(params.ptr_dKaccum), params.shape_dKaccum, params.stride_dKaccum)(_, bidh_kv, !is_varlen ? bidb : 0);
|
432 |
+
Tensor mdVaccum = make_tensor(make_gmem_ptr(params.ptr_dVaccum), params.shape_dVaccum, params.stride_dVaccum)(_, bidh_kv, !is_varlen ? bidb : 0);
|
433 |
+
Tensor gdKaccum = local_tile(domain_offset(make_coord(seqlen_info.offset_padded * kHeadDim), mdKaccum), Shape<Int<kBlockN * kHeadDim>>{}, make_coord(n_block)); // (M * K)
|
434 |
+
Tensor gdVaccum = local_tile(domain_offset(make_coord(seqlen_info.offset_padded * kHeadDim), mdVaccum), Shape<Int<kBlockN * kHeadDim>>{}, make_coord(n_block)); // (M * K)
|
435 |
+
|
436 |
+
R2STiledCopydKVaccum r2s_tiled_copy_dKVaccum;
|
437 |
+
auto r2s_thr_copy_dKVaccum = r2s_tiled_copy_dKVaccum.get_thread_slice(thread_idx);
|
438 |
+
Tensor tdKVsdKVaccum = r2s_thr_copy_dKVaccum.partition_D(sdKV);
|
439 |
+
|
440 |
+
// Only used if !Use_TMA
|
441 |
+
R2GTiledCopydKVaccum r2g_tiled_copy_dKVaccum;
|
442 |
+
auto r2g_thr_copy_dKVaccum = r2g_tiled_copy_dKVaccum.get_thread_slice(thread_idx);
|
443 |
+
|
444 |
+
// Make sure all WGs have finished reading K and V, otherwise we get racy dQ
|
445 |
+
// because smem_q could be changed.
|
446 |
+
flash::named_barrier_sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
|
447 |
+
if constexpr (Use_TMA) {
|
448 |
+
Tensor taccdKVrdV = r2s_thr_copy_dKVaccum.retile_S(tdVrdV); // ((Atom,AtomNum), MMA_M, MMA_N)
|
449 |
+
cute::copy(r2s_tiled_copy_dKVaccum, taccdKVrdV, tdKVsdKVaccum);
|
450 |
+
}
|
451 |
+
|
452 |
+
// int const num_batch = params.num_batch;
|
453 |
+
int const num_batch = get<2>(params.shape_dKaccum);
|
454 |
+
int const num_head_kv = get<1>(params.shape_dKaccum);
|
455 |
+
int *lock_ptr = !Deterministic ? nullptr : params.dv_semaphore + bidb * num_head_kv + bidh_kv;
|
456 |
+
using Barrier = cutlass::GenericBarrier<cutlass::detail::SyncwarpSync>;
|
457 |
+
|
458 |
+
// if (thread_idx == 0) { printf("blockIdx.x = %d, blockIdx.y = %d, blockIdx.z = %d, bidb = %d, bidh_kv = %d, lock_ptr = %p, dv_semaphore = %p, num_batch = %d, num_head_kv = %d, n_block = %d, bihd_idx_in_group = %d\n", blockIdx.x, blockIdx.y, blockIdx.z, bidb, bidh_kv, lock_ptr, params.dv_semaphore, num_batch, num_head_kv, n_block, bidh_idx_in_group);}
|
459 |
+
|
460 |
+
if constexpr (Deterministic) {
|
461 |
+
Barrier::wait_eq(lock_ptr, thread_idx, n_block * num_batch * num_head_kv, bidh_idx_in_group);
|
462 |
+
}
|
463 |
+
// if (thread_idx == 0) { printf("After barrier blockIdx.x = %d, blockIdx.y = %d, blockIdx.z = %d, bidb = %d, bidh_kv = %d, lock_ptr = %p, dv_semaphore = %p\n", blockIdx.x, blockIdx.y, blockIdx.z, bidb, bidh_kv, lock_ptr, params.dv_semaphore);}
|
464 |
+
if constexpr (Use_TMA) {
|
465 |
+
cutlass::arch::fence_view_async_shared();
|
466 |
+
cutlass::arch::NamedBarrier::sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
|
467 |
+
if (thread_idx == 0) {
|
468 |
+
SM90_BULK_REDUCE_ADD::copy(raw_pointer_cast(sdKV_flat.data()), raw_pointer_cast(gdVaccum.data()), dKV_TMA_num_bytes, static_cast<uint64_t>(TMA::CacheHintSm90::EVICT_LAST));
|
469 |
+
tma_store_arrive();
|
470 |
+
tma_store_wait<0>();
|
471 |
+
}
|
472 |
+
} else {
|
473 |
+
Tensor tdVrdV_atomic = r2g_thr_copy_dKVaccum.retile_S(tdVrdV);
|
474 |
+
Tensor tdVgdV_atomic = r2g_thr_copy_dKVaccum.partition_D(gdVaccum);
|
475 |
+
static_assert(CUTE_STATIC_V(size(tdVrdV_atomic)) == CUTE_STATIC_V(size(tdVgdV_atomic)));
|
476 |
+
#pragma unroll
|
477 |
+
for (int i = 0; i < size(tdVrdV_atomic); ++i) { atomicAdd(&tdVgdV_atomic(i), tdVrdV_atomic(i)); }
|
478 |
+
}
|
479 |
+
if constexpr (Deterministic) {
|
480 |
+
Barrier::arrive_inc(lock_ptr, thread_idx, n_block * num_batch * num_head_kv);
|
481 |
+
}
|
482 |
+
|
483 |
+
if constexpr (Use_TMA) {
|
484 |
+
cutlass::arch::NamedBarrier::sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
|
485 |
+
Tensor taccdKVrdK = r2s_thr_copy_dKVaccum.retile_S(tdKrdK); // ((Atom,AtomNum), MMA_M, MMA_N)
|
486 |
+
cute::copy(r2s_tiled_copy_dKVaccum, taccdKVrdK, tdKVsdKVaccum);
|
487 |
+
}
|
488 |
+
lock_ptr = !Deterministic ? nullptr : params.dk_semaphore + bidb * num_head_kv + bidh_kv;
|
489 |
+
// if (thread_idx == 0) { printf("blockIdx.x = %d, blockIdx.y = %d, blockIdx.z = %d, bidb = %d, bidh_kv = %d, lock_ptr = %p, dk_semaphore = %p, num_batch = %d, num_head_kv = %d, n_block = %d, bihd_idx_in_group = %d\n", blockIdx.x, blockIdx.y, blockIdx.z, bidb, bidh_kv, lock_ptr, params.dk_semaphore, num_batch, num_head_kv, n_block, bidh_idx_in_group);}
|
490 |
+
|
491 |
+
if constexpr (Deterministic) {
|
492 |
+
Barrier::wait_eq(lock_ptr, thread_idx, n_block * num_batch * num_head_kv, bidh_idx_in_group);
|
493 |
+
}
|
494 |
+
// if (thread_idx == 0) { printf("After barrier blockIdx.x = %d, blockIdx.y = %d, blockIdx.z = %d, bidb = %d, bidh_kv = %d, lock_ptr = %p, dk_semaphore = %p\n", blockIdx.x, blockIdx.y, blockIdx.z, bidb, bidh_kv, lock_ptr, params.dk_semaphore);}
|
495 |
+
if constexpr (Use_TMA) {
|
496 |
+
cutlass::arch::fence_view_async_shared();
|
497 |
+
cutlass::arch::NamedBarrier::sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
|
498 |
+
if (thread_idx == 0) {
|
499 |
+
SM90_BULK_REDUCE_ADD::copy(raw_pointer_cast(sdKV_flat.data()), raw_pointer_cast(gdKaccum.data()), dKV_TMA_num_bytes, static_cast<uint64_t>(TMA::CacheHintSm90::EVICT_LAST));
|
500 |
+
tma_store_arrive();
|
501 |
+
tma_store_wait<0>();
|
502 |
+
}
|
503 |
+
} else {
|
504 |
+
Tensor tdKrdK_atomic = r2g_thr_copy_dKVaccum.retile_S(tdKrdK);
|
505 |
+
Tensor tdKgdK_atomic = r2g_thr_copy_dKVaccum.partition_D(gdKaccum);
|
506 |
+
static_assert(CUTE_STATIC_V(size(tdKrdK_atomic)) == CUTE_STATIC_V(size(tdKgdK_atomic)));
|
507 |
+
#pragma unroll
|
508 |
+
for (int i = 0; i < size(tdKrdK_atomic); ++i) { atomicAdd(&tdKgdK_atomic(i), tdKrdK_atomic(i)); }
|
509 |
+
}
|
510 |
+
if constexpr (Deterministic) {
|
511 |
+
Barrier::arrive_inc(lock_ptr, thread_idx, n_block * num_batch * num_head_kv);
|
512 |
+
}
|
513 |
+
// // Tell warp 0 that smem_k and smem_v are ready
|
514 |
+
// flash::named_barrier_arrive(NumEpilogueThreads + cutlass::NumThreadsPerWarp, static_cast<uint32_t>(BwdNamedBarriers::KVEmpty) /*id*/);
|
515 |
+
}
|
516 |
+
|
517 |
+
CUTLASS_DEVICE void
|
518 |
+
store_tail() {
|
519 |
+
}
|
520 |
+
|
521 |
+
// Write 0 to dK and dV
|
522 |
+
CUTLASS_DEVICE void
|
523 |
+
store_zero(
|
524 |
+
Params const& params,
|
525 |
+
int thread_idx,
|
526 |
+
cute::tuple<int32_t, int32_t, int32_t> const& block_coord
|
527 |
+
) {
|
528 |
+
// Don't need to do anything since dKaccum and dVaccum are already zero-initialized
|
529 |
+
}
|
530 |
+
|
531 |
+
};
|
532 |
+
|
533 |
+
} // namespace flash
|
flash-attn/epilogue_fwd.hpp
ADDED
@@ -0,0 +1,484 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/******************************************************************************
|
2 |
+
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
3 |
+
******************************************************************************/
|
4 |
+
|
5 |
+
#pragma once
|
6 |
+
|
7 |
+
#include <cutlass/cutlass.h>
|
8 |
+
#include <cutlass/fast_math.h> // For FastDivMod
|
9 |
+
#include "cute/tensor.hpp"
|
10 |
+
|
11 |
+
#include "cutlass/gemm/collective/builders/sm90_common.inl"
|
12 |
+
#include "cutlass/epilogue/collective/builders/sm90_common.inl"
|
13 |
+
|
14 |
+
#include "seqlen.h"
|
15 |
+
#include "named_barrier.hpp"
|
16 |
+
#include "pack_gqa.h"
|
17 |
+
#include "utils.h"
|
18 |
+
|
19 |
+
namespace flash {
|
20 |
+
|
21 |
+
using namespace cute;
|
22 |
+
|
23 |
+
template <class TileShape_MNK_PV_, class ClusterShape_, class Element_, class ArchTag_,
|
24 |
+
int NumEpilogueThreads_, bool Varlen_, bool PackGQA_, bool Split_, bool FP8PermuteCol=false>
|
25 |
+
struct CollectiveEpilogueFwd {
|
26 |
+
|
27 |
+
using TileShape_MNK_PV = TileShape_MNK_PV_;
|
28 |
+
using ClusterShape = ClusterShape_;
|
29 |
+
using Element = Element_;
|
30 |
+
using ElementPartial = float;
|
31 |
+
using ArchTag = ArchTag_;
|
32 |
+
static constexpr int NumEpilogueThreads = NumEpilogueThreads_;
|
33 |
+
static constexpr bool Varlen = Varlen_;
|
34 |
+
static constexpr bool PackGQA = PackGQA_;
|
35 |
+
static constexpr bool Split = Split_;
|
36 |
+
static constexpr bool Use_smem = !(Split && !Varlen);
|
37 |
+
static constexpr bool Use_TMA_O = ArchTag::kMinComputeCapability >= 90 && !Varlen && !Split && !PackGQA;
|
38 |
+
|
39 |
+
static_assert(ArchTag::kMinComputeCapability >= 80);
|
40 |
+
static_assert(ArchTag::kMinComputeCapability >= 90 || CUTE_STATIC_V(size(ClusterShape{})) == 1);
|
41 |
+
static_assert(sizeof(Element) <= 2);
|
42 |
+
|
43 |
+
static constexpr int kBlockM = get<0>(TileShape_MNK_PV{});
|
44 |
+
static constexpr int kHeadDimV = get<1>(TileShape_MNK_PV{});
|
45 |
+
|
46 |
+
static constexpr bool LargeHeadDimV = kHeadDimV > 256;
|
47 |
+
|
48 |
+
using GmemTiledCopyOTMA = cute::SM90_TMA_STORE;
|
49 |
+
|
50 |
+
// These are for storing the output tensor without TMA (e.g., for setting output to zero)
|
51 |
+
static constexpr int kGmemElemsPerStore = sizeof(cute::uint128_t) / sizeof(Element);
|
52 |
+
static_assert(kHeadDimV % kGmemElemsPerStore == 0, "Headdim must be a multiple of kGmemElemsPerStore");
|
53 |
+
// We want each "row" to have 64 elements (128 bytes, i.e. 1 cache line). We want each thread to have 4 elements
|
54 |
+
// in the M direction and 2 elements in the K direction. In the case of PackGQA, this reduces the number of times
|
55 |
+
// we need to call divmod.
|
56 |
+
static constexpr int kBytePerRow = kHeadDimV * sizeof(Element);
|
57 |
+
static constexpr int kBlockKGmem = (kBytePerRow % 128 == 0 ? 128 : (kBytePerRow % 64 == 0 ? 64 : 32)) / sizeof(Element);
|
58 |
+
static constexpr int kGmemThreadsPerRow = kBlockKGmem / kGmemElemsPerStore;
|
59 |
+
// If PackGQA, we split the work of compute O_ptr among threads in the same row, so we need this to within a warp
|
60 |
+
static_assert(cutlass::NumThreadsPerWarp % kGmemThreadsPerRow == 0);
|
61 |
+
static_assert(NumEpilogueThreads % kGmemThreadsPerRow == 0, "NumEpilogueThreads must be a multiple of kGmemThreadsPerRow");
|
62 |
+
using GmemLayoutAtom = Layout<Shape <Int<NumEpilogueThreads / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
|
63 |
+
Stride<Int<kGmemThreadsPerRow>, _1>>;
|
64 |
+
static_assert(kBlockM % CUTE_STATIC_V(shape<0>(GmemLayoutAtom{})) == 0, "kBlockM must be a multiple of NumEpilogueThreads / kGmemThreadsPerRow");
|
65 |
+
using GmemTiledCopyO = decltype(
|
66 |
+
make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>{},
|
67 |
+
GmemLayoutAtom{},
|
68 |
+
Layout<Shape<_1, Int<kGmemElemsPerStore>>>{})); // Val layout, 8 or 16 vals per store
|
69 |
+
|
70 |
+
using SmemLayoutAtomOTMA = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
|
71 |
+
decltype(cute::get<0>(TileShape_MNK_PV{})), decltype(cute::get<1>(TileShape_MNK_PV{}))>());
|
72 |
+
using SmemLayoutOTMA = decltype(tile_to_shape(SmemLayoutAtomOTMA{}, select<0, 1>(TileShape_MNK_PV{})));
|
73 |
+
static constexpr int kSwizzle = kBlockKGmem == 128 ? 4 : (kBlockKGmem == 64 ? 3 : (kBlockKGmem == 32 ? 2 : 1));
|
74 |
+
static constexpr int kSwizzleBase = sizeof(Element) == 4 ? 2 : (sizeof(Element) == 2 ? 3 : 4);
|
75 |
+
using SmemLayoutAtomO = decltype(
|
76 |
+
composition(Swizzle<kSwizzle, kSwizzleBase, kSwizzleBase>{},
|
77 |
+
Layout<Shape<_8, Int<kBlockKGmem>>,
|
78 |
+
Stride<Int<kBlockKGmem>, _1>>{}));
|
79 |
+
using SmemLayoutOSTS = decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 1>(TileShape_MNK_PV{})));
|
80 |
+
using SmemLayoutO = std::conditional_t<ArchTag::kMinComputeCapability >= 90, SmemLayoutOTMA, SmemLayoutOSTS>;
|
81 |
+
|
82 |
+
using ShapeO = cute::Shape<int32_t, int32_t, int32_t, int32_t, int32_t>; // (seqlen_q, d, head, batch, num_splits)
|
83 |
+
using StrideO = cute::Stride<int64_t, _1, int64_t, int64_t, int64_t>;
|
84 |
+
using StrideLSE = cute::Stride<_1, int64_t, int64_t, int64_t>; // (seqlen_q, head, batch, num_splits)
|
85 |
+
// ((qhead_per_khead, seqlen_q), d, nheads_kv, batch, num_splits)
|
86 |
+
using ShapeOPacked = std::conditional_t<!PackGQA, ShapeO, cute::Shape<cute::Shape<int32_t, int32_t>, int32_t, int32_t, int32_t, int32_t>>;
|
87 |
+
using StrideOPacked = std::conditional_t<!PackGQA, StrideO, cute::Stride<cute::Stride<int64_t, int64_t>, _1, int64_t, int64_t, int64_t>>;
|
88 |
+
// ((qhead_per_khead, seqlen_q), nheads_kv, batch, num_splits)
|
89 |
+
using ShapeLSEPacked = std::conditional_t<!PackGQA, cute::Shape<int32_t, int32_t, int32_t, int32_t>, cute::Shape<cute::Shape<int32_t, int32_t>, int32_t, int32_t, int32_t>>;
|
90 |
+
using StrideLSEPacked = std::conditional_t<!PackGQA, StrideLSE, cute::Stride<cute::Stride<int64_t, _1>, int64_t, int64_t, int64_t>>;
|
91 |
+
|
92 |
+
using CopyOpR2S = std::conditional_t<
|
93 |
+
ArchTag::kMinComputeCapability >= 90,
|
94 |
+
// cute::SM90_U32x4_STSM_N if Element size is 2 bytes (fp16, bf16)
|
95 |
+
decltype(cutlass::epilogue::collective::detail::sm90_get_smem_store_op_for_accumulator<StrideO, Element>()),
|
96 |
+
AutoVectorizingCopyWithAssumedAlignment<128>
|
97 |
+
>;
|
98 |
+
using SmemCopyAtomO = Copy_Atom<CopyOpR2S, Element>;
|
99 |
+
|
100 |
+
// static constexpr size_t SmemAlignmentO = cutlass::detail::alignment_for_swizzle(SmemLayoutO{});
|
101 |
+
// static_assert(SmemAlignmentO >= 128, "Require at least 128B alignment");
|
102 |
+
// struct TensorStorage : cute::aligned_struct<SmemAlignmentO> {
|
103 |
+
// cute::array_aligned<Element, Use_smem ? cute::cosize_v<SmemLayoutO> : 0, SmemAlignmentO> smem_o;
|
104 |
+
// };
|
105 |
+
struct TensorStorage : cute::aligned_struct<128> {
|
106 |
+
cute::array_aligned<Element, Use_smem ? cute::cosize_v<SmemLayoutO> : 0> smem_o;
|
107 |
+
};
|
108 |
+
|
109 |
+
using TMA_O = std::conditional_t<
|
110 |
+
Use_TMA_O,
|
111 |
+
decltype(make_tma_copy(
|
112 |
+
GmemTiledCopyOTMA{},
|
113 |
+
make_tensor(make_gmem_ptr(static_cast<Element*>(nullptr)), ShapeO{}, StrideO{}),
|
114 |
+
SmemLayoutOTMA{},
|
115 |
+
select<0, 1>(TileShape_MNK_PV{}),
|
116 |
+
_1{})), // no mcast for O
|
117 |
+
std::nullptr_t
|
118 |
+
>;
|
119 |
+
|
120 |
+
// Host side kernel arguments
|
121 |
+
struct Arguments {
|
122 |
+
Element* ptr_O;
|
123 |
+
ShapeO const shape_O;
|
124 |
+
StrideO const stride_O;
|
125 |
+
ElementPartial* ptr_O_partial;
|
126 |
+
StrideO const stride_O_partial;
|
127 |
+
float* ptr_LSE;
|
128 |
+
StrideLSE const stride_LSE;
|
129 |
+
float* ptr_LSE_partial;
|
130 |
+
StrideLSE const stride_LSE_partial;
|
131 |
+
int32_t const nheads_kv;
|
132 |
+
int const* cu_seqlens = nullptr;
|
133 |
+
int const* seqused = nullptr;
|
134 |
+
};
|
135 |
+
|
136 |
+
// Device side kernel params
|
137 |
+
struct Params {
|
138 |
+
Element* ptr_O;
|
139 |
+
ShapeO const shape_O;
|
140 |
+
StrideO const stride_O;
|
141 |
+
ShapeOPacked const shape_O_packed;
|
142 |
+
StrideOPacked const stride_O_packed;
|
143 |
+
ElementPartial* ptr_O_partial;
|
144 |
+
StrideO const stride_O_partial;
|
145 |
+
StrideOPacked const stride_O_partial_packed;
|
146 |
+
float* ptr_LSE;
|
147 |
+
StrideLSE const stride_LSE;
|
148 |
+
ShapeLSEPacked const shape_LSE_packed;
|
149 |
+
StrideLSEPacked const stride_LSE_packed;
|
150 |
+
float* ptr_LSE_partial;
|
151 |
+
StrideLSE const stride_LSE_partial;
|
152 |
+
StrideLSEPacked const stride_LSE_partial_packed;
|
153 |
+
cutlass::FastDivmod qhead_per_khead_divmod;
|
154 |
+
TMA_O tma_store_O;
|
155 |
+
int const* cu_seqlens = nullptr;
|
156 |
+
int const* seqused = nullptr;
|
157 |
+
};
|
158 |
+
|
159 |
+
static Params
|
160 |
+
to_underlying_arguments(Arguments const& args) {
|
161 |
+
Tensor mO = make_tensor(make_gmem_ptr(args.ptr_O), args.shape_O, args.stride_O);
|
162 |
+
TMA_O tma_store_O = [&]{
|
163 |
+
if constexpr (Use_TMA_O) {
|
164 |
+
return make_tma_copy(GmemTiledCopyOTMA{}, mO, SmemLayoutO{}, select<0, 1>(TileShape_MNK_PV{}), _1{}); // no mcast
|
165 |
+
} else {
|
166 |
+
return nullptr;
|
167 |
+
}
|
168 |
+
}();
|
169 |
+
// If PackGQA, reshape O to be ((qhead_per_khead, seqlen_q), head_size, nhead_k, batch_size, num_splits)
|
170 |
+
int const qhead_per_khead = !PackGQA ? 1 : cute::ceil_div(get<2>(args.shape_O), args.nheads_kv);
|
171 |
+
auto const shape_O_packed = cute::conditional_return<!PackGQA>(
|
172 |
+
args.shape_O,
|
173 |
+
make_shape(make_shape(qhead_per_khead, get<0>(args.shape_O)), get<1>(args.shape_O), args.nheads_kv, get<3>(args.shape_O), get<4>(args.shape_O))
|
174 |
+
);
|
175 |
+
auto const stride_O_packed = cute::conditional_return<!PackGQA>(
|
176 |
+
args.stride_O,
|
177 |
+
make_stride(make_stride(get<2>(args.stride_O), get<0>(args.stride_O)), get<1>(args.stride_O), get<2>(args.stride_O) * qhead_per_khead, get<3>(args.stride_O), get<4>(args.stride_O))
|
178 |
+
);
|
179 |
+
auto const stride_O_partial_packed = cute::conditional_return<!PackGQA>(
|
180 |
+
args.stride_O_partial,
|
181 |
+
make_stride(make_stride(get<2>(args.stride_O_partial), get<0>(args.stride_O_partial)), get<1>(args.stride_O_partial), get<2>(args.stride_O_partial) * qhead_per_khead, get<3>(args.stride_O_partial), get<4>(args.stride_O_partial))
|
182 |
+
);
|
183 |
+
// If PackGQA, Reshape LSE to be ((qhead_per_khead, seqlen_q), nhead_k, batch_size, num_splits)
|
184 |
+
auto const shape_LSE_packed = cute::conditional_return<!PackGQA>(
|
185 |
+
select<0, 2, 3, 4>(args.shape_O),
|
186 |
+
make_shape(make_shape(qhead_per_khead, get<0>(args.shape_O)), args.nheads_kv, get<3>(args.shape_O), get<4>(args.shape_O))
|
187 |
+
);
|
188 |
+
auto const stride_LSE_packed = cute::conditional_return<!PackGQA>(
|
189 |
+
args.stride_LSE,
|
190 |
+
make_stride(make_stride(get<1>(args.stride_LSE), get<0>(args.stride_LSE)), get<1>(args.stride_LSE) * qhead_per_khead, get<2>(args.stride_LSE), get<3>(args.stride_LSE))
|
191 |
+
);
|
192 |
+
auto const stride_LSE_partial_packed = cute::conditional_return<!PackGQA>(
|
193 |
+
args.stride_LSE_partial,
|
194 |
+
make_stride(make_stride(get<1>(args.stride_LSE_partial), get<0>(args.stride_LSE_partial)), get<1>(args.stride_LSE_partial) * qhead_per_khead, get<2>(args.stride_LSE_partial), get<3>(args.stride_LSE_partial))
|
195 |
+
);
|
196 |
+
return {args.ptr_O, args.shape_O, args.stride_O, shape_O_packed, stride_O_packed,
|
197 |
+
args.ptr_O_partial, args.stride_O_partial, stride_O_partial_packed,
|
198 |
+
args.ptr_LSE, args.stride_LSE, shape_LSE_packed, stride_LSE_packed,
|
199 |
+
args.ptr_LSE_partial, args.stride_LSE_partial, stride_LSE_partial_packed,
|
200 |
+
cutlass::FastDivmod(qhead_per_khead),
|
201 |
+
tma_store_O, args.cu_seqlens, args.seqused};
|
202 |
+
}
|
203 |
+
|
204 |
+
/// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
|
205 |
+
CUTLASS_DEVICE
|
206 |
+
static void prefetch_tma_descriptors(Params const& params) {
|
207 |
+
if constexpr (Use_TMA_O) {
|
208 |
+
cute::prefetch_tma_descriptor(params.tma_store_O.get_tma_descriptor());
|
209 |
+
}
|
210 |
+
}
|
211 |
+
|
212 |
+
template <typename SharedStorage, typename FrgTensorO, typename FrgTensorLSE, typename TiledMma>
|
213 |
+
CUTLASS_DEVICE void
|
214 |
+
store(Params const& params,
|
215 |
+
FrgTensorO& tOrO,
|
216 |
+
FrgTensorLSE const& lse,
|
217 |
+
SharedStorage& shared_storage,
|
218 |
+
TiledMma tiled_mma,
|
219 |
+
int thread_idx,
|
220 |
+
cute::tuple<int32_t, int32_t, int32_t, int32_t> const& block_coord
|
221 |
+
) {
|
222 |
+
|
223 |
+
auto [m_block, bidh, bidb, split_idx] = block_coord;
|
224 |
+
int num_splits = get<4>(params.shape_O_packed);
|
225 |
+
if constexpr (Split && Varlen) {
|
226 |
+
uint32_t num_splits_dynamic_u = reinterpret_cast<uint32_t const&>(split_idx) >> 16; // first 16 bits are for num_splits
|
227 |
+
int num_splits_dynamic = reinterpret_cast<int&>(num_splits_dynamic_u);
|
228 |
+
num_splits = num_splits_dynamic > 0 ? num_splits_dynamic : num_splits;
|
229 |
+
split_idx &= 0x0000FFFF; // Only use the lower 16 bits of split_idx
|
230 |
+
}
|
231 |
+
bool const is_split = !Split ? false : (!Varlen ? true : num_splits > 1);
|
232 |
+
|
233 |
+
Tensor sO = make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_o.data()), SmemLayoutO{});
|
234 |
+
// Tensor sO_pi = cute::as_position_independent_swizzle_tensor(sO);
|
235 |
+
|
236 |
+
static constexpr bool NeedFP8Permute = FP8PermuteCol && (sizeof(Element) == 2 || sizeof(Element) == 4);
|
237 |
+
// If we will possibly need tOrO in FP32, we'd want to permute tOrO before type conversion.
|
238 |
+
// Otherwise we can permute after conversion.
|
239 |
+
if constexpr (NeedFP8Permute && Split) { flash::permute_output_fp8_Vcolmajor(tOrO); }
|
240 |
+
Tensor tOrO_out = make_tensor_like<Element>(tOrO);
|
241 |
+
flash::convert_type_out(tOrO, tOrO_out);
|
242 |
+
if constexpr (NeedFP8Permute && !Split) { flash::permute_output_fp8_Vcolmajor(tOrO_out); }
|
243 |
+
|
244 |
+
// Make sure all WGs have finished reading V
|
245 |
+
// Technically we don't need this if we're not using smem, but the mainloop makes the assumption that
|
246 |
+
// all epilogue threads sync at least once during the epilogue (so that we can start loading Q with
|
247 |
+
// cp.async if we need).
|
248 |
+
flash::named_barrier_sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
|
249 |
+
|
250 |
+
// Step 1: Write O from rmem -> smem
|
251 |
+
if constexpr (Use_smem) {
|
252 |
+
auto smem_tiled_copy_O = make_tiled_copy_C(SmemCopyAtomO{}, tiled_mma);
|
253 |
+
auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(thread_idx);
|
254 |
+
Tensor taccOrO = smem_thr_copy_O.retile_S(tOrO_out); // ((Atom,AtomNum), MMA_M, MMA_N)
|
255 |
+
Tensor taccOsO = smem_thr_copy_O.partition_D(sO); // ((Atom,AtomNum),PIPE_M,PIPE_N)
|
256 |
+
// Tensor taccOsO = smem_thr_copy_O.partition_D(sO_pi); // ((Atom,AtomNum),PIPE_M,PIPE_N)
|
257 |
+
cute::copy(smem_tiled_copy_O, taccOrO, taccOsO);
|
258 |
+
if constexpr (Use_TMA_O) {
|
259 |
+
cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible to TMA
|
260 |
+
cutlass::arch::NamedBarrier::arrive(NumEpilogueThreads + cutlass::NumThreadsPerWarp,
|
261 |
+
cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
|
262 |
+
} else {
|
263 |
+
flash::named_barrier_sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
|
264 |
+
}
|
265 |
+
} else {
|
266 |
+
if constexpr (ArchTag::kMinComputeCapability >= 90) {
|
267 |
+
#pragma unroll
|
268 |
+
for (uint32_t cta_id = 0; cta_id < size(ClusterShape{}); ++cta_id) {
|
269 |
+
shared_storage.pipelines.barrier_O.arrive(cta_id);
|
270 |
+
}
|
271 |
+
}
|
272 |
+
}
|
273 |
+
|
274 |
+
flash::SeqlenInfo<Varlen, kBlockM> seqlen_info{bidb, size<0>(params.shape_O), params.cu_seqlens, params.seqused};
|
275 |
+
bool is_varlen = Varlen && params.cu_seqlens;
|
276 |
+
int offset_o = seqlen_info.offset;
|
277 |
+
int seqlen_o = seqlen_info.seqlen;
|
278 |
+
int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / cutlass::NumThreadsPerWarpGroup, 0);
|
279 |
+
|
280 |
+
// Step 2: Write LSE from rmem -> gmem
|
281 |
+
auto thread_mma = tiled_mma.get_thread_slice(thread_idx);
|
282 |
+
// (MMA,MMA_M,MMA_K)
|
283 |
+
Tensor taccOcO = thread_mma.partition_C(cute::make_identity_tensor(select<0, 1>(TileShape_MNK_PV{})));
|
284 |
+
static_assert(decltype(size<0, 0>(taccOcO))::value == 2);
|
285 |
+
static_assert(decltype(size<0, 1>(taccOcO))::value == 2);
|
286 |
+
Tensor taccOcO_rowcol = make_tensor(taccOcO.data(), flash::convert_layout_acc_rowcol(taccOcO.layout()));
|
287 |
+
Tensor taccOcO_row = taccOcO_rowcol(_, _0{});
|
288 |
+
CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M
|
289 |
+
|
290 |
+
using PackGQA_t = flash::PackGQAManager<get<0>(TileShape_MNK_PV{}), get<1>(TileShape_MNK_PV{}), NumEpilogueThreads, Element>;
|
291 |
+
using PackGQApartial_t = flash::PackGQAManager<get<0>(TileShape_MNK_PV{}), get<1>(TileShape_MNK_PV{}), NumEpilogueThreads, ElementPartial>;
|
292 |
+
|
293 |
+
Tensor mLSE = make_tensor(make_gmem_ptr((!is_split ? params.ptr_LSE : params.ptr_LSE_partial) + offset_o * get<0>(!is_split ? params.stride_LSE : params.stride_LSE_partial)),
|
294 |
+
params.shape_LSE_packed,
|
295 |
+
!is_split ? params.stride_LSE_packed : params.stride_LSE_partial_packed)(_, bidh, !is_varlen ? bidb : 0, !is_split ? 0 : split_idx);
|
296 |
+
// if (thread_idx == 0) { printf("Before LSE write, m_block: %d, bidh: %d, bidb: %d, split_idx: %d, offset_o: %d, seqlen_o: %d\n", m_block, bidh, bidb, split_idx, offset_o, seqlen_o); print(mLSE); printf("\n"); }
|
297 |
+
if (!LargeHeadDimV || warp_group_idx == 0) {
|
298 |
+
if constexpr (!PackGQA) {
|
299 |
+
#pragma unroll
|
300 |
+
for (int mi = 0; mi < size(lse); ++mi) {
|
301 |
+
int const row = m_block * kBlockM + get<0>(taccOcO_row(mi));
|
302 |
+
if (get<1>(taccOcO_row(_0{})) == 0 && row < seqlen_o) { mLSE(row) = lse(mi); }
|
303 |
+
}
|
304 |
+
} else {
|
305 |
+
PackGQA_t::store_LSE(mLSE, lse, tiled_mma, params.qhead_per_khead_divmod, thread_idx, seqlen_o, m_block);
|
306 |
+
}
|
307 |
+
}
|
308 |
+
|
309 |
+
// Step 3: Write O from smem -> gmem
|
310 |
+
if constexpr (Use_TMA_O) {
|
311 |
+
Tensor mO = params.tma_store_O.get_tma_tensor(params.shape_O)(_, _, bidh, bidb, split_idx);
|
312 |
+
Tensor gO = local_tile(mO, select<0, 1>(TileShape_MNK_PV{}), make_coord(m_block, _0{})); // (M, K)
|
313 |
+
auto block_tma_O = params.tma_store_O.get_slice(_0{});
|
314 |
+
Tensor tOgO = block_tma_O.partition_D(gO); // (TMA, TMA_M, TMA_K)
|
315 |
+
Tensor tOsO = block_tma_O.partition_S(sO); // (TMA, TMA_M, TMA_K)
|
316 |
+
int warp_idx_sync = __shfl_sync(0xffffffff, thread_idx / cutlass::NumThreadsPerWarp, 0);
|
317 |
+
if (warp_idx_sync == NumEpilogueThreads / cutlass::NumThreadsPerWarp - 1) {
|
318 |
+
cutlass::arch::NamedBarrier::sync(NumEpilogueThreads + cutlass::NumThreadsPerWarp,
|
319 |
+
cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
|
320 |
+
if (cute::elect_one_sync()) {
|
321 |
+
cute::copy(params.tma_store_O, tOsO, tOgO);
|
322 |
+
tma_store_arrive();
|
323 |
+
tma_store_wait<0>();
|
324 |
+
#pragma unroll
|
325 |
+
for (uint32_t cta_id = 0; cta_id < size(ClusterShape{}); ++cta_id) {
|
326 |
+
shared_storage.pipelines.barrier_O.arrive(cta_id);
|
327 |
+
}
|
328 |
+
}
|
329 |
+
}
|
330 |
+
} else { // Don't use TMA in Varlen case since we don't want to overwrite the output of another sequence
|
331 |
+
if (!is_split) {
|
332 |
+
Tensor mO = make_tensor(make_gmem_ptr(params.ptr_O + offset_o * get<0>(params.stride_O)), params.shape_O_packed, params.stride_O_packed)(_, _, bidh, !is_varlen ? bidb : 0, _0{});
|
333 |
+
Tensor gO = local_tile(mO, select<0, 1>(TileShape_MNK_PV{}), make_coord(m_block, _0{})); // (M, K)
|
334 |
+
// if (thread_idx == 0) { printf("Before O write, m_block: %d, bidh: %d, bidb: %d, split_idx: %d, offset_o: %d, seqlen_o: %d, mO_addr = %p, addr diff = %d\n", m_block, bidh, bidb, split_idx, offset_o, seqlen_o, mO.data(), reinterpret_cast<int>(&mO(0)) - reinterpret_cast<int>(params.ptr_O)); }
|
335 |
+
GmemTiledCopyO gmem_tiled_copy_O;
|
336 |
+
auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx);
|
337 |
+
Tensor tOsO = gmem_thr_copy_O.partition_S(sO); // ((Atom,AtomNum),ATOM_M,ATOM_N)
|
338 |
+
// Tensor tOsO = gmem_thr_copy_O.partition_S(sO_pi); // ((Atom,AtomNum),ATOM_M,ATOM_N)
|
339 |
+
Tensor tOrO = make_fragment_like(tOsO);
|
340 |
+
cute::copy(gmem_tiled_copy_O, tOsO, tOrO);
|
341 |
+
if constexpr (ArchTag::kMinComputeCapability >= 90) {
|
342 |
+
cutlass::arch::fence_view_async_shared(); // ensure smem reads are done before next TMA to smem_v
|
343 |
+
#pragma unroll
|
344 |
+
for (uint32_t cta_id = 0; cta_id < size(ClusterShape{}); ++cta_id) {
|
345 |
+
shared_storage.pipelines.barrier_O.arrive(cta_id);
|
346 |
+
}
|
347 |
+
}
|
348 |
+
if constexpr (!PackGQA) {
|
349 |
+
// (BLK_M,BLK_K) -> (blk_m,blk_k)
|
350 |
+
Tensor tOcO = gmem_thr_copy_O.partition_D(cute::make_identity_tensor(select<0, 1>(TileShape_MNK_PV{})));
|
351 |
+
Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOsO)));
|
352 |
+
#pragma unroll
|
353 |
+
for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < get<1>(params.shape_O); }
|
354 |
+
Tensor tOgO = gmem_thr_copy_O.partition_D(gO);
|
355 |
+
// Clear_OOB_K must be false since we don't want to write zeros to gmem
|
356 |
+
flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
|
357 |
+
gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, seqlen_o - m_block * kBlockM
|
358 |
+
);
|
359 |
+
} else {
|
360 |
+
// If PackGQA, we split the work of compute O_ptr among threads in the same row
|
361 |
+
PackGQA_t::store_O(mO, tOrO, params.qhead_per_khead_divmod, thread_idx, seqlen_o, m_block);
|
362 |
+
}
|
363 |
+
} else {
|
364 |
+
Tensor mOpartial = make_tensor(make_gmem_ptr(params.ptr_O_partial + offset_o * get<0>(params.stride_O_partial)), params.shape_O_packed, params.stride_O_partial_packed)(_, _, bidh, !is_varlen ? bidb : 0, split_idx);
|
365 |
+
Tensor gOpartial = local_tile(mOpartial, select<0, 1>(TileShape_MNK_PV{}), make_coord(m_block, _0{})); // (M, K)
|
366 |
+
// We already arrived on barrier_O earlier if !Use_smem
|
367 |
+
if constexpr (Use_smem) {
|
368 |
+
if constexpr (ArchTag::kMinComputeCapability >= 90) {
|
369 |
+
#pragma unroll
|
370 |
+
for (uint32_t cta_id = 0; cta_id < size(ClusterShape{}); ++cta_id) {
|
371 |
+
shared_storage.pipelines.barrier_O.arrive(cta_id);
|
372 |
+
}
|
373 |
+
}
|
374 |
+
}
|
375 |
+
if constexpr (!PackGQA) {
|
376 |
+
static constexpr int kGmemElemsPerStoreDirect = 2;
|
377 |
+
cute::Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementPartial> gmem_copy_direct;
|
378 |
+
// Reshape acc from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N))
|
379 |
+
Tensor tOrO_rowcol = make_tensor(tOrO.data(), flash::convert_layout_acc_rowcol(tOrO.layout()));
|
380 |
+
Tensor tOrO_copy = cute::tiled_divide(tOrO_rowcol, Shape<_1, Int<kGmemElemsPerStoreDirect>>{});
|
381 |
+
Tensor tOgO = thread_mma.partition_C(gOpartial);
|
382 |
+
Tensor tOgO_rowcol = make_tensor(tOgO.data(), flash::convert_layout_acc_rowcol(tOgO.layout()));
|
383 |
+
Tensor tOgO_copy = cute::tiled_divide(tOgO_rowcol, Shape<_1, Int<kGmemElemsPerStoreDirect>>{});
|
384 |
+
Tensor taccOcO_col = taccOcO_rowcol(_0{}, _);
|
385 |
+
#pragma unroll
|
386 |
+
for (int m = 0; m < size(taccOcO_row); ++m) {
|
387 |
+
if (get<0>(taccOcO_row(m)) < seqlen_o - m_block * kBlockM) {
|
388 |
+
#pragma unroll
|
389 |
+
for (int k = 0; k < size(taccOcO_col) / kGmemElemsPerStoreDirect; ++k) {
|
390 |
+
if (get<1>(taccOcO_col(k * kGmemElemsPerStoreDirect)) < get<1>(params.shape_O)) {
|
391 |
+
cute::copy(gmem_copy_direct, tOrO_copy(_, m, k), tOgO_copy(_, m, k));
|
392 |
+
}
|
393 |
+
}
|
394 |
+
}
|
395 |
+
}
|
396 |
+
} else {
|
397 |
+
PackGQApartial_t::store_O_direct(mOpartial, tOrO, tiled_mma, params.qhead_per_khead_divmod, thread_idx, seqlen_o, m_block);
|
398 |
+
}
|
399 |
+
}
|
400 |
+
}
|
401 |
+
}
|
402 |
+
|
403 |
+
CUTLASS_DEVICE void
|
404 |
+
store_tail() {
|
405 |
+
// Don't need to do tma_store_wait<0>() here since we already did in @store
|
406 |
+
}
|
407 |
+
|
408 |
+
// Write 0 to output and -inf to LSE
|
409 |
+
CUTLASS_DEVICE void
|
410 |
+
store_zero(
|
411 |
+
Params const& params,
|
412 |
+
int thread_idx,
|
413 |
+
cute::tuple<int32_t, int32_t, int32_t, int32_t> const& block_coord
|
414 |
+
) {
|
415 |
+
static constexpr int kBlockM = get<0>(TileShape_MNK_PV{});
|
416 |
+
auto [m_block, bidh, bidb, split_idx] = block_coord;
|
417 |
+
int num_splits = get<4>(params.shape_O_packed);
|
418 |
+
if constexpr (Split && Varlen) {
|
419 |
+
uint32_t num_splits_dynamic_u = reinterpret_cast<uint32_t const&>(split_idx) >> 16; // first 16 bits are for num_splits
|
420 |
+
int num_splits_dynamic = reinterpret_cast<int&>(num_splits_dynamic_u);
|
421 |
+
num_splits = num_splits_dynamic > 0 ? num_splits_dynamic : num_splits;
|
422 |
+
split_idx &= 0x0000FFFF; // Only use the lower 16 bits of split_idx
|
423 |
+
}
|
424 |
+
bool const is_split = !Split ? false : (!Varlen ? true : num_splits > 1);
|
425 |
+
|
426 |
+
flash::SeqlenInfo<Varlen, kBlockM> seqlen_info{bidb, size<0>(params.shape_O), params.cu_seqlens, params.seqused};
|
427 |
+
bool const is_varlen = Varlen && params.cu_seqlens;
|
428 |
+
int offset_o = seqlen_info.offset;
|
429 |
+
int seqlen_o = seqlen_info.seqlen;
|
430 |
+
int qhead_per_khead = !PackGQA ? 1 : params.qhead_per_khead_divmod.divisor;
|
431 |
+
Tensor mLSE = make_tensor(make_gmem_ptr((!is_split ? params.ptr_LSE : params.ptr_LSE_partial) + offset_o * get<0>(!is_split ? params.stride_LSE : params.stride_LSE_partial)),
|
432 |
+
params.shape_LSE_packed,
|
433 |
+
!is_split ? params.stride_LSE_packed : params.stride_LSE_partial_packed)(_, bidh, !is_varlen ? bidb : 0, !is_split ? 0 : split_idx);
|
434 |
+
Tensor gLSE = local_tile(mLSE, Shape<Int<kBlockM>>{}, make_coord(m_block));
|
435 |
+
|
436 |
+
static_assert(kBlockM <= NumEpilogueThreads);
|
437 |
+
if (thread_idx < kBlockM) {
|
438 |
+
const int row = m_block * kBlockM + thread_idx;
|
439 |
+
if constexpr (!PackGQA) {
|
440 |
+
if (row < seqlen_o) { mLSE(row) = -INFINITY; }
|
441 |
+
} else {
|
442 |
+
if (row < seqlen_o * qhead_per_khead) {
|
443 |
+
int m_idx, h_idx;
|
444 |
+
m_idx = params.qhead_per_khead_divmod.divmod(h_idx, row);
|
445 |
+
// mLSE has shape ((qhead_per_khead, seqlen_q)) and it's unhappy with just 1 "make_coord"
|
446 |
+
mLSE(make_coord(make_coord(h_idx, m_idx))) = -INFINITY;
|
447 |
+
}
|
448 |
+
}
|
449 |
+
}
|
450 |
+
|
451 |
+
// If split, we don't have to write 0 to mOpartial if the mha_combine kernel is used,
|
452 |
+
// since it will not use the value of O if LSE is -inf.
|
453 |
+
if (!is_split) {
|
454 |
+
Tensor mO = make_tensor(make_gmem_ptr(params.ptr_O + offset_o * get<0>(params.stride_O)), params.shape_O_packed, params.stride_O_packed)(_, _, bidh, !is_varlen ? bidb : 0, _0{});
|
455 |
+
|
456 |
+
GmemTiledCopyO gmem_tiled_copy_O;
|
457 |
+
auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx);
|
458 |
+
Tensor tOcO = gmem_thr_copy_O.partition_D(cute::make_identity_tensor(select<0, 1>(TileShape_MNK_PV{})));
|
459 |
+
if constexpr (!PackGQA) {
|
460 |
+
Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOcO)));
|
461 |
+
#pragma unroll
|
462 |
+
for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < get<1>(params.shape_O); }
|
463 |
+
Tensor gO = local_tile(mO, select<0, 1>(TileShape_MNK_PV{}), make_coord(m_block, _0{})); // (M, K)
|
464 |
+
Tensor tOgO = gmem_thr_copy_O.partition_D(gO);
|
465 |
+
Tensor tOrO = make_fragment_like(tOgO);
|
466 |
+
cute::clear(tOrO);
|
467 |
+
// Clear_OOB_K must be false since we don't want to write zeros to gmem
|
468 |
+
flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
|
469 |
+
gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, seqlen_o - m_block * kBlockM
|
470 |
+
);
|
471 |
+
} else {
|
472 |
+
// If PackGQA, we split the work of compute O_ptr among threads in the same row
|
473 |
+
using PackGQA_t = flash::PackGQAManager<get<0>(TileShape_MNK_PV{}), get<1>(TileShape_MNK_PV{}), NumEpilogueThreads, Element>;
|
474 |
+
Tensor tOrO = make_tensor<Element>(make_shape(Shape<_1, Int<kGmemElemsPerStore>>{}, size<1>(tOcO), size<2>(tOcO)));
|
475 |
+
cute::clear(tOrO);
|
476 |
+
PackGQA_t::store_O(mO, tOrO, params.qhead_per_khead_divmod, thread_idx, seqlen_o, m_block);
|
477 |
+
}
|
478 |
+
}
|
479 |
+
|
480 |
+
}
|
481 |
+
|
482 |
+
};
|
483 |
+
|
484 |
+
} // namespace flash
|
flash-attn/flash.h
ADDED
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/******************************************************************************
|
2 |
+
* Copyright (c) 2023, Tri Dao.
|
3 |
+
******************************************************************************/
|
4 |
+
|
5 |
+
#pragma once
|
6 |
+
|
7 |
+
#include <cuda.h>
|
8 |
+
#include <vector>
|
9 |
+
|
10 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
11 |
+
|
12 |
+
struct Qkv_params {
|
13 |
+
using index_t = int64_t;
|
14 |
+
// The QKV matrices.
|
15 |
+
void *__restrict__ q_ptr;
|
16 |
+
void *__restrict__ k_ptr;
|
17 |
+
void *__restrict__ v_ptr;
|
18 |
+
|
19 |
+
// The stride between rows of the Q, K and V matrices.
|
20 |
+
index_t q_batch_stride;
|
21 |
+
index_t k_batch_stride;
|
22 |
+
index_t v_batch_stride;
|
23 |
+
index_t q_row_stride;
|
24 |
+
index_t k_row_stride;
|
25 |
+
index_t v_row_stride;
|
26 |
+
index_t q_head_stride;
|
27 |
+
index_t k_head_stride;
|
28 |
+
index_t v_head_stride;
|
29 |
+
index_t v_dim_stride;
|
30 |
+
|
31 |
+
// The number of heads.
|
32 |
+
int h, h_k;
|
33 |
+
};
|
34 |
+
|
35 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
36 |
+
|
37 |
+
struct Flash_fwd_params : public Qkv_params {
|
38 |
+
using index_t = int64_t;
|
39 |
+
|
40 |
+
// The O matrix (output).
|
41 |
+
void * __restrict__ o_ptr;
|
42 |
+
void * __restrict__ oaccum_ptr;
|
43 |
+
|
44 |
+
// The stride between rows of O.
|
45 |
+
index_t o_batch_stride;
|
46 |
+
index_t o_row_stride;
|
47 |
+
index_t o_head_stride;
|
48 |
+
|
49 |
+
// The pointer to the softmax sum.
|
50 |
+
void * __restrict__ softmax_lse_ptr;
|
51 |
+
void * __restrict__ softmax_lseaccum_ptr;
|
52 |
+
|
53 |
+
// For FP8 scaling
|
54 |
+
float * __restrict__ q_descale_ptr;
|
55 |
+
float * __restrict__ k_descale_ptr;
|
56 |
+
float * __restrict__ v_descale_ptr;
|
57 |
+
index_t q_descale_batch_stride;
|
58 |
+
index_t q_descale_head_stride;
|
59 |
+
index_t k_descale_batch_stride;
|
60 |
+
index_t k_descale_head_stride;
|
61 |
+
index_t v_descale_batch_stride;
|
62 |
+
index_t v_descale_head_stride;
|
63 |
+
|
64 |
+
// The dimensions.
|
65 |
+
int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, rotary_dim;
|
66 |
+
int total_q, total_k, total_knew;
|
67 |
+
int b_k; // When having KV cache and with cache_batch_idx, K & V might have larger batch size than Q
|
68 |
+
int dv, dv_rounded; // For the case where V headdim is different from Q/K headdim
|
69 |
+
|
70 |
+
// The scaling factors for the kernel.
|
71 |
+
float scale_softmax;
|
72 |
+
float softcap;
|
73 |
+
|
74 |
+
// array of length b+1 holding starting offset of each sequence.
|
75 |
+
int * __restrict__ cu_seqlens_q;
|
76 |
+
int * __restrict__ cu_seqlens_k;
|
77 |
+
int * __restrict__ cu_seqlens_knew;
|
78 |
+
int * __restrict__ leftpad_k;
|
79 |
+
|
80 |
+
// If provided, the actual length of each q/k sequence.
|
81 |
+
int *__restrict__ seqused_q;
|
82 |
+
int *__restrict__ seqused_k;
|
83 |
+
|
84 |
+
// The stride between rows of Oaccum.
|
85 |
+
index_t oaccum_split_stride;
|
86 |
+
index_t oaccum_batch_stride;
|
87 |
+
index_t oaccum_row_stride;
|
88 |
+
index_t oaccum_head_stride;
|
89 |
+
|
90 |
+
// The stride between rows of LSEaccum.
|
91 |
+
index_t lseaccum_split_stride;
|
92 |
+
index_t lseaccum_batch_stride;
|
93 |
+
index_t lseaccum_head_stride;
|
94 |
+
|
95 |
+
// The K_new and V_new matrices.
|
96 |
+
void * __restrict__ knew_ptr;
|
97 |
+
void * __restrict__ vnew_ptr;
|
98 |
+
|
99 |
+
// The stride between rows of the Q, K and V matrices.
|
100 |
+
index_t knew_batch_stride;
|
101 |
+
index_t vnew_batch_stride;
|
102 |
+
index_t knew_row_stride;
|
103 |
+
index_t vnew_row_stride;
|
104 |
+
index_t knew_head_stride;
|
105 |
+
index_t vnew_head_stride;
|
106 |
+
|
107 |
+
void *__restrict__ qv_ptr;
|
108 |
+
index_t qv_batch_stride;
|
109 |
+
index_t qv_row_stride;
|
110 |
+
index_t qv_head_stride;
|
111 |
+
|
112 |
+
// The cos and sin matrices for rotary embedding.
|
113 |
+
void * __restrict__ rotary_cos_ptr;
|
114 |
+
void * __restrict__ rotary_sin_ptr;
|
115 |
+
int *__restrict__ seqlens_rotary;
|
116 |
+
|
117 |
+
// The indices to index into the KV cache.
|
118 |
+
int * __restrict__ kv_batch_idx;
|
119 |
+
|
120 |
+
// Paged KV cache
|
121 |
+
int * __restrict__ page_table;
|
122 |
+
index_t page_table_batch_stride;
|
123 |
+
int page_size;
|
124 |
+
int num_pages;
|
125 |
+
bool pagedkv_tma;
|
126 |
+
|
127 |
+
// The dropout probability (probability of keeping an activation).
|
128 |
+
float p_dropout;
|
129 |
+
// uint32_t p_dropout_in_uint;
|
130 |
+
// uint16_t p_dropout_in_uint16_t;
|
131 |
+
uint8_t p_dropout_in_uint8_t;
|
132 |
+
|
133 |
+
// Scale factor of 1 / (1 - p_dropout).
|
134 |
+
float rp_dropout;
|
135 |
+
|
136 |
+
// Local window size
|
137 |
+
int window_size_left, window_size_right;
|
138 |
+
int attention_chunk;
|
139 |
+
|
140 |
+
// Pointer to the RNG seed (idx 0) and offset (idx 1).
|
141 |
+
uint64_t * rng_state;
|
142 |
+
|
143 |
+
bool is_bf16;
|
144 |
+
bool is_fp32;
|
145 |
+
bool is_e4m3;
|
146 |
+
bool is_causal;
|
147 |
+
bool is_local;
|
148 |
+
|
149 |
+
bool is_rotary_interleaved;
|
150 |
+
|
151 |
+
int num_splits; // For split-KV version
|
152 |
+
bool pack_gqa;
|
153 |
+
|
154 |
+
int * __restrict__ tile_count_semaphore;
|
155 |
+
// int * __restrict__ num_m_blocks_ptr;
|
156 |
+
// int * __restrict__ num_n_blocks_ptr;
|
157 |
+
int * __restrict__ num_splits_dynamic_ptr;
|
158 |
+
bool skip_scheduler_metadata_computation;
|
159 |
+
|
160 |
+
int arch;
|
161 |
+
int num_sm;
|
162 |
+
};
|
163 |
+
|
164 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
165 |
+
|
166 |
+
struct Flash_bwd_params : public Flash_fwd_params {
|
167 |
+
using index_t = int64_t;
|
168 |
+
|
169 |
+
// The dO and dQKV matrices.
|
170 |
+
void *__restrict__ do_ptr;
|
171 |
+
void *__restrict__ dq_ptr;
|
172 |
+
void *__restrict__ dk_ptr;
|
173 |
+
void *__restrict__ dv_ptr;
|
174 |
+
|
175 |
+
// To accumulate dQ
|
176 |
+
void *__restrict__ dq_accum_ptr;
|
177 |
+
void *__restrict__ dk_accum_ptr;
|
178 |
+
void *__restrict__ dv_accum_ptr;
|
179 |
+
|
180 |
+
// // To accumulate dK and dV in case we're splitting the bwd along seqlen_q
|
181 |
+
// dimension void *__restrict__ dk_accum_ptr; void *__restrict__
|
182 |
+
// dv_accum_ptr;
|
183 |
+
|
184 |
+
// The stride between rows of the dO, dQ, dK and dV matrices.
|
185 |
+
index_t do_batch_stride;
|
186 |
+
index_t do_row_stride;
|
187 |
+
index_t do_head_stride;
|
188 |
+
index_t dq_batch_stride;
|
189 |
+
index_t dk_batch_stride;
|
190 |
+
index_t dv_batch_stride;
|
191 |
+
index_t dq_row_stride;
|
192 |
+
index_t dk_row_stride;
|
193 |
+
index_t dv_row_stride;
|
194 |
+
index_t dq_head_stride;
|
195 |
+
index_t dk_head_stride;
|
196 |
+
index_t dv_head_stride;
|
197 |
+
|
198 |
+
// The pointer to the softmax d sum.
|
199 |
+
void *__restrict__ dsoftmax_sum;
|
200 |
+
void *__restrict__ softmax_lse_log2_ptr;
|
201 |
+
|
202 |
+
int *__restrict__ dq_semaphore;
|
203 |
+
int *__restrict__ dk_semaphore;
|
204 |
+
int *__restrict__ dv_semaphore;
|
205 |
+
|
206 |
+
bool deterministic;
|
207 |
+
index_t dq_accum_split_stride;
|
208 |
+
};
|
209 |
+
|
210 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
211 |
+
|
212 |
+
template <int Arch, typename T, int kHeadDim, int kHeadDimV, bool Split, bool PagedKVNonTMA, bool Has_softcap, bool PackGQA>
|
213 |
+
void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream);
|
214 |
+
void prepare_varlen_num_blocks(Flash_fwd_params ¶ms, cudaStream_t stream, bool packgqa, int blockM, int blockN, bool enable_pdl);
|
215 |
+
template <int Arch, typename T, int kHeadDim, bool Has_softcap>
|
216 |
+
void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream);
|
217 |
+
template <typename T, typename Tpartial, int kBlockK>
|
218 |
+
void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream, bool enable_pdl);
|
flash-attn/flash_api.cpp
ADDED
@@ -0,0 +1,1720 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/******************************************************************************
|
2 |
+
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
3 |
+
******************************************************************************/
|
4 |
+
|
5 |
+
#include <Python.h>
|
6 |
+
#include <torch/nn/functional/padding.h>
|
7 |
+
#include <ATen/cuda/CUDAContextLight.h>
|
8 |
+
#include <c10/cuda/CUDAGuard.h>
|
9 |
+
|
10 |
+
#include <cutlass/numeric_types.h>
|
11 |
+
|
12 |
+
#include "flash.h"
|
13 |
+
#include "static_switch.h"
|
14 |
+
#include "tile_size.h"
|
15 |
+
#include "heuristics.h"
|
16 |
+
#include "cuda_check.h"
|
17 |
+
|
18 |
+
|
19 |
+
extern "C" {
|
20 |
+
/* Creates a dummy empty _C module that can be imported from Python.
|
21 |
+
The import from Python will load the .so consisting of this file
|
22 |
+
in this extension, so that the TORCH_LIBRARY static initializers
|
23 |
+
below are run. */
|
24 |
+
PyObject* PyInit__C(void)
|
25 |
+
{
|
26 |
+
static struct PyModuleDef module_def = {
|
27 |
+
PyModuleDef_HEAD_INIT,
|
28 |
+
"_C", /* name of module */
|
29 |
+
NULL, /* module documentation, may be NULL */
|
30 |
+
-1, /* size of per-interpreter state of the module,
|
31 |
+
or -1 if the module keeps state in global variables. */
|
32 |
+
NULL, /* methods */
|
33 |
+
};
|
34 |
+
return PyModule_Create(&module_def);
|
35 |
+
}
|
36 |
+
}
|
37 |
+
|
38 |
+
#define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA")
|
39 |
+
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
|
40 |
+
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
41 |
+
|
42 |
+
void set_params_fprop(Flash_fwd_params ¶ms,
|
43 |
+
// sizes
|
44 |
+
const size_t b,
|
45 |
+
const size_t seqlen_q,
|
46 |
+
const size_t seqlen_k,
|
47 |
+
const size_t seqlen_q_rounded,
|
48 |
+
const size_t seqlen_k_rounded,
|
49 |
+
const size_t h,
|
50 |
+
const size_t h_k,
|
51 |
+
const size_t d,
|
52 |
+
const size_t d_rounded,
|
53 |
+
// device pointers
|
54 |
+
const at::Tensor q,
|
55 |
+
const at::Tensor k,
|
56 |
+
const at::Tensor v,
|
57 |
+
at::Tensor out,
|
58 |
+
void *cu_seqlens_q_d,
|
59 |
+
void *cu_seqlens_k_d,
|
60 |
+
void *seqused_q,
|
61 |
+
void *seqused_k,
|
62 |
+
void *softmax_lse_d,
|
63 |
+
float p_dropout,
|
64 |
+
float softmax_scale,
|
65 |
+
int window_size_left,
|
66 |
+
int window_size_right,
|
67 |
+
int attention_chunk,
|
68 |
+
const float softcap=0.f,
|
69 |
+
const int sm_margin=0) {
|
70 |
+
|
71 |
+
// Reset the parameters
|
72 |
+
params = {};
|
73 |
+
|
74 |
+
params.is_bf16 = q.dtype() == torch::kBFloat16;
|
75 |
+
params.is_e4m3 = q.dtype() == torch::kFloat8_e4m3fn;
|
76 |
+
|
77 |
+
// Set the pointers and strides.
|
78 |
+
params.q_ptr = q.data_ptr();
|
79 |
+
params.k_ptr = k.data_ptr();
|
80 |
+
params.v_ptr = v.data_ptr();
|
81 |
+
// All stride are in elements, not bytes.
|
82 |
+
params.q_row_stride = q.stride(-3);
|
83 |
+
params.k_row_stride = k.stride(-3);
|
84 |
+
params.v_row_stride = v.stride(-3);
|
85 |
+
params.q_head_stride = q.stride(-2);
|
86 |
+
params.k_head_stride = k.stride(-2);
|
87 |
+
params.v_head_stride = v.stride(-2);
|
88 |
+
params.v_dim_stride = v.stride(-1);
|
89 |
+
params.o_ptr = out.data_ptr();
|
90 |
+
params.o_row_stride = out.stride(-3);
|
91 |
+
params.o_head_stride = out.stride(-2);
|
92 |
+
|
93 |
+
if (cu_seqlens_q_d == nullptr) {
|
94 |
+
params.q_batch_stride = q.stride(0);
|
95 |
+
params.o_batch_stride = out.stride(0);
|
96 |
+
}
|
97 |
+
if (cu_seqlens_k_d == nullptr) {
|
98 |
+
params.k_batch_stride = k.stride(0);
|
99 |
+
params.v_batch_stride = v.stride(0);
|
100 |
+
}
|
101 |
+
|
102 |
+
params.cu_seqlens_q = static_cast<int *>(cu_seqlens_q_d);
|
103 |
+
params.cu_seqlens_k = static_cast<int *>(cu_seqlens_k_d);
|
104 |
+
params.seqused_q = static_cast<int *>(seqused_q);
|
105 |
+
params.seqused_k = static_cast<int *>(seqused_k);
|
106 |
+
|
107 |
+
// Softmax sum
|
108 |
+
params.softmax_lse_ptr = softmax_lse_d;
|
109 |
+
|
110 |
+
// Set the dimensions.
|
111 |
+
params.b = b;
|
112 |
+
params.h = h;
|
113 |
+
params.h_k = h_k;
|
114 |
+
params.seqlen_q = seqlen_q;
|
115 |
+
params.seqlen_k = seqlen_k;
|
116 |
+
params.seqlen_q_rounded = seqlen_q_rounded;
|
117 |
+
params.seqlen_k_rounded = seqlen_k_rounded;
|
118 |
+
params.d = d;
|
119 |
+
params.d_rounded = d_rounded;
|
120 |
+
|
121 |
+
// Set the different scale values.
|
122 |
+
params.scale_softmax = softmax_scale;
|
123 |
+
params.softcap = softcap;
|
124 |
+
|
125 |
+
// Set this to probability of keeping an element to simplify things.
|
126 |
+
params.p_dropout = 1.f - p_dropout;
|
127 |
+
// Convert p from float to int so we don't have to convert the random uint to float to compare.
|
128 |
+
// [Minor] We want to round down since when we do the comparison we use <= instead of <
|
129 |
+
// params.p_dropout_in_uint = uint32_t(std::floor(params.p_dropout * 4294967295.0));
|
130 |
+
// params.p_dropout_in_uint16_t = uint16_t(std::floor(params.p_dropout * 65535.0));
|
131 |
+
params.p_dropout_in_uint8_t = uint8_t(std::floor(params.p_dropout * 255.0));
|
132 |
+
params.rp_dropout = 1.f / params.p_dropout;
|
133 |
+
TORCH_CHECK(p_dropout < 1.f);
|
134 |
+
#ifdef FLASHATTENTION_DISABLE_DROPOUT
|
135 |
+
TORCH_CHECK(p_dropout == 0.0f, "This flash attention build does not support dropout.");
|
136 |
+
#endif
|
137 |
+
|
138 |
+
// Causal is the special case where window_size_right == 0 and window_size_left < 0.
|
139 |
+
// Local is the more general case where window_size_right >= 0 or window_size_left >= 0.
|
140 |
+
params.is_causal = window_size_left < 0 && window_size_right == 0 && attention_chunk == 0;
|
141 |
+
params.is_local = (window_size_left >= 0 || window_size_right >= 0 || attention_chunk >= 1) && !params.is_causal;
|
142 |
+
|
143 |
+
// TODO: check this
|
144 |
+
if (window_size_left < 0) { window_size_left = seqlen_k - 1; }
|
145 |
+
if (window_size_right < 0) { window_size_right = seqlen_q - 1; }
|
146 |
+
if (attention_chunk > 0) {
|
147 |
+
window_size_left = std::min(window_size_left, attention_chunk - 1);
|
148 |
+
window_size_right = std::min(window_size_right, attention_chunk - 1);
|
149 |
+
}
|
150 |
+
params.window_size_left = window_size_left;
|
151 |
+
params.window_size_right = window_size_right;
|
152 |
+
params.attention_chunk = attention_chunk;
|
153 |
+
|
154 |
+
params.arch = at::cuda::getCurrentDeviceProperties()->major * 10 + at::cuda::getCurrentDeviceProperties()->minor;
|
155 |
+
params.num_sm = at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin;
|
156 |
+
|
157 |
+
#ifdef FLASHATTENTION_DISABLE_LOCAL
|
158 |
+
TORCH_CHECK(!params.is_local, "This flash attention build does not support local attention.");
|
159 |
+
#endif
|
160 |
+
}
|
161 |
+
|
162 |
+
void set_params_dgrad(Flash_bwd_params ¶ms,
|
163 |
+
// sizes
|
164 |
+
const size_t b,
|
165 |
+
const size_t seqlen_q,
|
166 |
+
const size_t seqlen_k,
|
167 |
+
const size_t seqlen_q_rounded,
|
168 |
+
const size_t seqlen_k_rounded,
|
169 |
+
const size_t h,
|
170 |
+
const size_t h_k,
|
171 |
+
const size_t d,
|
172 |
+
const size_t d_rounded,
|
173 |
+
// device pointers
|
174 |
+
const at::Tensor q,
|
175 |
+
const at::Tensor k,
|
176 |
+
const at::Tensor v,
|
177 |
+
const at::Tensor out,
|
178 |
+
const at::Tensor dout,
|
179 |
+
at::Tensor dq,
|
180 |
+
at::Tensor dk,
|
181 |
+
at::Tensor dv,
|
182 |
+
void *cu_seqlens_q_d,
|
183 |
+
void *cu_seqlens_k_d,
|
184 |
+
void *seqused_q,
|
185 |
+
void *seqused_k,
|
186 |
+
void *dq_accum_d,
|
187 |
+
void *dk_accum_d,
|
188 |
+
void *dv_accum_d,
|
189 |
+
void *softmax_lse_d,
|
190 |
+
void *dsoftmax_sum_d,
|
191 |
+
float p_dropout,
|
192 |
+
float softmax_scale,
|
193 |
+
int window_size_left,
|
194 |
+
int window_size_right,
|
195 |
+
int attention_chunk,
|
196 |
+
const float softcap=0.f,
|
197 |
+
bool deterministic=false,
|
198 |
+
int const sm_margin=0) {
|
199 |
+
|
200 |
+
set_params_fprop(params,
|
201 |
+
b, seqlen_q, seqlen_k, seqlen_q_rounded, seqlen_k_rounded, h, h_k, d, d_rounded,
|
202 |
+
q, k, v, out,
|
203 |
+
cu_seqlens_q_d,
|
204 |
+
cu_seqlens_k_d,
|
205 |
+
seqused_q,
|
206 |
+
seqused_k,
|
207 |
+
softmax_lse_d,
|
208 |
+
p_dropout,
|
209 |
+
softmax_scale,
|
210 |
+
window_size_left,
|
211 |
+
window_size_right,
|
212 |
+
attention_chunk,
|
213 |
+
softcap,
|
214 |
+
sm_margin);
|
215 |
+
|
216 |
+
// Set the pointers and strides.
|
217 |
+
params.do_ptr = dout.data_ptr();
|
218 |
+
params.do_row_stride = dout.stride(-3);
|
219 |
+
params.do_head_stride = dout.stride(-2);
|
220 |
+
params.dq_ptr = dq.data_ptr();
|
221 |
+
params.dk_ptr = dk.data_ptr();
|
222 |
+
params.dv_ptr = dv.data_ptr();
|
223 |
+
params.dq_row_stride = dq.stride(-3);
|
224 |
+
params.dk_row_stride = dk.stride(-3);
|
225 |
+
params.dv_row_stride = dv.stride(-3);
|
226 |
+
params.dq_head_stride = dq.stride(-2);
|
227 |
+
params.dk_head_stride = dk.stride(-2);
|
228 |
+
params.dv_head_stride = dv.stride(-2);
|
229 |
+
|
230 |
+
if (cu_seqlens_q_d == nullptr) {
|
231 |
+
params.do_batch_stride = dout.stride(0);
|
232 |
+
params.dq_batch_stride = dq.stride(0);
|
233 |
+
params.dk_batch_stride = dk.stride(0);
|
234 |
+
params.dv_batch_stride = dv.stride(0);
|
235 |
+
}
|
236 |
+
|
237 |
+
params.dq_accum_ptr = dq_accum_d;
|
238 |
+
params.dk_accum_ptr = dk_accum_d;
|
239 |
+
params.dv_accum_ptr = dv_accum_d;
|
240 |
+
|
241 |
+
// Softmax sum
|
242 |
+
params.dsoftmax_sum = dsoftmax_sum_d;
|
243 |
+
|
244 |
+
params.deterministic = deterministic;
|
245 |
+
}
|
246 |
+
|
247 |
+
template <int Arch, int Split, bool PagedKVNonTMA, bool PackGQA, bool Has_softcap>
|
248 |
+
void run_mha_fwd_constexpr(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
249 |
+
if (!params.is_e4m3) {
|
250 |
+
if (params.is_bf16) {
|
251 |
+
#ifndef FLASHATTENTION_DISABLE_HDIM64
|
252 |
+
if (params.d <= 64) {
|
253 |
+
if constexpr (Arch == 90) {
|
254 |
+
if (params.dv > 256) {
|
255 |
+
return run_mha_fwd_<Arch, cutlass::bfloat16_t, 64, 512, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);
|
256 |
+
} else if (params.dv > 64) {
|
257 |
+
return run_mha_fwd_<Arch, cutlass::bfloat16_t, 64, 256, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);
|
258 |
+
}
|
259 |
+
}
|
260 |
+
return run_mha_fwd_<Arch, cutlass::bfloat16_t, 64, 64, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);
|
261 |
+
}
|
262 |
+
#endif
|
263 |
+
#ifndef FLASHATTENTION_DISABLE_HDIM96
|
264 |
+
if (params.d <= 96) { return run_mha_fwd_<Arch, cutlass::bfloat16_t, 96, 96, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); }
|
265 |
+
#endif
|
266 |
+
#ifndef FLASHATTENTION_DISABLE_HDIM128
|
267 |
+
if (params.d <= 128) { return run_mha_fwd_<Arch, cutlass::bfloat16_t, 128, 128, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); }
|
268 |
+
#endif
|
269 |
+
#ifndef FLASHATTENTION_DISABLE_HDIM192
|
270 |
+
if (params.d <= 192) {
|
271 |
+
if constexpr (Arch == 90) {
|
272 |
+
if (params.dv <= 128) {
|
273 |
+
return run_mha_fwd_<Arch, cutlass::bfloat16_t, 192, 128, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);
|
274 |
+
}
|
275 |
+
}
|
276 |
+
return run_mha_fwd_<Arch, cutlass::bfloat16_t, 192, 192, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);
|
277 |
+
}
|
278 |
+
#endif
|
279 |
+
#ifndef FLASHATTENTION_DISABLE_HDIM256
|
280 |
+
if (params.d <= 256) { return run_mha_fwd_<Arch, cutlass::bfloat16_t, 256, 256, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); }
|
281 |
+
#endif
|
282 |
+
} else {
|
283 |
+
#ifndef FLASHATTENTION_DISABLE_FP16
|
284 |
+
#ifndef FLASHATTENTION_DISABLE_HDIM64
|
285 |
+
if (params.d <= 64) {
|
286 |
+
if constexpr (Arch == 90) {
|
287 |
+
if (params.dv > 256) {
|
288 |
+
return run_mha_fwd_<Arch, cutlass::half_t, 64, 512, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);
|
289 |
+
} else if (params.dv > 64) {
|
290 |
+
return run_mha_fwd_<Arch, cutlass::half_t, 64, 256, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);
|
291 |
+
}
|
292 |
+
}
|
293 |
+
return run_mha_fwd_<Arch, cutlass::half_t, 64, 64, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);
|
294 |
+
}
|
295 |
+
#endif
|
296 |
+
#ifndef FLASHATTENTION_DISABLE_HDIM96
|
297 |
+
if (params.d <= 96) { return run_mha_fwd_<Arch, cutlass::half_t, 96, 96, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); }
|
298 |
+
#endif
|
299 |
+
#ifndef FLASHATTENTION_DISABLE_HDIM128
|
300 |
+
if (params.d <= 128) { return run_mha_fwd_<Arch, cutlass::half_t, 128, 128, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); }
|
301 |
+
#endif
|
302 |
+
#ifndef FLASHATTENTION_DISABLE_HDIM192
|
303 |
+
if (params.d <= 192) {
|
304 |
+
if constexpr (Arch == 90) {
|
305 |
+
if (params.dv <= 128) {
|
306 |
+
return run_mha_fwd_<Arch, cutlass::half_t, 192, 128, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);
|
307 |
+
}
|
308 |
+
}
|
309 |
+
return run_mha_fwd_<Arch, cutlass::half_t, 192, 192, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);
|
310 |
+
}
|
311 |
+
#endif
|
312 |
+
#ifndef FLASHATTENTION_DISABLE_HDIM256
|
313 |
+
if (params.d <= 256) { return run_mha_fwd_<Arch, cutlass::half_t, 256, 256, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); }
|
314 |
+
#endif
|
315 |
+
#else
|
316 |
+
TORCH_CHECK(false, "This flash attention build does not support FP16.");
|
317 |
+
#endif
|
318 |
+
}
|
319 |
+
} else {
|
320 |
+
#ifndef FLASHATTENTION_DISABLE_FP8
|
321 |
+
#ifndef FLASHATTENTION_DISABLE_HDIM64
|
322 |
+
if (params.d <= 64) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 64, 64, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); }
|
323 |
+
#endif
|
324 |
+
#ifndef FLASHATTENTION_DISABLE_HDIM96
|
325 |
+
if (params.d <= 96) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 96, 96, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); }
|
326 |
+
#endif
|
327 |
+
#ifndef FLASHATTENTION_DISABLE_HDIM128
|
328 |
+
if (params.d <= 128) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 128, 128, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); }
|
329 |
+
#endif
|
330 |
+
#ifndef FLASHATTENTION_DISABLE_HDIM192
|
331 |
+
if (params.d <= 192) {
|
332 |
+
if constexpr (Arch == 90) {
|
333 |
+
if (params.dv <= 128) {
|
334 |
+
return run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 128, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);
|
335 |
+
}
|
336 |
+
}
|
337 |
+
return run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 192, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);
|
338 |
+
}
|
339 |
+
#endif
|
340 |
+
#ifndef FLASHATTENTION_DISABLE_HDIM256
|
341 |
+
if (params.d <= 256) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 256, 256, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); }
|
342 |
+
#endif
|
343 |
+
#else
|
344 |
+
TORCH_CHECK(false, "This flash attention build does not support FP8.");
|
345 |
+
#endif
|
346 |
+
}
|
347 |
+
}
|
348 |
+
|
349 |
+
void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
350 |
+
// HEADDIM_SWITCH(params.d, [&] {
|
351 |
+
// run_mha_fwd_<cutlass::half_t, kHeadSize>(params, stream);
|
352 |
+
// });
|
353 |
+
TORCH_CHECK(params.num_splits >= 1);
|
354 |
+
ARCH_SWITCH(params.arch, Arch, [&] {
|
355 |
+
SPLIT_SWITCH(params.num_splits > 1, Split, [&] {
|
356 |
+
PAGEDKV_SWITCH(params.page_table && !params.pagedkv_tma, PagedKVNonTMA, [&] {
|
357 |
+
PACKGQA_SWITCH(params.pack_gqa, PackGQA_, [&] {
|
358 |
+
// Always enable PackGQA for Sm8x or PagedKVNonTMA or Split to reduce compilation
|
359 |
+
static constexpr bool PackGQA = PackGQA_ || Arch < 90 || PagedKVNonTMA || Split;
|
360 |
+
SOFTCAP_SWITCH(params.softcap > 0.0, Has_softcap, [&] {
|
361 |
+
run_mha_fwd_constexpr<Arch, Split, PagedKVNonTMA, PackGQA, Has_softcap>(params, stream);
|
362 |
+
});
|
363 |
+
});
|
364 |
+
});
|
365 |
+
});
|
366 |
+
});
|
367 |
+
}
|
368 |
+
|
369 |
+
void run_mha_fwd_combine(Flash_fwd_params ¶ms, cudaStream_t stream, bool enable_pdl=false) {
|
370 |
+
#ifndef FLASHATTENTION_DISABLE_SPLIT
|
371 |
+
// If hdim is 96 or 192, it's faster to round them to 128 or 256 respectively
|
372 |
+
// so that kBlockM is smaller and we have more parallelism.
|
373 |
+
if (params.is_fp32) {
|
374 |
+
if (params.dv <= 64) {
|
375 |
+
run_mha_fwd_combine_<float, float, 64>(params, stream, enable_pdl);
|
376 |
+
} else {
|
377 |
+
run_mha_fwd_combine_<float, float, 128>(params, stream, enable_pdl);
|
378 |
+
}
|
379 |
+
} else if (params.is_bf16) {
|
380 |
+
if (params.dv <= 64) {
|
381 |
+
run_mha_fwd_combine_<cutlass::bfloat16_t, float, 64>(params, stream, enable_pdl);
|
382 |
+
} else {
|
383 |
+
run_mha_fwd_combine_<cutlass::bfloat16_t, float, 128>(params, stream, enable_pdl);
|
384 |
+
}
|
385 |
+
} else {
|
386 |
+
if (params.dv <= 64) {
|
387 |
+
run_mha_fwd_combine_<cutlass::half_t, float, 64>(params, stream, enable_pdl);
|
388 |
+
} else {
|
389 |
+
run_mha_fwd_combine_<cutlass::half_t, float, 128>(params, stream, enable_pdl);
|
390 |
+
}
|
391 |
+
}
|
392 |
+
#else
|
393 |
+
TORCH_CHECK(false, "This flash attention build does not support combine kernels.");
|
394 |
+
#endif
|
395 |
+
}
|
396 |
+
|
397 |
+
inline bool get_pagedkv_tma(Flash_fwd_params const& params) {
|
398 |
+
if (params.arch < 90 || !params.page_table || params.leftpad_k || params.knew_ptr) { return false; }
|
399 |
+
// This needs to match the kernel configs
|
400 |
+
auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, false /*paged_kv_non_TMA*/, params.softcap > 0.f);
|
401 |
+
int const kBlockM = std::get<0>(kBlockMN_kernel_args_sm90);
|
402 |
+
int const kBlockN = std::get<1>(kBlockMN_kernel_args_sm90);
|
403 |
+
// Heuristic: when seqlen_q <= kBlockM, we're not compute bound, and somehow using TMA is slower,
|
404 |
+
// at least for MLA.
|
405 |
+
return params.page_size % kBlockN == 0 && params.seqlen_q * (params.h / params.h_k) > kBlockM;
|
406 |
+
}
|
407 |
+
|
408 |
+
inline bool get_pack_gqa(Flash_fwd_params const& params) {
|
409 |
+
// Always enable PackGQA for Sm8x or PagedKVNonTMA or Split to reduce compilation and binary size.
|
410 |
+
// Has little effect on speed.
|
411 |
+
if (params.arch < 90 || (params.page_table && !params.pagedkv_tma) || params.num_splits > 1) { return true; }
|
412 |
+
#ifdef FLASHATTENTION_DISABLE_PACKGQA
|
413 |
+
return false;
|
414 |
+
#else
|
415 |
+
// params.page_table must already be set
|
416 |
+
if (params.h == params.h_k) { return false; }
|
417 |
+
// This needs to match the kernel configs
|
418 |
+
auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table && !params.pagedkv_tma, params.softcap > 0.f);
|
419 |
+
int const kBlockM = std::get<0>(kBlockMN_kernel_args_sm90);
|
420 |
+
return should_pack_gqa(params.cu_seqlens_q || params.seqused_q, params.seqlen_q, params.h / params.h_k, kBlockM);
|
421 |
+
#endif
|
422 |
+
}
|
423 |
+
|
424 |
+
inline int get_num_splits(Flash_fwd_params const& params) {
|
425 |
+
#ifdef FLASHATTENTION_DISABLE_SPLIT
|
426 |
+
return 1;
|
427 |
+
#else
|
428 |
+
// Always enable PackGQA for Split
|
429 |
+
// params.page_table must already be set
|
430 |
+
// This needs to match the kernel configs
|
431 |
+
bool varlen = params.cu_seqlens_q || params.cu_seqlens_k || params.seqused_q || params.seqused_k || params.leftpad_k;
|
432 |
+
auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table && !params.pagedkv_tma, params.softcap > 0.f);
|
433 |
+
// Strictly speaking we need to pass in (varlen && params.num_splits > 1) but num_splits
|
434 |
+
// has not been set here. It's OK though because we might just underestimate kBlockN a bit
|
435 |
+
auto kBlockMN_kernel_args_sm8x = tile_size_fwd_sm8x(params.arch == 86 || params.arch == 89, params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, params.page_table, varlen, params.softcap > 0.f, params.knew_ptr);
|
436 |
+
int const kBlockM = params.arch >= 90 ? std::get<0>(kBlockMN_kernel_args_sm90) : std::get<0>(kBlockMN_kernel_args_sm8x);
|
437 |
+
int const kBlockN = params.arch >= 90 ? std::get<1>(kBlockMN_kernel_args_sm90) : std::get<1>(kBlockMN_kernel_args_sm8x);
|
438 |
+
int seqlen_q_packgqa = params.seqlen_q * (params.h / params.h_k);
|
439 |
+
// If is_local, we're not going to load all of seqlen_k
|
440 |
+
int const seqlen_k_loaded = !params.is_local
|
441 |
+
? params.seqlen_k
|
442 |
+
: std::max(0, std::min(params.seqlen_k, params.window_size_right + params.window_size_left + 1 + kBlockM));
|
443 |
+
int const num_n_blocks = (seqlen_k_loaded + kBlockN - 1) / kBlockN;
|
444 |
+
int const num_m_blocks = (seqlen_q_packgqa + kBlockM - 1) / kBlockM;
|
445 |
+
int const size_one_kv_head = params.seqlen_k * (params.d + params.dv) * (params.is_e4m3 ? 1 : 2);
|
446 |
+
// Always enable PackGQA for Split
|
447 |
+
// If varlen, we use dynamic split, so this heuristic just needs to get an upper bound on num_splits.
|
448 |
+
// We assume the case where there's 1 long sequence and the rest are short, i.e. pretending
|
449 |
+
// that batch = 1.
|
450 |
+
int total_mblocks = (params.num_splits_dynamic_ptr ? 1 : params.b) * params.h_k * num_m_blocks;
|
451 |
+
return num_splits_heuristic(total_mblocks, params.num_sm, num_n_blocks, num_m_blocks, size_one_kv_head, params.is_causal || params.is_local, 128);
|
452 |
+
#endif
|
453 |
+
}
|
454 |
+
|
455 |
+
inline int get_max_headdim() {
|
456 |
+
#ifndef FLASHATTENTION_DISABLE_HDIM256
|
457 |
+
return 256;
|
458 |
+
#endif
|
459 |
+
#ifndef FLASHATTENTION_DISABLE_HDIM192
|
460 |
+
return 192;
|
461 |
+
#endif
|
462 |
+
#ifndef FLASHATTENTION_DISABLE_HDIM128
|
463 |
+
return 128;
|
464 |
+
#endif
|
465 |
+
#ifndef FLASHATTENTION_DISABLE_HDIM96
|
466 |
+
return 96;
|
467 |
+
#endif
|
468 |
+
#ifndef FLASHATTENTION_DISABLE_HDIM64
|
469 |
+
return 64;
|
470 |
+
#endif
|
471 |
+
return 0;
|
472 |
+
}
|
473 |
+
|
474 |
+
inline int round_up_headdim(int head_size) {
|
475 |
+
#ifndef FLASHATTENTION_DISABLE_HDIM64
|
476 |
+
if (head_size <= 64) { return 64; }
|
477 |
+
#endif
|
478 |
+
#ifndef FLASHATTENTION_DISABLE_HDIM96
|
479 |
+
if (head_size <= 96) { return 96; }
|
480 |
+
#endif
|
481 |
+
#ifndef FLASHATTENTION_DISABLE_HDIM128
|
482 |
+
if (head_size <= 128) { return 128; }
|
483 |
+
#endif
|
484 |
+
#ifndef FLASHATTENTION_DISABLE_HDIM192
|
485 |
+
if (head_size <= 192) { return 192; }
|
486 |
+
#endif
|
487 |
+
#ifndef FLASHATTENTION_DISABLE_HDIM256
|
488 |
+
if (head_size <= 256) { return 256; }
|
489 |
+
#endif
|
490 |
+
return 256;
|
491 |
+
}
|
492 |
+
|
493 |
+
inline int round_up_headdimv(int head_size) {
|
494 |
+
if (head_size <= 64) { return 64; }
|
495 |
+
if (head_size <= 96) { return 96; }
|
496 |
+
if (head_size <= 128) { return 128; }
|
497 |
+
if (head_size <= 192) { return 192; }
|
498 |
+
if (head_size <= 256) { return 256; }
|
499 |
+
return 512;
|
500 |
+
}
|
501 |
+
|
502 |
+
// Only applicable to the case where seqused_k (i.e. cache_seqlens) is available
|
503 |
+
at::Tensor
|
504 |
+
mha_fwd_get_scheduler_metadata(
|
505 |
+
int64_t batch_size,
|
506 |
+
int64_t max_seqlen_q,
|
507 |
+
int64_t max_seqlen_k,
|
508 |
+
int64_t num_heads,
|
509 |
+
int64_t num_heads_k,
|
510 |
+
int64_t headdim,
|
511 |
+
int64_t headdim_v,
|
512 |
+
at::ScalarType qkv_dtype,
|
513 |
+
at::Tensor seqused_k, // b
|
514 |
+
std::optional<at::Tensor> cu_seqlens_q_, // b+1
|
515 |
+
std::optional<at::Tensor> cu_seqlens_k_, // b+1
|
516 |
+
std::optional<at::Tensor> cu_seqlens_k_new_, // b+1
|
517 |
+
std::optional<at::Tensor> seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used.
|
518 |
+
std::optional<at::Tensor> leftpad_k_, // b
|
519 |
+
std::optional<int64_t> page_size,
|
520 |
+
int64_t max_seqlen_k_new, // 0 means we're not appending new KV
|
521 |
+
bool is_causal,
|
522 |
+
int64_t window_size_left,
|
523 |
+
int64_t window_size_right,
|
524 |
+
int64_t attention_chunk,
|
525 |
+
bool has_softcap,
|
526 |
+
int64_t num_splits,
|
527 |
+
std::optional<bool> pack_gqa_,
|
528 |
+
int64_t sm_margin
|
529 |
+
) {
|
530 |
+
|
531 |
+
TORCH_CHECK(qkv_dtype == at::ScalarType::Half || qkv_dtype == at::ScalarType::BFloat16 || qkv_dtype == at::ScalarType::Float8_e4m3fn,
|
532 |
+
"FlashAttention only supports fp16, bf16, and fp8_e4m3 data type");
|
533 |
+
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
|
534 |
+
|
535 |
+
// Reset the parameters
|
536 |
+
Flash_fwd_params params{};
|
537 |
+
params.is_bf16 = qkv_dtype == at::ScalarType::BFloat16;
|
538 |
+
params.is_e4m3 = qkv_dtype == at::ScalarType::Float8_e4m3fn;
|
539 |
+
params.b = batch_size;
|
540 |
+
params.seqlen_q = max_seqlen_q;
|
541 |
+
params.seqlen_k = max_seqlen_k;
|
542 |
+
params.h = num_heads;
|
543 |
+
params.h_k = num_heads_k;
|
544 |
+
params.d = headdim;
|
545 |
+
params.dv = headdim_v;
|
546 |
+
params.d_rounded = round_up_headdim(headdim);
|
547 |
+
params.dv_rounded = headdim_v == headdim ? params.d_rounded : round_up_headdimv(headdim_v);
|
548 |
+
params.seqlen_knew = max_seqlen_k_new;
|
549 |
+
|
550 |
+
bool const is_varlen_q = cu_seqlens_q_.has_value();
|
551 |
+
params.cu_seqlens_q = is_varlen_q ? cu_seqlens_q_.value().data_ptr<int>() : nullptr;
|
552 |
+
bool const is_varlen_k = cu_seqlens_k_.has_value();
|
553 |
+
params.cu_seqlens_k = is_varlen_k ? cu_seqlens_k_.value().data_ptr<int>() : nullptr;
|
554 |
+
params.cu_seqlens_knew = cu_seqlens_k_new_.has_value() ? cu_seqlens_k_new_.value().data_ptr<int>() : nullptr;
|
555 |
+
params.seqused_q = seqused_q_.has_value() ? seqused_q_.value().data_ptr<int>() : nullptr;
|
556 |
+
params.seqused_k = seqused_k.data_ptr<int>();
|
557 |
+
params.leftpad_k = leftpad_k_.has_value() ? leftpad_k_.value().data_ptr<int>() : nullptr;
|
558 |
+
params.knew_ptr = params.seqlen_knew > 0 ? reinterpret_cast<int*>(1) : nullptr;
|
559 |
+
if (window_size_left >= max_seqlen_k - 1) { window_size_left = -1; }
|
560 |
+
if (window_size_right >= max_seqlen_q - 1) { window_size_right = -1; }
|
561 |
+
// causal=true is the same as causal=false in this case
|
562 |
+
if (max_seqlen_q == 1 && window_size_left == -1 && window_size_right == -1 && attention_chunk == 0) {
|
563 |
+
// Special case of hdim 128 where we want causal to have kBlockN=128, better for pagedKV and TMA
|
564 |
+
if ((headdim <= 64 || headdim > 128) || !page_size.has_value()) {
|
565 |
+
is_causal = false;
|
566 |
+
}
|
567 |
+
}
|
568 |
+
if (is_causal) { window_size_right = 0; }
|
569 |
+
|
570 |
+
params.is_causal = window_size_left < 0 && window_size_right == 0 && attention_chunk == 0;
|
571 |
+
params.is_local = (window_size_left >= 0 || window_size_right >= 0 || attention_chunk >= 1) && !params.is_causal;
|
572 |
+
if (window_size_left < 0) { window_size_left = max_seqlen_k - 1; }
|
573 |
+
if (window_size_right < 0) { window_size_right = max_seqlen_q - 1; }
|
574 |
+
if (attention_chunk > 0) {
|
575 |
+
window_size_left = std::min(window_size_left, attention_chunk - 1);
|
576 |
+
window_size_right = std::min(window_size_right, attention_chunk - 1);
|
577 |
+
}
|
578 |
+
params.window_size_left = window_size_left;
|
579 |
+
params.window_size_right = window_size_right;
|
580 |
+
params.attention_chunk = attention_chunk;
|
581 |
+
params.arch = at::cuda::getCurrentDeviceProperties()->major * 10 + at::cuda::getCurrentDeviceProperties()->minor;
|
582 |
+
params.num_sm = at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin;
|
583 |
+
params.softcap = has_softcap ? 1.0f : 0.0f;
|
584 |
+
|
585 |
+
params.page_size = page_size.has_value() ? page_size.value() : 1;
|
586 |
+
params.page_table = !page_size.has_value() ? nullptr : reinterpret_cast<int*>(1);
|
587 |
+
|
588 |
+
bool const use_dynamic_split = params.b <= 992;
|
589 |
+
params.num_splits_dynamic_ptr = !use_dynamic_split ? nullptr : reinterpret_cast<int*>(1);
|
590 |
+
|
591 |
+
params.pagedkv_tma = get_pagedkv_tma(params);
|
592 |
+
params.num_splits = num_splits <= 0 ? get_num_splits(params) : num_splits;
|
593 |
+
// Always enable PackGQA for Split, and get_pack_gqa requires params.num_splits to decide
|
594 |
+
params.pack_gqa = pack_gqa_.has_value() ? pack_gqa_.value() : get_pack_gqa(params);
|
595 |
+
|
596 |
+
bool is_varlen = true;
|
597 |
+
|
598 |
+
// Otherwise the kernel will be launched from cuda:0 device
|
599 |
+
// Cast to char to avoid compiler warning about narrowing
|
600 |
+
at::cuda::CUDAGuard device_guard{(char)seqused_k.get_device()};
|
601 |
+
|
602 |
+
auto opts = seqused_k.options();
|
603 |
+
// This needs to be set after get_num_splits
|
604 |
+
at::Tensor tile_count_semaphore; // Contains the semaphore and optionally num_splits_dynamic
|
605 |
+
bool const scheduler_needs_semaphore = params.arch >= 90 || params.num_splits > 1;
|
606 |
+
if (scheduler_needs_semaphore || use_dynamic_split) {
|
607 |
+
tile_count_semaphore = torch::empty({int(scheduler_needs_semaphore) + int(use_dynamic_split) * params.b}, opts.dtype(torch::kInt32));
|
608 |
+
if (scheduler_needs_semaphore) {
|
609 |
+
if (!use_dynamic_split) { tile_count_semaphore.zero_(); } // If varlen we'll manually do the zero-ing
|
610 |
+
params.tile_count_semaphore = tile_count_semaphore.data_ptr<int>();
|
611 |
+
} else {
|
612 |
+
params.tile_count_semaphore = nullptr;
|
613 |
+
}
|
614 |
+
params.num_splits_dynamic_ptr = use_dynamic_split ? tile_count_semaphore.data_ptr<int>() + 1 : nullptr;
|
615 |
+
}
|
616 |
+
|
617 |
+
if (params.num_splits_dynamic_ptr) {
|
618 |
+
auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table && !params.pagedkv_tma, params.softcap > 0.f);
|
619 |
+
auto kBlockMN_kernel_args_sm8x = tile_size_fwd_sm8x(params.arch == 86 || params.arch == 89, params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, params.page_table, is_varlen && params.num_splits > 1, params.softcap > 0.f, params.knew_ptr);
|
620 |
+
int const kBlockM = params.arch >= 90 ? std::get<0>(kBlockMN_kernel_args_sm90) : std::get<0>(kBlockMN_kernel_args_sm8x);
|
621 |
+
int const kBlockN = params.arch >= 90 ? std::get<1>(kBlockMN_kernel_args_sm90) : std::get<1>(kBlockMN_kernel_args_sm8x);
|
622 |
+
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
623 |
+
prepare_varlen_num_blocks(params, stream, params.pack_gqa, kBlockM, kBlockN, false /*enable_pdl*/);
|
624 |
+
CHECK_CUDA_KERNEL_LAUNCH();
|
625 |
+
}
|
626 |
+
return tile_count_semaphore;
|
627 |
+
}
|
628 |
+
|
629 |
+
// b: batch_size
|
630 |
+
// b_k: batch_size_k
|
631 |
+
// s_q: seqlen_q
|
632 |
+
// s_k: seqlen_k
|
633 |
+
// s_k_new: seqlen_k_new
|
634 |
+
// h: num_heads
|
635 |
+
// h_k: num_heads_k
|
636 |
+
// d: head_size
|
637 |
+
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
|
638 |
+
mha_fwd(at::Tensor q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q
|
639 |
+
at::Tensor k, // (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size, h_k, d) if there is page_table.
|
640 |
+
at::Tensor v, // (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages, page_size, h_k, dv) if there is page_table.
|
641 |
+
std::optional<at::Tensor> k_new_, // (b, s_k_new, h_k, d) or (total_k_new, h_k, d) if there is cu_seqlens_k_new
|
642 |
+
std::optional<at::Tensor> v_new_, // (b, s_k_new, h_k, dv) or (total_k_new, h_k, dv) if there is cu_seqlens_k_new
|
643 |
+
std::optional<at::Tensor> q_v_, // (b, s_q, h, dv) or (total_q_new, h, dv) if there is cu_seqlens_q
|
644 |
+
std::optional<at::Tensor> out_, // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q
|
645 |
+
std::optional<at::Tensor> cu_seqlens_q_, // b+1
|
646 |
+
std::optional<at::Tensor> cu_seqlens_k_, // b+1
|
647 |
+
std::optional<at::Tensor> cu_seqlens_k_new_, // b+1
|
648 |
+
std::optional<at::Tensor> seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used.
|
649 |
+
std::optional<at::Tensor> seqused_k_, // b. If given, only this many elements of each batch element's keys are used.
|
650 |
+
std::optional<int64_t> max_seqlen_q_,
|
651 |
+
// TODO: check if we need max_seqlen_k
|
652 |
+
std::optional<int64_t> max_seqlen_k_,
|
653 |
+
std::optional<at::Tensor> page_table_, // (b_k, max_num_pages_per_seq)
|
654 |
+
std::optional<at::Tensor> kv_batch_idx_, // b. indices to index into the KV cache
|
655 |
+
std::optional<at::Tensor> leftpad_k_, // b
|
656 |
+
std::optional<at::Tensor> rotary_cos_, // seqlen_ro x (rotary_dim / 2)
|
657 |
+
std::optional<at::Tensor> rotary_sin_, // seqlen_ro x (rotary_dim / 2)
|
658 |
+
std::optional<at::Tensor> seqlens_rotary_, // b
|
659 |
+
std::optional<at::Tensor> q_descale_, // (b, h_k), not (b, h)
|
660 |
+
std::optional<at::Tensor> k_descale_, // (b, h_k)
|
661 |
+
std::optional<at::Tensor> v_descale_, // (b, h_k)
|
662 |
+
std::optional<double> softmax_scale_,
|
663 |
+
bool is_causal,
|
664 |
+
int64_t window_size_left,
|
665 |
+
int64_t window_size_right,
|
666 |
+
int64_t attention_chunk,
|
667 |
+
double softcap,
|
668 |
+
bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
|
669 |
+
std::optional<at::Tensor> scheduler_metadata_, // (b + 1)
|
670 |
+
int64_t num_splits,
|
671 |
+
std::optional<bool> pack_gqa_,
|
672 |
+
int64_t sm_margin
|
673 |
+
) {
|
674 |
+
|
675 |
+
auto dprops = at::cuda::getCurrentDeviceProperties();
|
676 |
+
bool is_sm8x = dprops->major >= 8;
|
677 |
+
TORCH_CHECK(is_sm8x, "FlashAttention only supports Ampere GPUs or newer.");
|
678 |
+
|
679 |
+
auto q_type = q.scalar_type();
|
680 |
+
TORCH_CHECK(q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16 || q_type == at::ScalarType::Float8_e4m3fn,
|
681 |
+
"FlashAttention only supports fp16, bf16, and fp8_e4m3 data type");
|
682 |
+
if (dprops->major < 9) {
|
683 |
+
TORCH_CHECK(q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16,
|
684 |
+
"FlashAttention on Ampere/Ada cards only supports fp16 and bf16 data type");
|
685 |
+
}
|
686 |
+
TORCH_CHECK(k.scalar_type() == q_type, "query and key must have the same dtype");
|
687 |
+
TORCH_CHECK(v.scalar_type() == q_type, "query and value must have the same dtype");
|
688 |
+
|
689 |
+
CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
|
690 |
+
|
691 |
+
TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
692 |
+
TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
693 |
+
TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
694 |
+
|
695 |
+
at::Tensor page_table;
|
696 |
+
const bool paged_KV = page_table_.has_value();
|
697 |
+
if (paged_KV) {
|
698 |
+
page_table = page_table_.value();
|
699 |
+
CHECK_DEVICE(page_table);
|
700 |
+
TORCH_CHECK(page_table.dtype() == torch::kInt32, "page_table must have dtype torch.int32");
|
701 |
+
TORCH_CHECK(page_table.stride(-1) == 1, "page_table must have contiguous last dimension");
|
702 |
+
}
|
703 |
+
|
704 |
+
at::Tensor cu_seqlens_q;
|
705 |
+
bool const is_varlen_q = cu_seqlens_q_.has_value();
|
706 |
+
if (is_varlen_q) {
|
707 |
+
cu_seqlens_q = cu_seqlens_q_.value();
|
708 |
+
CHECK_DEVICE(cu_seqlens_q); CHECK_CONTIGUOUS(cu_seqlens_q);
|
709 |
+
TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype torch.int32");
|
710 |
+
TORCH_CHECK(max_seqlen_q_.has_value(), "max_seqlen_q must be provided if cu_seqlens_q is provided");
|
711 |
+
}
|
712 |
+
at::Tensor cu_seqlens_k;
|
713 |
+
bool const is_varlen_k = cu_seqlens_k_.has_value();
|
714 |
+
if (is_varlen_k) {
|
715 |
+
cu_seqlens_k = cu_seqlens_k_.value();
|
716 |
+
CHECK_DEVICE(cu_seqlens_k); CHECK_CONTIGUOUS(cu_seqlens_k);
|
717 |
+
TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype torch.int32");
|
718 |
+
TORCH_CHECK(max_seqlen_k_.has_value(), "max_seqlen_k must be provided if cu_seqlens_k is provided");
|
719 |
+
TORCH_CHECK(!paged_KV, "If cu_seqlens_k is passed in, then page table is not supported");
|
720 |
+
TORCH_CHECK(!kv_batch_idx_.has_value(), "If cu_seqlens_k is passed in, then page table is not supported");
|
721 |
+
}
|
722 |
+
|
723 |
+
auto const sizes = q.sizes();
|
724 |
+
const int batch_size = !is_varlen_q ? sizes[0] : cu_seqlens_q.size(0) - 1;
|
725 |
+
int seqlen_q = !is_varlen_q ? sizes[1] : max_seqlen_q_.value();
|
726 |
+
int total_q = !is_varlen_q ? batch_size * sizes[1] : sizes[0];
|
727 |
+
int num_heads = q.size(-2);
|
728 |
+
int const head_size = q.size(-1);
|
729 |
+
int const head_size_v = v.size(-1);
|
730 |
+
int const max_num_pages_per_seq = !paged_KV ? 0 : page_table.size(1);
|
731 |
+
int const num_pages = !paged_KV ? 0 : k.size(0);
|
732 |
+
int const page_size = !paged_KV ? 1 : k.size(1);
|
733 |
+
int const seqlen_k = !is_varlen_k ? (!paged_KV ? k.size(1) : max_num_pages_per_seq * page_size) : max_seqlen_k_.value();
|
734 |
+
int const total_k = !is_varlen_k ? batch_size * k.size(1) : k.size(0);
|
735 |
+
int const num_heads_k = k.size(-2);
|
736 |
+
int const batch_size_k = !paged_KV ? (!is_varlen_k ? k.size(0) : cu_seqlens_k.size(0) - 1) : page_table.size(0);
|
737 |
+
double softmax_scale = 1.0 / sqrt(double(head_size));
|
738 |
+
if (softmax_scale_.has_value()) {
|
739 |
+
softmax_scale = softmax_scale_.value();
|
740 |
+
}
|
741 |
+
if (!kv_batch_idx_.has_value()) {
|
742 |
+
TORCH_CHECK(batch_size == batch_size_k, "batch_size must be equal to batch_size_k");
|
743 |
+
}
|
744 |
+
int const max_headdim = get_max_headdim();
|
745 |
+
TORCH_CHECK(head_size <= max_headdim, "FlashAttention forward only supports head dimension at most " + std::to_string(max_headdim));
|
746 |
+
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
|
747 |
+
if (head_size_v != head_size) {
|
748 |
+
TORCH_CHECK((head_size > 128 && head_size <= 192 && head_size_v > 96 && head_size_v <= 128) ||
|
749 |
+
(head_size <= 64 && head_size_v <= 512),
|
750 |
+
"If V headdim is different from Q/K dim, we only support Q/K headdim in (128, 192] and V headdim in (96, 128], "
|
751 |
+
"or (Q/K <= 64 and V <= 512).");
|
752 |
+
TORCH_CHECK(dprops->major == 9, "Only Hopper supports different V headdim");
|
753 |
+
if (head_size_v > 256) {
|
754 |
+
TORCH_CHECK(q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16,
|
755 |
+
"HeaddimV > 256 requires fp16 and bf16 data type");
|
756 |
+
}
|
757 |
+
}
|
758 |
+
|
759 |
+
// This needs to go before kBlockM & kBlockN since we rely on the correct window_size and is_causal to set kBlockM
|
760 |
+
// TODO: check this
|
761 |
+
if (window_size_left >= seqlen_k - 1) { window_size_left = -1; }
|
762 |
+
if (window_size_right >= seqlen_q - 1) { window_size_right = -1; }
|
763 |
+
// causal=true is the same as causal=false in this case
|
764 |
+
if (seqlen_q == 1 && window_size_left == -1 && window_size_right == -1 && attention_chunk == 0) {
|
765 |
+
// Special case of hdim 128 where we want causal to have kBlockN=128, better for pagedKV and TMA
|
766 |
+
if ((head_size <= 64 || head_size > 128) || !paged_KV) {
|
767 |
+
is_causal = false;
|
768 |
+
}
|
769 |
+
}
|
770 |
+
if (is_causal) { window_size_right = 0; }
|
771 |
+
|
772 |
+
if (!is_varlen_q) {
|
773 |
+
CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size);
|
774 |
+
} else {
|
775 |
+
CHECK_SHAPE(q, total_q, num_heads, head_size);
|
776 |
+
CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
|
777 |
+
}
|
778 |
+
if (!paged_KV) {
|
779 |
+
if (!is_varlen_k) {
|
780 |
+
CHECK_SHAPE(k, batch_size_k, seqlen_k, num_heads_k, head_size);
|
781 |
+
CHECK_SHAPE(v, batch_size_k, seqlen_k, num_heads_k, head_size_v);
|
782 |
+
} else {
|
783 |
+
CHECK_SHAPE(k, total_k, num_heads_k, head_size);
|
784 |
+
CHECK_SHAPE(v, total_k, num_heads_k, head_size_v);
|
785 |
+
CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
|
786 |
+
}
|
787 |
+
} else {
|
788 |
+
CHECK_SHAPE(k, num_pages, page_size, num_heads_k, head_size);
|
789 |
+
CHECK_SHAPE(v, num_pages, page_size, num_heads_k, head_size_v);
|
790 |
+
CHECK_SHAPE(page_table, batch_size_k, max_num_pages_per_seq);
|
791 |
+
}
|
792 |
+
|
793 |
+
if (seqused_q_.has_value()){
|
794 |
+
auto seqused_q = seqused_q_.value();
|
795 |
+
TORCH_CHECK(seqused_q.dtype() == torch::kInt32, "seqused_q must have dtype int32");
|
796 |
+
CHECK_DEVICE(seqused_q); CHECK_CONTIGUOUS(seqused_q);
|
797 |
+
CHECK_SHAPE(seqused_q, batch_size);
|
798 |
+
}
|
799 |
+
if (seqused_k_.has_value()) {
|
800 |
+
auto seqused_k = seqused_k_.value();
|
801 |
+
TORCH_CHECK(seqused_k.dtype() == torch::kInt32, "seqused_k must have dtype int32");
|
802 |
+
CHECK_DEVICE(seqused_k); CHECK_CONTIGUOUS(seqused_k);
|
803 |
+
CHECK_SHAPE(seqused_k, batch_size);
|
804 |
+
}
|
805 |
+
|
806 |
+
if (leftpad_k_.has_value()) {
|
807 |
+
auto leftpad_k = leftpad_k_.value();
|
808 |
+
TORCH_CHECK(leftpad_k.dtype() == torch::kInt32, "leftpad_k must have dtype int32");
|
809 |
+
CHECK_DEVICE(leftpad_k); CHECK_CONTIGUOUS(leftpad_k);
|
810 |
+
CHECK_SHAPE(leftpad_k, batch_size);
|
811 |
+
}
|
812 |
+
|
813 |
+
// This is what we will template on
|
814 |
+
bool const is_varlen = is_varlen_q || is_varlen_k || seqused_q_.has_value() || seqused_k_.has_value() || leftpad_k_.has_value();
|
815 |
+
#ifdef FLASHATTENTION_DISABLE_VARLEN
|
816 |
+
TORCH_CHECK(!is_varlen, "This flash attention build does not support varlen.");
|
817 |
+
#endif
|
818 |
+
|
819 |
+
int const alignment = q_type == torch::kFloat8_e4m3fn ? 16 : 8;
|
820 |
+
TORCH_CHECK(head_size % alignment == 0, "head_size should be a multiple of " + std::to_string(alignment));
|
821 |
+
TORCH_CHECK(head_size_v % alignment == 0, "head_size_v should be a multiple of " + std::to_string(alignment));
|
822 |
+
|
823 |
+
auto opts = q.options();
|
824 |
+
auto out_type = q_type == at::ScalarType::Float8_e4m3fn ? at::ScalarType::BFloat16 : q_type;
|
825 |
+
at::Tensor out;
|
826 |
+
if (out_.has_value()) {
|
827 |
+
out = out_.value();
|
828 |
+
TORCH_CHECK(out.scalar_type() == out_type, "For FP16/BF16 input, output must have the same dtype as inputs. For FP8 input, output must have dtype BF16");
|
829 |
+
CHECK_DEVICE(out);
|
830 |
+
TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
|
831 |
+
if (!is_varlen_q) {
|
832 |
+
CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_v);
|
833 |
+
} else {
|
834 |
+
CHECK_SHAPE(out, total_q, num_heads, head_size_v);
|
835 |
+
}
|
836 |
+
} else {
|
837 |
+
out = !is_varlen_q
|
838 |
+
? torch::empty({batch_size, seqlen_q, num_heads, head_size_v}, opts.dtype(out_type))
|
839 |
+
: torch::empty({total_q, num_heads, head_size_v}, opts.dtype(out_type));
|
840 |
+
}
|
841 |
+
|
842 |
+
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
|
843 |
+
int const head_size_rounded = round_up_headdim(head_size);
|
844 |
+
int const head_size_v_rounded = head_size_v == head_size ? head_size_rounded : round_up_headdimv(head_size_v);
|
845 |
+
int const seqlen_q_rounded = round_multiple(seqlen_q, 128);
|
846 |
+
int const seqlen_k_rounded = round_multiple(seqlen_k, 128);
|
847 |
+
|
848 |
+
// Otherwise the kernel will be launched from cuda:0 device
|
849 |
+
// Cast to char to avoid compiler warning about narrowing
|
850 |
+
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
|
851 |
+
|
852 |
+
at::Tensor softmax_lse;
|
853 |
+
if (!is_varlen_q) {
|
854 |
+
softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
|
855 |
+
} else {
|
856 |
+
softmax_lse = torch::empty({num_heads, total_q}, opts.dtype(at::kFloat));
|
857 |
+
}
|
858 |
+
|
859 |
+
Flash_fwd_params params;
|
860 |
+
set_params_fprop(params,
|
861 |
+
batch_size,
|
862 |
+
seqlen_q, seqlen_k,
|
863 |
+
seqlen_q_rounded, seqlen_k_rounded,
|
864 |
+
num_heads, num_heads_k,
|
865 |
+
head_size, head_size_rounded,
|
866 |
+
q, k, v, out,
|
867 |
+
!is_varlen_q ? nullptr : cu_seqlens_q.data_ptr(),
|
868 |
+
!is_varlen_k ? nullptr : cu_seqlens_k.data_ptr(),
|
869 |
+
seqused_q_.has_value() ? seqused_q_.value().data_ptr() : nullptr,
|
870 |
+
seqused_k_.has_value() ? seqused_k_.value().data_ptr() : nullptr,
|
871 |
+
softmax_lse.data_ptr(),
|
872 |
+
/*p_dropout=*/0.f,
|
873 |
+
softmax_scale,
|
874 |
+
window_size_left,
|
875 |
+
window_size_right,
|
876 |
+
attention_chunk,
|
877 |
+
softcap,
|
878 |
+
sm_margin);
|
879 |
+
params.total_q = total_q;
|
880 |
+
params.total_k = total_k;
|
881 |
+
params.b_k = batch_size_k;
|
882 |
+
params.dv = head_size_v;
|
883 |
+
params.dv_rounded = head_size_v_rounded;
|
884 |
+
if (leftpad_k_.has_value()) { // This needs to be set before get_pagedkv_tma
|
885 |
+
params.leftpad_k = static_cast<int *>(leftpad_k_.value().data_ptr());
|
886 |
+
}
|
887 |
+
if (paged_KV) {
|
888 |
+
params.page_table = page_table.data_ptr<int>();
|
889 |
+
params.page_table_batch_stride = page_table.stride(0);
|
890 |
+
}
|
891 |
+
params.page_size = page_size;
|
892 |
+
params.num_pages = num_pages;
|
893 |
+
|
894 |
+
if (k_new_.has_value()) { // This needs to be set before get_pagedkv_tma
|
895 |
+
at::Tensor k_new, v_new;
|
896 |
+
TORCH_CHECK(v_new_.has_value(), "If k_new is supplied, v_new must also be passed in");
|
897 |
+
TORCH_CHECK(seqused_k_.has_value(), "If k_new is supplied, seqlens_k must also be passed in");
|
898 |
+
TORCH_CHECK(seqlen_q <= seqlen_k, "If k_new is supplied, it must have seqlen <= the seqlen of the KV cache");
|
899 |
+
at::Tensor cu_seqlens_k_new;
|
900 |
+
bool const is_varlen_k_new = cu_seqlens_k_new_.has_value();
|
901 |
+
if (is_varlen_k_new) {
|
902 |
+
cu_seqlens_k_new = cu_seqlens_k_new_.value();
|
903 |
+
CHECK_DEVICE(cu_seqlens_k_new); CHECK_CONTIGUOUS(cu_seqlens_k_new);
|
904 |
+
TORCH_CHECK(cu_seqlens_k_new.dtype() == torch::kInt32, "cu_seqlens_k_new must have dtype torch.int32");
|
905 |
+
}
|
906 |
+
k_new = k_new_.value();
|
907 |
+
v_new = v_new_.value();
|
908 |
+
TORCH_CHECK(k_new.dtype() == q_type, "k_new must have the same dtype as query");
|
909 |
+
TORCH_CHECK(v_new.dtype() == q_type, "v_new must have the same dtype as query");
|
910 |
+
CHECK_DEVICE(k_new); CHECK_DEVICE(v_new);
|
911 |
+
TORCH_CHECK(k_new.stride(-1) == 1, "k_new tensor must have contiguous last dimension");
|
912 |
+
TORCH_CHECK(v_new.stride(-1) == 1, "v_new tensor must have contiguous last dimension");
|
913 |
+
// We don't need max_seqlen_k_new, so seqlen_k_new can be whatever when is_varlen_k_new
|
914 |
+
int seqlen_k_new = !is_varlen_k_new ? k_new.size(1) : 0;
|
915 |
+
int total_k_new = !is_varlen_k_new ? batch_size * k_new.size(1): k_new.size(0);
|
916 |
+
if (!is_varlen_k_new) {
|
917 |
+
CHECK_SHAPE(k_new, batch_size, seqlen_k_new, num_heads_k, head_size);
|
918 |
+
CHECK_SHAPE(v_new, batch_size, seqlen_k_new, num_heads_k, head_size_v);
|
919 |
+
} else {
|
920 |
+
CHECK_SHAPE(k_new, total_k_new, num_heads_k, head_size);
|
921 |
+
CHECK_SHAPE(v_new, total_k_new, num_heads_k, head_size_v);
|
922 |
+
CHECK_SHAPE(cu_seqlens_k_new, batch_size + 1);
|
923 |
+
}
|
924 |
+
params.seqlen_knew = seqlen_k_new;
|
925 |
+
params.total_knew = total_k_new;
|
926 |
+
params.knew_ptr = k_new.data_ptr();
|
927 |
+
params.vnew_ptr = v_new.data_ptr();
|
928 |
+
// All stride are in elements, not bytes.
|
929 |
+
params.knew_row_stride = k_new.stride(-3);
|
930 |
+
params.vnew_row_stride = v_new.stride(-3);
|
931 |
+
params.knew_head_stride = k_new.stride(-2);
|
932 |
+
params.vnew_head_stride = v_new.stride(-2);
|
933 |
+
if (!is_varlen_k_new) {
|
934 |
+
params.knew_batch_stride = k_new.stride(0);
|
935 |
+
params.vnew_batch_stride = v_new.stride(0);
|
936 |
+
}
|
937 |
+
if (is_varlen_k_new) {
|
938 |
+
params.cu_seqlens_knew = static_cast<int*>(cu_seqlens_k_new.data_ptr());
|
939 |
+
}
|
940 |
+
}
|
941 |
+
|
942 |
+
// 992 = 32 * 31 is the max supported batch in prepare_varlen_num_blocks kernel
|
943 |
+
bool const use_dynamic_split = is_varlen && params.b <= 992;
|
944 |
+
// Temporarily set num_splits_dynamic_ptr to 1 since get_num_splits checks it
|
945 |
+
params.num_splits_dynamic_ptr = !use_dynamic_split ? nullptr : reinterpret_cast<int*>(1);
|
946 |
+
|
947 |
+
params.pagedkv_tma = get_pagedkv_tma(params);
|
948 |
+
params.num_splits = num_splits <= 0 ? get_num_splits(params) : num_splits;
|
949 |
+
// Always enable PackGQA for Split, and get_pack_gqa requires params.num_splits to decide
|
950 |
+
params.pack_gqa = pack_gqa_.has_value() ? pack_gqa_.value() : get_pack_gqa(params);
|
951 |
+
|
952 |
+
// This needs to be set after get_num_splits
|
953 |
+
at::Tensor tile_count_semaphore; // Contains the semaphore and optionally num_splits_dynamic
|
954 |
+
// We don't use the persistent scheduler if Split and not Varlen
|
955 |
+
bool const scheduler_needs_semaphore = params.arch >= 90
|
956 |
+
? (((params.is_causal || params.is_local) && (params.num_splits == 1)) || is_varlen)
|
957 |
+
: ((params.is_causal && !is_varlen) || (is_varlen && params.num_splits > 1));
|
958 |
+
if (scheduler_needs_semaphore || use_dynamic_split) {
|
959 |
+
int metadata_size = int(scheduler_needs_semaphore) + int(use_dynamic_split) * params.b;
|
960 |
+
params.skip_scheduler_metadata_computation = scheduler_metadata_.has_value();
|
961 |
+
if (scheduler_metadata_.has_value()) {
|
962 |
+
at::Tensor scheduler_metadata = scheduler_metadata_.value();
|
963 |
+
CHECK_DEVICE(scheduler_metadata);
|
964 |
+
CHECK_SHAPE(scheduler_metadata, metadata_size);
|
965 |
+
CHECK_CONTIGUOUS(scheduler_metadata);
|
966 |
+
TORCH_CHECK(scheduler_metadata.dtype() == torch::kInt32, "scheduler_metadata must have dtype int32");
|
967 |
+
tile_count_semaphore = scheduler_metadata;
|
968 |
+
} else {
|
969 |
+
tile_count_semaphore = torch::empty({metadata_size}, opts.dtype(torch::kInt32));
|
970 |
+
}
|
971 |
+
if (scheduler_needs_semaphore && !use_dynamic_split) {
|
972 |
+
tile_count_semaphore.zero_(); // If varlen we'll manually do the zero-ing
|
973 |
+
}
|
974 |
+
params.tile_count_semaphore = scheduler_needs_semaphore ? tile_count_semaphore.data_ptr<int>() : nullptr;
|
975 |
+
params.num_splits_dynamic_ptr = use_dynamic_split ? tile_count_semaphore.data_ptr<int>() + 1 : nullptr;
|
976 |
+
}
|
977 |
+
|
978 |
+
if (q_v_.has_value()) {
|
979 |
+
TORCH_CHECK(head_size <= 64, "q_v is only supported for head_size <= 64");
|
980 |
+
TORCH_CHECK(q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16,
|
981 |
+
"q_v is only supported for fp16 and bf16 data type");
|
982 |
+
TORCH_CHECK(params.arch == 90, "q_v is only supported for Hopper GPUs");
|
983 |
+
at::Tensor q_v = q_v_.value();
|
984 |
+
TORCH_CHECK(q_v.dtype() == q_type, "q_v must have the same dtype as query");
|
985 |
+
CHECK_DEVICE(q_v);
|
986 |
+
TORCH_CHECK(q_v.stride(-1) == 1, "q_v tensor must have contiguous last dimension");
|
987 |
+
if (!is_varlen_q) {
|
988 |
+
CHECK_SHAPE(q_v, batch_size, seqlen_q, num_heads, head_size_v);
|
989 |
+
} else {
|
990 |
+
CHECK_SHAPE(q_v, total_q, num_heads, head_size_v);
|
991 |
+
}
|
992 |
+
params.qv_ptr = q_v.data_ptr();
|
993 |
+
// All stride are in elements, not bytes.
|
994 |
+
params.qv_row_stride = q_v.stride(-3);
|
995 |
+
params.qv_head_stride = q_v.stride(-2);
|
996 |
+
if (!is_varlen_q) {
|
997 |
+
params.qv_batch_stride = q_v.stride(0);
|
998 |
+
}
|
999 |
+
}
|
1000 |
+
|
1001 |
+
if (rotary_cos_.has_value()) {
|
1002 |
+
TORCH_CHECK(k_new_.has_value(), "If rotary cos/sin are provided, new key / value to be appended to KV cache must also be provided");
|
1003 |
+
auto rotary_cos = rotary_cos_.value();
|
1004 |
+
CHECK_DEVICE(rotary_cos); CHECK_CONTIGUOUS(rotary_cos);
|
1005 |
+
params.rotary_dim = rotary_cos.size(1) * 2;
|
1006 |
+
TORCH_CHECK(params.rotary_dim <= head_size, "rotary_dim must be <= headdim");
|
1007 |
+
TORCH_CHECK(params.rotary_dim % 16 == 0, "Only rotary dimensions divisible by 16 are currently supported");
|
1008 |
+
const int seqlen_ro = rotary_cos.size(0);
|
1009 |
+
if (paged_KV) {
|
1010 |
+
TORCH_CHECK(seqlen_ro >= seqlen_k, "cos/sin seqlen must be at least the seqlen of KV cache");
|
1011 |
+
}
|
1012 |
+
CHECK_SHAPE(rotary_cos, seqlen_ro, params.rotary_dim / 2);
|
1013 |
+
TORCH_CHECK(rotary_cos.scalar_type() == q_type, "rotary_cos must have the same dtype as query");
|
1014 |
+
|
1015 |
+
TORCH_CHECK(rotary_sin_.has_value(), "If rotary cos is provided, rotary sin must also be provided");
|
1016 |
+
auto rotary_sin = rotary_sin_.value();
|
1017 |
+
CHECK_DEVICE(rotary_sin); CHECK_CONTIGUOUS(rotary_sin);
|
1018 |
+
CHECK_SHAPE(rotary_sin, seqlen_ro, params.rotary_dim / 2);
|
1019 |
+
TORCH_CHECK(rotary_sin.scalar_type() == q_type, "rotary_cos must have the same dtype as query");
|
1020 |
+
params.rotary_cos_ptr = rotary_cos.data_ptr();
|
1021 |
+
params.rotary_sin_ptr = rotary_sin.data_ptr();
|
1022 |
+
params.is_rotary_interleaved = is_rotary_interleaved;
|
1023 |
+
if (seqlens_rotary_.has_value()) {
|
1024 |
+
at::Tensor seqlens_rotary = seqlens_rotary_.value();
|
1025 |
+
CHECK_DEVICE(seqlens_rotary); CHECK_CONTIGUOUS(seqlens_rotary);
|
1026 |
+
TORCH_CHECK(seqlens_rotary.dtype() == torch::kInt32, "seqlens_rotary must have dtype torch.int32");
|
1027 |
+
CHECK_SHAPE(seqlens_rotary, batch_size);
|
1028 |
+
params.seqlens_rotary = seqlens_rotary.data_ptr<int>();
|
1029 |
+
}
|
1030 |
+
} else {
|
1031 |
+
params.rotary_dim = 0;
|
1032 |
+
}
|
1033 |
+
|
1034 |
+
if (kv_batch_idx_.has_value()) {
|
1035 |
+
auto kv_batch_idx = kv_batch_idx_.value();
|
1036 |
+
CHECK_DEVICE(kv_batch_idx); CHECK_CONTIGUOUS(kv_batch_idx);
|
1037 |
+
TORCH_CHECK(kv_batch_idx.scalar_type() == torch::kInt32, "kv_batch_idx must have dtype int32");
|
1038 |
+
params.kv_batch_idx = reinterpret_cast<int *>(kv_batch_idx.data_ptr());
|
1039 |
+
}
|
1040 |
+
|
1041 |
+
at::Tensor out_accum, softmax_lse_accum;
|
1042 |
+
auto outaccum_type = at::ScalarType::Float;
|
1043 |
+
if (params.num_splits > 1) {
|
1044 |
+
TORCH_CHECK(params.num_splits <= 256, "num_splits > 256 not supported");
|
1045 |
+
if (!is_varlen_q) {
|
1046 |
+
out_accum = torch::empty({params.num_splits, batch_size, num_heads, seqlen_q, head_size_v}, opts.dtype(outaccum_type));
|
1047 |
+
softmax_lse_accum = torch::empty({params.num_splits, batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
|
1048 |
+
params.oaccum_batch_stride = out_accum.stride(1);
|
1049 |
+
params.lseaccum_batch_stride = softmax_lse_accum.stride(1);
|
1050 |
+
} else {
|
1051 |
+
out_accum = torch::empty({params.num_splits, num_heads, total_q, head_size_v}, opts.dtype(outaccum_type));
|
1052 |
+
softmax_lse_accum = torch::empty({params.num_splits, num_heads, total_q}, opts.dtype(at::kFloat));
|
1053 |
+
}
|
1054 |
+
params.is_fp32 = false;
|
1055 |
+
params.oaccum_ptr = out_accum.data_ptr();
|
1056 |
+
params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr();
|
1057 |
+
params.oaccum_split_stride = out_accum.stride(0);
|
1058 |
+
params.oaccum_row_stride = out_accum.stride(-2);
|
1059 |
+
params.oaccum_head_stride = out_accum.stride(-3);
|
1060 |
+
params.lseaccum_split_stride = softmax_lse_accum.stride(0);
|
1061 |
+
params.lseaccum_head_stride = softmax_lse_accum.stride(-2);
|
1062 |
+
}
|
1063 |
+
|
1064 |
+
if (q_type == at::ScalarType::Float8_e4m3fn) {
|
1065 |
+
if (q_descale_.has_value()) {
|
1066 |
+
auto q_descale = q_descale_.value();
|
1067 |
+
CHECK_DEVICE(q_descale);
|
1068 |
+
CHECK_SHAPE(q_descale, batch_size, num_heads_k);
|
1069 |
+
params.q_descale_ptr = q_descale.data_ptr<float>();
|
1070 |
+
params.q_descale_batch_stride = q_descale.stride(0);
|
1071 |
+
params.q_descale_head_stride = q_descale.stride(1);
|
1072 |
+
} else {
|
1073 |
+
params.q_descale_ptr = nullptr;
|
1074 |
+
}
|
1075 |
+
if (k_descale_.has_value()) {
|
1076 |
+
auto k_descale = k_descale_.value();
|
1077 |
+
CHECK_DEVICE(k_descale);
|
1078 |
+
CHECK_SHAPE(k_descale, batch_size, num_heads_k);
|
1079 |
+
params.k_descale_ptr = k_descale.data_ptr<float>();
|
1080 |
+
params.k_descale_batch_stride = k_descale.stride(0);
|
1081 |
+
params.k_descale_head_stride = k_descale.stride(1);
|
1082 |
+
} else {
|
1083 |
+
params.k_descale_ptr = nullptr;
|
1084 |
+
}
|
1085 |
+
if (v_descale_.has_value()) {
|
1086 |
+
auto v_descale = v_descale_.value();
|
1087 |
+
CHECK_DEVICE(v_descale);
|
1088 |
+
CHECK_SHAPE(v_descale, batch_size, num_heads_k);
|
1089 |
+
params.v_descale_ptr = v_descale.data_ptr<float>();
|
1090 |
+
params.v_descale_batch_stride = v_descale.stride(0);
|
1091 |
+
params.v_descale_head_stride = v_descale.stride(1);
|
1092 |
+
} else {
|
1093 |
+
params.v_descale_ptr = nullptr;
|
1094 |
+
}
|
1095 |
+
}
|
1096 |
+
|
1097 |
+
#ifdef FLASHATTENTION_DISABLE_LOCAL
|
1098 |
+
TORCH_CHECK(!params.is_local, "This flash attention build does not support local attention.");
|
1099 |
+
#endif
|
1100 |
+
#ifdef FLASHATTENTION_DISABLE_SOFTCAP
|
1101 |
+
TORCH_CHECK(params.softcap == 0.0, "This flash attention build does not support tanh softcapping.");
|
1102 |
+
#endif
|
1103 |
+
#ifdef FLASHATTENTION_DISABLE_SPLIT
|
1104 |
+
TORCH_CHECK(params.num_splits == 1, "This flash attention build does not support splits.");
|
1105 |
+
#endif
|
1106 |
+
#ifdef FLASHATTENTION_DISABLE_PACKGQA
|
1107 |
+
TORCH_CHECK(!params.pack_gqa || params.arch < 90 || (params.page_table && !params.pagedkv_tma) || params.num_splits > 1, "This flash attention build does not support pack_gqa.");
|
1108 |
+
#endif
|
1109 |
+
#ifdef FLASHATTENTION_DISABLE_PAGEDKV
|
1110 |
+
TORCH_CHECK(!(params.page_table && !params.pagedkv_tma), "This flash attention build does not support paged KV.");
|
1111 |
+
#endif
|
1112 |
+
#ifdef FLASHATTENTION_DISABLE_APPENDKV
|
1113 |
+
TORCH_CHECK(!k_new_.has_value(), "This flash attention build does not support appending KV.");
|
1114 |
+
#endif
|
1115 |
+
|
1116 |
+
if (total_q > 0 && (total_k + params.total_knew) > 0 && num_heads_k > 0) {
|
1117 |
+
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
1118 |
+
run_mha_fwd(params, stream);
|
1119 |
+
if (params.num_splits > 1) {
|
1120 |
+
if (out_type == at::ScalarType::BFloat16) {
|
1121 |
+
// Since we want output in BF16. Otherwise fwd_combine will output to FP16
|
1122 |
+
params.is_bf16 = true;
|
1123 |
+
}
|
1124 |
+
// Unless there's seqused_q, for the purpose of attn_combine, we can just treat it as batch=1
|
1125 |
+
// and seqlen = total_q, and don't need to dispatch to Varlen there.
|
1126 |
+
// However, with dynamic split, each row needs to know which batch it belongs to
|
1127 |
+
// to read the number of splits, so we just use the varlen version of combine kernel.
|
1128 |
+
// if (is_varlen_q && !seqused_q_.has_value()) {
|
1129 |
+
// if (is_varlen_q) {
|
1130 |
+
// params.b = 1;
|
1131 |
+
// params.seqlen_q = total_q;
|
1132 |
+
// }
|
1133 |
+
// This will zero out the semaphore if needed
|
1134 |
+
run_mha_fwd_combine(params, stream, true /*enable_pdl*/);
|
1135 |
+
} else if (scheduler_needs_semaphore && params.skip_scheduler_metadata_computation) {
|
1136 |
+
// need to zero out the semaphore in this case
|
1137 |
+
tile_count_semaphore.index({torch::indexing::Slice(0, 1)}).zero_();
|
1138 |
+
}
|
1139 |
+
} else if (total_q > 0 && num_heads_k > 0) {
|
1140 |
+
// If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0.
|
1141 |
+
out.zero_();
|
1142 |
+
softmax_lse.fill_(std::numeric_limits<float>::infinity());
|
1143 |
+
}
|
1144 |
+
|
1145 |
+
// return {out, softmax_lse};
|
1146 |
+
return {out, softmax_lse, out_accum, softmax_lse_accum};
|
1147 |
+
}
|
1148 |
+
|
1149 |
+
#ifdef FLASHATTENTION_DISABLE_BACKWARD
|
1150 |
+
void run_mha_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
1151 |
+
TORCH_CHECK(false, "Flash-Attention was built with backward disabled");
|
1152 |
+
}
|
1153 |
+
#else
|
1154 |
+
template <int Arch, bool Has_softcap>
|
1155 |
+
void run_mha_bwd_constexpr(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
1156 |
+
if (!params.is_bf16) {
|
1157 |
+
#ifndef FLASHATTENTION_DISABLE_FP16
|
1158 |
+
#ifndef FLASHATTENTION_DISABLE_HDIM64
|
1159 |
+
if (params.d_rounded == 64) { return run_mha_bwd_<Arch, cutlass::half_t, 64, Has_softcap>(params, stream); }
|
1160 |
+
#endif
|
1161 |
+
#ifndef FLASHATTENTION_DISABLE_HDIM96
|
1162 |
+
if (params.d_rounded == 96) { return run_mha_bwd_<Arch, cutlass::half_t, 96, Has_softcap>(params, stream); }
|
1163 |
+
#endif
|
1164 |
+
#ifndef FLASHATTENTION_DISABLE_HDIM128
|
1165 |
+
if (params.d_rounded == 128) { return run_mha_bwd_<Arch, cutlass::half_t, 128, Has_softcap>(params, stream); }
|
1166 |
+
#endif
|
1167 |
+
#ifndef FLASHATTENTION_DISABLE_HDIM192
|
1168 |
+
if (params.d_rounded == 192) { return run_mha_bwd_<Arch, cutlass::half_t, 192, Has_softcap>(params, stream); }
|
1169 |
+
#endif
|
1170 |
+
#ifndef FLASHATTENTION_DISABLE_HDIM256
|
1171 |
+
if (params.d_rounded == 256) { return run_mha_bwd_<Arch, cutlass::half_t, 256, Has_softcap>(params, stream); }
|
1172 |
+
#endif
|
1173 |
+
#else
|
1174 |
+
TORCH_CHECK(false, "This flash attention build does not support FP16.");
|
1175 |
+
#endif
|
1176 |
+
} else {
|
1177 |
+
#ifndef FLASHATTENTION_DISABLE_HDIM64
|
1178 |
+
if (params.d_rounded == 64) { return run_mha_bwd_<Arch, cutlass::bfloat16_t, 64, Has_softcap>(params, stream); }
|
1179 |
+
#endif
|
1180 |
+
#ifndef FLASHATTENTION_DISABLE_HDIM96
|
1181 |
+
if (params.d_rounded == 96) { return run_mha_bwd_<Arch, cutlass::bfloat16_t, 96, Has_softcap>(params, stream); }
|
1182 |
+
#endif
|
1183 |
+
#ifndef FLASHATTENTION_DISABLE_HDIM128
|
1184 |
+
if (params.d_rounded == 128) { return run_mha_bwd_<Arch, cutlass::bfloat16_t, 128, Has_softcap>(params, stream); }
|
1185 |
+
#endif
|
1186 |
+
#ifndef FLASHATTENTION_DISABLE_HDIM192
|
1187 |
+
if (params.d_rounded == 192) { return run_mha_bwd_<Arch, cutlass::bfloat16_t, 192, Has_softcap>(params, stream); }
|
1188 |
+
#endif
|
1189 |
+
#ifndef FLASHATTENTION_DISABLE_HDIM256
|
1190 |
+
if (params.d_rounded == 256) { return run_mha_bwd_<Arch, cutlass::bfloat16_t, 256, Has_softcap>(params, stream); }
|
1191 |
+
#endif
|
1192 |
+
}
|
1193 |
+
}
|
1194 |
+
|
1195 |
+
void run_mha_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
1196 |
+
// FP16_SWITCH(!params.is_bf16, [&] {
|
1197 |
+
// HEADDIM_SWITCH(params.d, [&] {
|
1198 |
+
// run_mha_bwd_<elem_type, kHeadDim>(params, stream);
|
1199 |
+
// });
|
1200 |
+
// });
|
1201 |
+
ARCH_SWITCH(params.arch, Arch, [&] {
|
1202 |
+
SOFTCAP_SWITCH(params.softcap > 0.f, Has_softcap, [&] {
|
1203 |
+
run_mha_bwd_constexpr<Arch, Has_softcap>(params, stream);
|
1204 |
+
});
|
1205 |
+
});
|
1206 |
+
}
|
1207 |
+
#endif
|
1208 |
+
|
1209 |
+
|
1210 |
+
// b: batch_size
|
1211 |
+
// s_q: seqlen_q
|
1212 |
+
// s_k: seqlen_k
|
1213 |
+
// h: num_heads
|
1214 |
+
// h_k: num_heads_k
|
1215 |
+
// d: head_size
|
1216 |
+
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor> mha_bwd(
|
1217 |
+
at::Tensor dout, // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q
|
1218 |
+
at::Tensor q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q
|
1219 |
+
at::Tensor k, // (b, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k
|
1220 |
+
at::Tensor v, // (b, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k
|
1221 |
+
at::Tensor out, // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q
|
1222 |
+
at::Tensor softmax_lse, // (b, h, s_q) or (h, total_q) if there is cu_seqlens_q
|
1223 |
+
std::optional<at::Tensor> dq_, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q
|
1224 |
+
std::optional<at::Tensor> dk_, // (b, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k
|
1225 |
+
std::optional<at::Tensor> dv_, // (b, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k
|
1226 |
+
std::optional<at::Tensor> cu_seqlens_q_, // b+1
|
1227 |
+
std::optional<at::Tensor> cu_seqlens_k_, // b+1
|
1228 |
+
std::optional<at::Tensor> seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used.
|
1229 |
+
std::optional<at::Tensor> seqused_k_, // b. If given, only this many elements of each batch element's keys are used.
|
1230 |
+
std::optional<int64_t> max_seqlen_q_,
|
1231 |
+
std::optional<int64_t> max_seqlen_k_,
|
1232 |
+
std::optional<double> softmax_scale_,
|
1233 |
+
bool is_causal,
|
1234 |
+
int64_t window_size_left,
|
1235 |
+
int64_t window_size_right,
|
1236 |
+
double softcap,
|
1237 |
+
bool deterministic,
|
1238 |
+
int64_t sm_margin
|
1239 |
+
) {
|
1240 |
+
|
1241 |
+
#ifdef FLASHATTENTION_DISABLE_BACKWARD
|
1242 |
+
TORCH_CHECK(false, "This flash attention build does not support backward.");
|
1243 |
+
#endif
|
1244 |
+
|
1245 |
+
auto dprops = at::cuda::getCurrentDeviceProperties();
|
1246 |
+
bool is_sm8x = dprops->major >= 8;
|
1247 |
+
TORCH_CHECK(is_sm8x, "FlashAttention only supports Ampere GPUs or newer.");
|
1248 |
+
|
1249 |
+
auto q_type = q.dtype();
|
1250 |
+
TORCH_CHECK(q_type == torch::kFloat16 || q_type == torch::kBFloat16,
|
1251 |
+
"FlashAttention only support fp16 and bf16 data type");
|
1252 |
+
TORCH_CHECK(k.dtype() == q_type, "query and key must have the same dtype");
|
1253 |
+
TORCH_CHECK(v.dtype() == q_type, "query and value must have the same dtype");
|
1254 |
+
TORCH_CHECK(out.dtype() == q_type, "query and out must have the same dtype");
|
1255 |
+
TORCH_CHECK(dout.dtype() == q_type, "query and dout must have the same dtype");
|
1256 |
+
|
1257 |
+
CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
|
1258 |
+
CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse);
|
1259 |
+
|
1260 |
+
TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
1261 |
+
TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
1262 |
+
TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
1263 |
+
TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension");
|
1264 |
+
TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension");
|
1265 |
+
|
1266 |
+
at::Tensor cu_seqlens_q;
|
1267 |
+
bool const is_varlen_q = cu_seqlens_q_.has_value();
|
1268 |
+
if (is_varlen_q) {
|
1269 |
+
cu_seqlens_q = cu_seqlens_q_.value();
|
1270 |
+
CHECK_DEVICE(cu_seqlens_q); CHECK_CONTIGUOUS(cu_seqlens_q);
|
1271 |
+
TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype torch.int32");
|
1272 |
+
TORCH_CHECK(max_seqlen_q_.has_value(), "max_seqlen_q must be provided if cu_seqlens_q is provided");
|
1273 |
+
}
|
1274 |
+
at::Tensor cu_seqlens_k;
|
1275 |
+
bool const is_varlen_k = cu_seqlens_k_.has_value();
|
1276 |
+
if (is_varlen_k) {
|
1277 |
+
cu_seqlens_k = cu_seqlens_k_.value();
|
1278 |
+
CHECK_DEVICE(cu_seqlens_k); CHECK_CONTIGUOUS(cu_seqlens_k);
|
1279 |
+
TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype torch.int32");
|
1280 |
+
TORCH_CHECK(max_seqlen_k_.has_value(), "max_seqlen_k must be provided if cu_seqlens_k is provided");
|
1281 |
+
}
|
1282 |
+
// This is what we will template on
|
1283 |
+
bool const is_varlen = is_varlen_q || is_varlen_k || seqused_q_.has_value() || seqused_k_.has_value();
|
1284 |
+
#ifdef FLASHATTENTION_DISABLE_VARLEN
|
1285 |
+
TORCH_CHECK(!is_varlen, "This flash attention build does not support varlen.");
|
1286 |
+
#endif
|
1287 |
+
|
1288 |
+
auto const sizes = q.sizes();
|
1289 |
+
int const batch_size = !is_varlen_q ? sizes[0] : cu_seqlens_q.size(0) - 1;
|
1290 |
+
int const seqlen_q = !is_varlen_q ? sizes[1] : max_seqlen_q_.value();
|
1291 |
+
int const total_q = !is_varlen_q ? batch_size * sizes[1] : sizes[0];
|
1292 |
+
int const num_heads = q.size(-2);
|
1293 |
+
int const head_size = q.size(-1);
|
1294 |
+
int const head_size_v = v.size(-1);
|
1295 |
+
int const seqlen_k = !is_varlen_k ? k.size(1) : max_seqlen_k_.value();
|
1296 |
+
int const total_k = !is_varlen_k ? batch_size * k.size(1) : k.size(0);
|
1297 |
+
int const num_heads_k = k.size(-2);
|
1298 |
+
TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8");
|
1299 |
+
TORCH_CHECK(head_size_v % 8 == 0, "head_size_v should be a multiple of 8");
|
1300 |
+
int const max_headdim = get_max_headdim();
|
1301 |
+
TORCH_CHECK(std::max(head_size, head_size_v) <= max_headdim, "FlashAttention forward only supports head dimension at most " + std::to_string(max_headdim));
|
1302 |
+
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
|
1303 |
+
double softmax_scale = 1.0 / sqrt(double(head_size));
|
1304 |
+
if (softmax_scale_.has_value()) {
|
1305 |
+
softmax_scale = softmax_scale_.value();
|
1306 |
+
}
|
1307 |
+
|
1308 |
+
// This needs to go before kBlockM & kBlockN since we rely on the correct window_size and is_causal to set kBlockM
|
1309 |
+
if (window_size_left >= seqlen_k - 1) { window_size_left = -1; }
|
1310 |
+
if (window_size_right >= seqlen_q - 1) { window_size_right = -1; }
|
1311 |
+
if (is_causal) { window_size_right = 0; }
|
1312 |
+
// There's a case where is_causal=false, window_size=(-1, 0). Then set_params_bprop will set params.is_causal=true.
|
1313 |
+
// If we don't have is_causal here matching params.is_causal, we might get the wrong kBlockM (and cause IMA).
|
1314 |
+
is_causal = window_size_left < 0 && window_size_right == 0;
|
1315 |
+
|
1316 |
+
int const arch = at::cuda::getCurrentDeviceProperties()->major * 10 + at::cuda::getCurrentDeviceProperties()->minor;
|
1317 |
+
int const head_size_rounded = round_up_headdim(std::max(head_size, head_size_v));
|
1318 |
+
int const head_size_v_rounded = head_size_rounded;
|
1319 |
+
// Very important that these match the kernel configs
|
1320 |
+
bool const is_local = (window_size_left >= 0 || window_size_right >= 0) && !is_causal;
|
1321 |
+
int const kBlockM_sm90 = head_size_rounded <= 64 ? (is_causal && softcap > 0.0 ? 96 : 128)
|
1322 |
+
: (head_size_rounded <= 96 ? 64
|
1323 |
+
: (head_size_rounded <= 128 ? (is_causal || is_local || softcap > 0.0 ? 64 : 80)
|
1324 |
+
: 64));
|
1325 |
+
int const kBlockM_sm80 = head_size_rounded <= 64 ? 128 : 64;
|
1326 |
+
int const kBlockM_sm86 = head_size_rounded <= 192 ? 64 : 32;
|
1327 |
+
int const kBlockM = arch >= 90 ? kBlockM_sm90 : (arch == 86 || arch == 89 ? kBlockM_sm86 : kBlockM_sm80);
|
1328 |
+
int const kBlockN_sm90 = head_size_rounded <= 128
|
1329 |
+
? 128
|
1330 |
+
: (head_size_rounded <= 192 ? 96 : 80);
|
1331 |
+
int const kBlockN_sm80 = head_size_rounded <= 128
|
1332 |
+
? 128
|
1333 |
+
: (head_size_rounded <= 192 ? 80 : 64);
|
1334 |
+
int const kBlockN_sm86 = head_size_rounded <= 64 ? 128
|
1335 |
+
: (head_size_rounded <= 96 ? 128
|
1336 |
+
: (head_size_rounded <= 128 ? 96
|
1337 |
+
: (head_size_rounded <= 192 ? 64 : 64)));
|
1338 |
+
int const kBlockN = arch >= 90 ? kBlockN_sm90 : (arch == 86 || arch == 89 ? kBlockN_sm86 : kBlockN_sm80);
|
1339 |
+
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
|
1340 |
+
int const seqlen_q_rounded = round_multiple(seqlen_q, kBlockM);
|
1341 |
+
int const seqlen_k_rounded = round_multiple(seqlen_k, kBlockN);
|
1342 |
+
int const total_q_padded_rounded = round_multiple(total_q + batch_size * kBlockM, kBlockM);
|
1343 |
+
int const total_k_padded_rounded = round_multiple(total_k + batch_size * kBlockN, kBlockN);
|
1344 |
+
|
1345 |
+
if (!is_varlen_q) {
|
1346 |
+
CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size);
|
1347 |
+
CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_v);
|
1348 |
+
CHECK_SHAPE(dout, batch_size, seqlen_q, num_heads, head_size_v);
|
1349 |
+
} else {
|
1350 |
+
CHECK_SHAPE(q, total_q, num_heads, head_size);
|
1351 |
+
CHECK_SHAPE(out, total_q, num_heads, head_size_v);
|
1352 |
+
CHECK_SHAPE(dout, total_q, num_heads, head_size_v);
|
1353 |
+
CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
|
1354 |
+
}
|
1355 |
+
if (!is_varlen_k) {
|
1356 |
+
CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size);
|
1357 |
+
CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size_v);
|
1358 |
+
} else {
|
1359 |
+
CHECK_SHAPE(k, total_k, num_heads_k, head_size);
|
1360 |
+
CHECK_SHAPE(v, total_k, num_heads_k, head_size_v);
|
1361 |
+
CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
|
1362 |
+
}
|
1363 |
+
|
1364 |
+
if (seqused_q_.has_value()){
|
1365 |
+
auto seqused_q = seqused_q_.value();
|
1366 |
+
TORCH_CHECK(seqused_q.dtype() == torch::kInt32, "seqused_q must have dtype int32");
|
1367 |
+
CHECK_DEVICE(seqused_q); CHECK_CONTIGUOUS(seqused_q);
|
1368 |
+
CHECK_SHAPE(seqused_q, batch_size);
|
1369 |
+
}
|
1370 |
+
if (seqused_k_.has_value()){
|
1371 |
+
auto seqused_k = seqused_k_.value();
|
1372 |
+
TORCH_CHECK(seqused_k.dtype() == torch::kInt32, "seqused_k must have dtype int32");
|
1373 |
+
CHECK_DEVICE(seqused_k); CHECK_CONTIGUOUS(seqused_k);
|
1374 |
+
CHECK_SHAPE(seqused_k, batch_size);
|
1375 |
+
}
|
1376 |
+
|
1377 |
+
at::Tensor dq, dk, dv;
|
1378 |
+
if (dq_.has_value()) {
|
1379 |
+
dq = dq_.value();
|
1380 |
+
TORCH_CHECK(dq.dtype() == q_type, "dq must have the same dtype as q");
|
1381 |
+
CHECK_DEVICE(dq);
|
1382 |
+
TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension");
|
1383 |
+
if (!is_varlen_q) {
|
1384 |
+
CHECK_SHAPE(dq, batch_size, seqlen_q, num_heads, head_size);
|
1385 |
+
} else {
|
1386 |
+
CHECK_SHAPE(dq, total_q, num_heads, head_size);
|
1387 |
+
}
|
1388 |
+
} else {
|
1389 |
+
dq = torch::empty_like(q);
|
1390 |
+
}
|
1391 |
+
if (dk_.has_value()) {
|
1392 |
+
dk = dk_.value();
|
1393 |
+
TORCH_CHECK(dk.dtype() == q_type, "dk must have the same dtype as q");
|
1394 |
+
CHECK_DEVICE(dk);
|
1395 |
+
TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension");
|
1396 |
+
if (!is_varlen_k) {
|
1397 |
+
CHECK_SHAPE(dk, batch_size, seqlen_k, num_heads_k, head_size);
|
1398 |
+
} else {
|
1399 |
+
CHECK_SHAPE(dk, total_k, num_heads_k, head_size);
|
1400 |
+
}
|
1401 |
+
} else {
|
1402 |
+
dk = torch::empty_like(k);
|
1403 |
+
}
|
1404 |
+
if (dv_.has_value()) {
|
1405 |
+
dv = dv_.value();
|
1406 |
+
TORCH_CHECK(dv.dtype() == q_type, "dv must have the same dtype as q");
|
1407 |
+
CHECK_DEVICE(dv);
|
1408 |
+
TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension");
|
1409 |
+
if (!is_varlen_k) {
|
1410 |
+
CHECK_SHAPE(dv, batch_size, seqlen_k, num_heads_k, head_size_v);
|
1411 |
+
} else {
|
1412 |
+
CHECK_SHAPE(dv, total_k, num_heads_k, head_size_v);
|
1413 |
+
}
|
1414 |
+
} else {
|
1415 |
+
dv = torch::empty_like(v);
|
1416 |
+
}
|
1417 |
+
|
1418 |
+
// Otherwise the kernel will be launched from cuda:0 device
|
1419 |
+
// Cast to char to avoid compiler warning about narrowing
|
1420 |
+
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
|
1421 |
+
|
1422 |
+
auto opts = q.options();
|
1423 |
+
// Need softmax_d to have total_q_padded_rounded since we want its address to be aligned by 16/8 bytes for TMA / LDG.64
|
1424 |
+
at::Tensor softmax_d, softmax_lse_log2;
|
1425 |
+
if (!is_varlen) {
|
1426 |
+
// Need softmax_d to have seqlen_q_rounded since we want its address to be aligned by 16/8 bytes for TMA / LDG.64
|
1427 |
+
softmax_d = torch::empty({batch_size, num_heads, seqlen_q_rounded}, opts.dtype(at::kFloat));
|
1428 |
+
softmax_lse_log2 = torch::empty({batch_size, num_heads, seqlen_q_rounded}, opts.dtype(at::kFloat));
|
1429 |
+
} else {
|
1430 |
+
softmax_d = torch::empty({num_heads, total_q_padded_rounded}, opts.dtype(at::kFloat));
|
1431 |
+
softmax_lse_log2 = torch::empty({num_heads, total_q_padded_rounded}, opts.dtype(at::kFloat));
|
1432 |
+
}
|
1433 |
+
at::Tensor dq_accum, dk_accum, dv_accum;
|
1434 |
+
if (!is_varlen) {
|
1435 |
+
dq_accum = torch::empty({batch_size, num_heads, seqlen_q_rounded * head_size_rounded}, opts.dtype(at::kFloat));
|
1436 |
+
} else {
|
1437 |
+
dq_accum = torch::empty({num_heads, total_q_padded_rounded * head_size_rounded}, opts.dtype(at::kFloat));
|
1438 |
+
}
|
1439 |
+
if (num_heads_k != num_heads) { // MQA / GQA
|
1440 |
+
if (!is_varlen) {
|
1441 |
+
dk_accum = torch::zeros({batch_size, num_heads_k, seqlen_k_rounded * head_size_rounded}, opts.dtype(at::kFloat));
|
1442 |
+
dv_accum = torch::zeros({batch_size, num_heads_k, seqlen_k_rounded * head_size_v_rounded}, opts.dtype(at::kFloat));
|
1443 |
+
} else {
|
1444 |
+
dk_accum = torch::zeros({num_heads_k, total_k_padded_rounded, head_size_rounded}, opts.dtype(at::kFloat));
|
1445 |
+
dv_accum = torch::zeros({num_heads_k, total_k_padded_rounded, head_size_v_rounded}, opts.dtype(at::kFloat));
|
1446 |
+
}
|
1447 |
+
}
|
1448 |
+
|
1449 |
+
Flash_bwd_params params;
|
1450 |
+
set_params_dgrad(params,
|
1451 |
+
batch_size,
|
1452 |
+
seqlen_q, seqlen_k,
|
1453 |
+
seqlen_q_rounded, seqlen_k_rounded,
|
1454 |
+
num_heads, num_heads_k,
|
1455 |
+
head_size, head_size_rounded,
|
1456 |
+
q, k, v, out,
|
1457 |
+
dout, dq, dk, dv,
|
1458 |
+
!is_varlen_q ? nullptr : cu_seqlens_q.data_ptr(),
|
1459 |
+
!is_varlen_k ? nullptr : cu_seqlens_k.data_ptr(),
|
1460 |
+
seqused_q_.has_value() ? seqused_q_.value().data_ptr() : nullptr,
|
1461 |
+
seqused_k_.has_value() ? seqused_k_.value().data_ptr() : nullptr,
|
1462 |
+
dq_accum.data_ptr(),
|
1463 |
+
num_heads_k != num_heads ? dk_accum.data_ptr() : nullptr,
|
1464 |
+
num_heads_k != num_heads ? dv_accum.data_ptr() : nullptr,
|
1465 |
+
softmax_lse.data_ptr(),
|
1466 |
+
softmax_d.data_ptr(),
|
1467 |
+
/*p_dropout=*/0.f,
|
1468 |
+
softmax_scale,
|
1469 |
+
window_size_left,
|
1470 |
+
window_size_right,
|
1471 |
+
0, // attention_chunk
|
1472 |
+
softcap,
|
1473 |
+
deterministic,
|
1474 |
+
sm_margin);
|
1475 |
+
params.total_q = total_q;
|
1476 |
+
params.total_k = total_k;
|
1477 |
+
params.softmax_lse_log2_ptr = softmax_lse_log2.data_ptr();
|
1478 |
+
params.dv = head_size_v;
|
1479 |
+
params.dv_rounded = head_size_v_rounded;
|
1480 |
+
|
1481 |
+
// auto tile_count_semaphore = (params.is_causal || params.is_local) ? torch::zeros({1}, opts.dtype(torch::kInt32)) : torch::empty({1}, opts.dtype(torch::kInt32));
|
1482 |
+
// params.tile_count_semaphore = tile_count_semaphore.data_ptr<int>();
|
1483 |
+
// Will be zero'ed out in the backward preprocess kernel
|
1484 |
+
at::Tensor dq_semaphore = torch::empty({(seqlen_q + kBlockM - 1) / kBlockM, batch_size, num_heads}, opts.dtype(torch::kInt32));
|
1485 |
+
params.dq_semaphore = dq_semaphore.data_ptr<int>();
|
1486 |
+
if (num_heads_k != num_heads && params.deterministic) {
|
1487 |
+
// TODO: do we need to zero them out?
|
1488 |
+
at::Tensor dk_semaphore = torch::empty({(seqlen_k + kBlockN - 1) / kBlockN, batch_size, num_heads_k}, opts.dtype(torch::kInt32));
|
1489 |
+
at::Tensor dv_semaphore = torch::empty({(seqlen_k + kBlockN - 1) / kBlockN, batch_size, num_heads_k}, opts.dtype(torch::kInt32));
|
1490 |
+
params.dk_semaphore = dk_semaphore.data_ptr<int>();
|
1491 |
+
params.dv_semaphore = dv_semaphore.data_ptr<int>();
|
1492 |
+
}
|
1493 |
+
|
1494 |
+
#ifdef FLASHATTENTION_DISABLE_LOCAL
|
1495 |
+
TORCH_CHECK(!params.is_local, "This flash attention build does not support local attention.");
|
1496 |
+
#endif
|
1497 |
+
#ifdef FLASHATTENTION_DISABLE_SOFTCAP
|
1498 |
+
TORCH_CHECK(params.softcap == 0.0, "This flash attention build does not support tanh softcapping.");
|
1499 |
+
#endif
|
1500 |
+
|
1501 |
+
if (total_q > 0 && total_k > 0 && num_heads_k > 0) {
|
1502 |
+
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
1503 |
+
run_mha_bwd(params, stream);
|
1504 |
+
} else if (total_k > 0 && num_heads_k > 0) {
|
1505 |
+
// If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0.
|
1506 |
+
dk.zero_();
|
1507 |
+
dv.zero_();
|
1508 |
+
softmax_d.zero_();
|
1509 |
+
} else if (total_q > 0 && num_heads_k > 0) {
|
1510 |
+
dq.zero_();
|
1511 |
+
softmax_d.zero_();
|
1512 |
+
}
|
1513 |
+
|
1514 |
+
return { dq, dk, dv, softmax_d, softmax_lse_log2, dq_accum, dk_accum, dv_accum };
|
1515 |
+
}
|
1516 |
+
|
1517 |
+
std::tuple<at::Tensor, at::Tensor>
|
1518 |
+
mha_combine(at::Tensor out_partial, // num_splits x batch_size x seqlen x num_heads x head_size
|
1519 |
+
at::Tensor lse_partial, // num_splits x batch_size x seqlen x num_heads
|
1520 |
+
std::optional<at::Tensor> out_, // batch_size x seqlen x num_heads x head_size
|
1521 |
+
std::optional<at::ScalarType> out_dtype_
|
1522 |
+
) {
|
1523 |
+
|
1524 |
+
auto dprops = at::cuda::getCurrentDeviceProperties();
|
1525 |
+
bool is_sm8x = dprops->major >= 8;
|
1526 |
+
TORCH_CHECK(is_sm8x, "Attention combine function only supports Ampere GPUs or newer.");
|
1527 |
+
|
1528 |
+
auto out_partial_type = out_partial.scalar_type();
|
1529 |
+
TORCH_CHECK(out_partial_type == at::ScalarType::Float, "Attention combine function only support fp32 data type");
|
1530 |
+
TORCH_CHECK(lse_partial.scalar_type() == at::ScalarType::Float, "Attention combine function only support fp32 data type");
|
1531 |
+
|
1532 |
+
CHECK_DEVICE(out_partial); CHECK_DEVICE(lse_partial);
|
1533 |
+
|
1534 |
+
TORCH_CHECK(out_partial.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
1535 |
+
TORCH_CHECK(lse_partial.stride(-2) == 1, "LSE tensor must be contiguous in the seqlen dimension");
|
1536 |
+
|
1537 |
+
const auto sizes = out_partial.sizes();
|
1538 |
+
|
1539 |
+
const int num_splits = sizes[0];
|
1540 |
+
const int batch_size = sizes[1];
|
1541 |
+
const int seqlen = sizes[2];
|
1542 |
+
const int num_heads = sizes[3];
|
1543 |
+
const int head_size_og = sizes[4];
|
1544 |
+
TORCH_CHECK(num_splits <= 256, "FlashAttention combine only supports num_splits at most 256");
|
1545 |
+
|
1546 |
+
CHECK_SHAPE(out_partial, num_splits, batch_size, seqlen, num_heads, head_size_og);
|
1547 |
+
CHECK_SHAPE(lse_partial, num_splits, batch_size, seqlen, num_heads);
|
1548 |
+
|
1549 |
+
int const alignment = 4;
|
1550 |
+
at::Tensor out_partial_padded;
|
1551 |
+
auto pad = [](at::Tensor x, int alignment) {
|
1552 |
+
return x.size(-1) % alignment == 0 ? x : torch::nn::functional::pad(x, torch::nn::functional::PadFuncOptions({0, alignment - x.size(-1) % alignment}));
|
1553 |
+
};
|
1554 |
+
out_partial_padded = pad(out_partial, alignment);
|
1555 |
+
|
1556 |
+
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
|
1557 |
+
const int head_size = round_multiple(head_size_og, alignment);
|
1558 |
+
|
1559 |
+
auto opts = out_partial.options();
|
1560 |
+
at::ScalarType out_type = out_dtype_.value_or(out_partial.scalar_type());
|
1561 |
+
TORCH_CHECK(out_type == at::ScalarType::Float || out_type == at::ScalarType::BFloat16 || out_type == at::ScalarType::Half, "Output type must be FP32, FP16 or BF16");
|
1562 |
+
at::Tensor out;
|
1563 |
+
if (out_.has_value()) {
|
1564 |
+
out = out_.value();
|
1565 |
+
TORCH_CHECK(out.scalar_type() == out_type);
|
1566 |
+
CHECK_DEVICE(out);
|
1567 |
+
TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
|
1568 |
+
CHECK_SHAPE(out, batch_size, seqlen, num_heads, head_size_og);
|
1569 |
+
if (head_size_og % alignment != 0) {
|
1570 |
+
out = torch::empty({batch_size, seqlen, num_heads, head_size}, opts.dtype(out_type));
|
1571 |
+
}
|
1572 |
+
} else {
|
1573 |
+
out = torch::empty({batch_size, seqlen, num_heads, head_size}, opts.dtype(out_type));
|
1574 |
+
}
|
1575 |
+
|
1576 |
+
// Otherwise the kernel will be launched from cuda:0 device
|
1577 |
+
// Cast to char to avoid compiler warning about narrowing
|
1578 |
+
at::cuda::CUDAGuard device_guard{(char)out_partial.get_device()};
|
1579 |
+
|
1580 |
+
auto softmax_lse = torch::empty({batch_size, num_heads, seqlen}, opts.dtype(at::kFloat)).transpose(1, 2);
|
1581 |
+
|
1582 |
+
Flash_fwd_params params {}; // Need to reset the params to set everything to zero
|
1583 |
+
params.is_fp32 = out_type == at::ScalarType::Float;
|
1584 |
+
params.is_bf16 = out_type == at::ScalarType::BFloat16;
|
1585 |
+
params.oaccum_ptr = out_partial_padded.data_ptr();
|
1586 |
+
params.softmax_lseaccum_ptr = lse_partial.data_ptr();
|
1587 |
+
params.o_ptr = out.data_ptr();
|
1588 |
+
params.softmax_lse_ptr = softmax_lse.data_ptr();
|
1589 |
+
params.b = batch_size;
|
1590 |
+
params.h = num_heads;
|
1591 |
+
params.seqlen_q = seqlen;
|
1592 |
+
params.dv = head_size;
|
1593 |
+
params.num_splits = num_splits;
|
1594 |
+
params.oaccum_split_stride = out_partial_padded.stride(0);
|
1595 |
+
params.oaccum_row_stride = out_partial_padded.stride(2);
|
1596 |
+
params.oaccum_head_stride = out_partial_padded.stride(3);
|
1597 |
+
params.oaccum_batch_stride = out_partial_padded.stride(1);
|
1598 |
+
params.lseaccum_split_stride = lse_partial.stride(0);
|
1599 |
+
params.lseaccum_head_stride = lse_partial.stride(3);
|
1600 |
+
params.lseaccum_batch_stride = lse_partial.stride(1);
|
1601 |
+
params.o_row_stride = out.stride(1);
|
1602 |
+
params.o_head_stride = out.stride(2);
|
1603 |
+
params.o_batch_stride = out.stride(0);
|
1604 |
+
params.arch = at::cuda::getCurrentDeviceProperties()->major * 10 + at::cuda::getCurrentDeviceProperties()->minor;
|
1605 |
+
|
1606 |
+
if (seqlen > 0 && batch_size > 0) {
|
1607 |
+
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
1608 |
+
run_mha_fwd_combine(params, stream, false /*enable_pdl*/);
|
1609 |
+
}
|
1610 |
+
|
1611 |
+
at::Tensor out_padded = out;
|
1612 |
+
if (head_size_og % alignment != 0) {
|
1613 |
+
out = out.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
|
1614 |
+
// if (out_.has_value()) { out_.value().copy_(out); }
|
1615 |
+
}
|
1616 |
+
|
1617 |
+
return {out, softmax_lse};
|
1618 |
+
}
|
1619 |
+
|
1620 |
+
#ifdef false
|
1621 |
+
|
1622 |
+
TORCH_LIBRARY(flash_attn_3, m) {
|
1623 |
+
m.def("fwd("
|
1624 |
+
"Tensor q,"
|
1625 |
+
"Tensor k,"
|
1626 |
+
"Tensor v,"
|
1627 |
+
"Tensor(k_new!)? k_new = None,"
|
1628 |
+
"Tensor(v_new!)? v_new = None,"
|
1629 |
+
"Tensor? q_v = None,"
|
1630 |
+
"Tensor(out!)? out = None,"
|
1631 |
+
"Tensor? cu_seqlens_q = None,"
|
1632 |
+
"Tensor? cu_seqlens_k = None,"
|
1633 |
+
"Tensor? cu_seqlens_k_new = None,"
|
1634 |
+
"Tensor? seqused_q = None,"
|
1635 |
+
"Tensor? seqused_k = None,"
|
1636 |
+
"int? max_seqlen_q = None,"
|
1637 |
+
"int? max_seqlen_k = None,"
|
1638 |
+
"Tensor? page_table = None,"
|
1639 |
+
"Tensor? kv_batch_idx = None,"
|
1640 |
+
"Tensor? leftpad_k = None,"
|
1641 |
+
"Tensor? rotary_cos = None,"
|
1642 |
+
"Tensor? rotary_sin = None,"
|
1643 |
+
"Tensor? seqlens_rotary = None,"
|
1644 |
+
"Tensor? q_descale = None,"
|
1645 |
+
"Tensor? k_descale = None,"
|
1646 |
+
"Tensor? v_descale = None,"
|
1647 |
+
"float? softmax_scale = None,"
|
1648 |
+
"bool is_causal = False,"
|
1649 |
+
"int window_size_left = -1,"
|
1650 |
+
"int window_size_right = -1,"
|
1651 |
+
"int attention_chunk = 0,"
|
1652 |
+
"float softcap = 0.0,"
|
1653 |
+
"bool is_rotary_interleaved = False,"
|
1654 |
+
"Tensor? scheduler_metadata = None,"
|
1655 |
+
"int num_splits = 0,"
|
1656 |
+
"bool? pack_gqa = None,"
|
1657 |
+
"int sm_margin = 0) -> (Tensor(out!), Tensor, Tensor, Tensor)");
|
1658 |
+
m.def("bwd("
|
1659 |
+
"Tensor dout,"
|
1660 |
+
"Tensor q,"
|
1661 |
+
"Tensor k,"
|
1662 |
+
"Tensor v,"
|
1663 |
+
"Tensor out,"
|
1664 |
+
"Tensor softmax_lse,"
|
1665 |
+
"Tensor(dq!)? dq = None,"
|
1666 |
+
"Tensor(dk!)? dk = None,"
|
1667 |
+
"Tensor(dv!)? dv = None,"
|
1668 |
+
"Tensor? cu_seqlens_q = None,"
|
1669 |
+
"Tensor? cu_seqlens_k = None,"
|
1670 |
+
"Tensor? seqused_q = None,"
|
1671 |
+
"Tensor? seqused_k = None,"
|
1672 |
+
"int? max_seqlen_q = None,"
|
1673 |
+
"int? max_seqlen_k = None,"
|
1674 |
+
"float? softmax_scale = None,"
|
1675 |
+
"bool is_causal = False,"
|
1676 |
+
"int window_size_left = -1,"
|
1677 |
+
"int window_size_right = -1,"
|
1678 |
+
"float softcap = 0.0,"
|
1679 |
+
"bool deterministic = False,"
|
1680 |
+
"int sm_margin = 0) -> (Tensor(dq!), Tensor(dk!), Tensor(dv!), Tensor, Tensor, Tensor, Tensor, Tensor)");
|
1681 |
+
m.def("fwd_combine("
|
1682 |
+
"Tensor out_partial,"
|
1683 |
+
"Tensor lse_partial,"
|
1684 |
+
"Tensor(out!)? out = None,"
|
1685 |
+
"ScalarType? out_dtype = None) -> (Tensor(out!), Tensor)");
|
1686 |
+
m.def("get_scheduler_metadata("
|
1687 |
+
"int batch_size,"
|
1688 |
+
"int max_seqlen_q,"
|
1689 |
+
"int max_seqlen_k,"
|
1690 |
+
"int num_heads,"
|
1691 |
+
"int num_heads_k,"
|
1692 |
+
"int headdim,"
|
1693 |
+
"int headdim_v,"
|
1694 |
+
"ScalarType qkv_dtype,"
|
1695 |
+
"Tensor seqused_k,"
|
1696 |
+
"Tensor? cu_seqlens_q = None,"
|
1697 |
+
"Tensor? cu_seqlens_k = None,"
|
1698 |
+
"Tensor? cu_seqlens_k_new = None,"
|
1699 |
+
"Tensor? seqused_q = None,"
|
1700 |
+
"Tensor? leftpad_k = None,"
|
1701 |
+
"int? page_size = None,"
|
1702 |
+
"int max_seqlen_k_new = 0,"
|
1703 |
+
"bool is_causal = False,"
|
1704 |
+
"int window_size_left = -1,"
|
1705 |
+
"int window_size_right = -1,"
|
1706 |
+
"int attention_chunk = 0,"
|
1707 |
+
"bool has_softcap = False,"
|
1708 |
+
"int num_splits = 0,"
|
1709 |
+
"bool? pack_gqa = None,"
|
1710 |
+
"int sm_margin = 0) -> Tensor");
|
1711 |
+
}
|
1712 |
+
|
1713 |
+
TORCH_LIBRARY_IMPL(flash_attn_3, CUDA, m) {
|
1714 |
+
m.impl("fwd", &mha_fwd);
|
1715 |
+
m.impl("bwd", &mha_bwd);
|
1716 |
+
m.impl("fwd_combine", &mha_combine);
|
1717 |
+
m.impl("get_scheduler_metadata", &mha_fwd_get_scheduler_metadata);
|
1718 |
+
}
|
1719 |
+
|
1720 |
+
#endif
|
flash-attn/flash_bwd_kernel_sm80.h
ADDED
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/******************************************************************************
|
2 |
+
* Copyright (c) 2024, Tri Dao.
|
3 |
+
******************************************************************************/
|
4 |
+
|
5 |
+
#pragma once
|
6 |
+
|
7 |
+
#include "cute/tensor.hpp"
|
8 |
+
|
9 |
+
#include <cutlass/cutlass.h>
|
10 |
+
#include <cutlass/array.h>
|
11 |
+
#include <cutlass/numeric_types.h>
|
12 |
+
#include <cutlass/kernel_hardware_info.h>
|
13 |
+
|
14 |
+
#include "utils.h"
|
15 |
+
|
16 |
+
namespace flash {
|
17 |
+
|
18 |
+
using namespace cute;
|
19 |
+
|
20 |
+
template <class CollectiveMainloop_, class CollectiveEpilogue_, class TileScheduler_>
|
21 |
+
class FlashAttnBwdSm80 {
|
22 |
+
|
23 |
+
public:
|
24 |
+
|
25 |
+
// Type Aliases
|
26 |
+
static constexpr bool Is_causal = CollectiveMainloop_::Is_causal;
|
27 |
+
static constexpr bool Is_local = CollectiveMainloop_::Is_local;
|
28 |
+
static_assert(CollectiveMainloop_::Varlen == CollectiveEpilogue_::Varlen);
|
29 |
+
static constexpr bool Varlen = CollectiveMainloop_::Varlen;
|
30 |
+
|
31 |
+
// Mainloop derived types
|
32 |
+
using CollectiveMainloop = CollectiveMainloop_;
|
33 |
+
using TileShape_MNK = typename CollectiveMainloop::TileShape_MNK;
|
34 |
+
using TiledMmaSdP = typename CollectiveMainloop::TiledMmaSdP;
|
35 |
+
using TiledMmadKV = typename CollectiveMainloop::TiledMmadKV;
|
36 |
+
using ArchTag = typename CollectiveMainloop::ArchTag;
|
37 |
+
using MainloopArguments = typename CollectiveMainloop::Arguments;
|
38 |
+
using MainloopParams = typename CollectiveMainloop::Params;
|
39 |
+
static constexpr bool dKV_swapAB = CollectiveMainloop::dKV_swapAB;
|
40 |
+
|
41 |
+
// Epilogue derived types
|
42 |
+
using CollectiveEpilogue = CollectiveEpilogue_;
|
43 |
+
using EpilogueArguments = typename CollectiveEpilogue::Arguments;
|
44 |
+
using EpilogueParams = typename CollectiveEpilogue::Params;
|
45 |
+
|
46 |
+
static_assert(ArchTag::kMinComputeCapability >= 80);
|
47 |
+
|
48 |
+
using TileScheduler = TileScheduler_;
|
49 |
+
using TileSchedulerArguments = typename flash::TileSchedulerArguments;
|
50 |
+
using TileSchedulerParams = typename TileScheduler::Params;
|
51 |
+
|
52 |
+
static constexpr uint32_t NumThreads = CUTE_STATIC_V(size(TiledMmaSdP{}));
|
53 |
+
static constexpr uint32_t MaxThreadsPerBlock = CUTE_STATIC_V(size(TiledMmaSdP{}));
|
54 |
+
static constexpr uint32_t MinBlocksPerMultiprocessor = 1;
|
55 |
+
|
56 |
+
// Kernel level shared memory storage
|
57 |
+
struct SharedStorage {
|
58 |
+
struct TensorStorage : cute::aligned_struct<128> {
|
59 |
+
union {
|
60 |
+
typename CollectiveMainloop::TensorStorage mainloop;
|
61 |
+
typename CollectiveEpilogue::TensorStorage epilogue;
|
62 |
+
};
|
63 |
+
} tensors;
|
64 |
+
|
65 |
+
alignas(16) typename TileScheduler::SharedStorage smem_scheduler;
|
66 |
+
|
67 |
+
};
|
68 |
+
|
69 |
+
static constexpr int SharedStorageSize = sizeof(SharedStorage);
|
70 |
+
|
71 |
+
// Device side arguments
|
72 |
+
struct Arguments {
|
73 |
+
MainloopArguments mainloop{};
|
74 |
+
EpilogueArguments epilogue{};
|
75 |
+
cutlass::KernelHardwareInfo hw_info{};
|
76 |
+
TileSchedulerArguments scheduler{};
|
77 |
+
};
|
78 |
+
|
79 |
+
// Kernel entry point API
|
80 |
+
struct Params {
|
81 |
+
MainloopParams mainloop{};
|
82 |
+
EpilogueParams epilogue{};
|
83 |
+
cutlass::KernelHardwareInfo hw_info{};
|
84 |
+
TileSchedulerParams scheduler{};
|
85 |
+
};
|
86 |
+
|
87 |
+
//
|
88 |
+
// Methods
|
89 |
+
//
|
90 |
+
|
91 |
+
// Convert to underlying arguments. In this case, a simple copy for the aliased type.
|
92 |
+
static
|
93 |
+
Params
|
94 |
+
to_underlying_arguments(Arguments const& args) {
|
95 |
+
CUTLASS_TRACE_HOST("to_underlying_arguments():");
|
96 |
+
|
97 |
+
// Get SM count if needed, otherwise use user supplied SM count
|
98 |
+
int sm_count = args.hw_info.sm_count;
|
99 |
+
if (sm_count <= 0) {
|
100 |
+
CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n"
|
101 |
+
" For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count.");
|
102 |
+
sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id);
|
103 |
+
}
|
104 |
+
|
105 |
+
CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count);
|
106 |
+
|
107 |
+
cutlass::KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count};
|
108 |
+
return {
|
109 |
+
CollectiveMainloop::to_underlying_arguments(args.mainloop),
|
110 |
+
CollectiveEpilogue::to_underlying_arguments(args.epilogue),
|
111 |
+
hw_info,
|
112 |
+
TileScheduler::to_underlying_arguments(args.scheduler)
|
113 |
+
};
|
114 |
+
}
|
115 |
+
|
116 |
+
// Computes the kernel launch grid shape based on runtime parameters
|
117 |
+
static dim3
|
118 |
+
get_grid_shape(Params const& params) {
|
119 |
+
return TileScheduler::get_grid_shape(params.scheduler, params.hw_info.sm_count);
|
120 |
+
}
|
121 |
+
|
122 |
+
static dim3
|
123 |
+
get_block_shape() {
|
124 |
+
return dim3(MaxThreadsPerBlock, 1, 1);
|
125 |
+
}
|
126 |
+
|
127 |
+
CUTLASS_DEVICE
|
128 |
+
void
|
129 |
+
operator()(Params const& params, char* smem_buf) {
|
130 |
+
|
131 |
+
static constexpr int kBlockM = get<0>(TileShape_MNK{});
|
132 |
+
static constexpr int kBlockN = get<1>(TileShape_MNK{});
|
133 |
+
|
134 |
+
SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(smem_buf);
|
135 |
+
|
136 |
+
CollectiveMainloop mainloop;
|
137 |
+
CollectiveEpilogue epilogue;
|
138 |
+
|
139 |
+
TileScheduler scheduler(reinterpret_cast<typename TileScheduler::SharedStorage*>(&shared_storage.smem_scheduler));
|
140 |
+
// Initialize matmul objects.
|
141 |
+
TiledMmadKV tiled_mma_dKV;
|
142 |
+
|
143 |
+
scheduler.init_consumer();
|
144 |
+
|
145 |
+
int warp_idx = cutlass::canonical_warp_idx_sync();
|
146 |
+
CUTLASS_PRAGMA_NO_UNROLL
|
147 |
+
for (auto work_tile_info = warp_idx == 0 ? scheduler.template get_initial_work</*IsProducerWarp=*/true>(params.scheduler) : scheduler.template get_initial_work</*IsProducerWarp=*/false>(params.scheduler);
|
148 |
+
work_tile_info.is_valid(params.scheduler);
|
149 |
+
work_tile_info = warp_idx == 0 ? scheduler.template get_next_work</*IsProducerWarp=*/true>(params.scheduler, work_tile_info) : scheduler.template get_next_work</*IsProducerWarp=*/false>(params.scheduler, work_tile_info)) {
|
150 |
+
|
151 |
+
auto block_coord_ = work_tile_info.get_block_coord(params.scheduler);
|
152 |
+
auto [n_block, bidh, bidb, _ /*split_idx*/] = block_coord_;
|
153 |
+
cute::tuple<int32_t, int32_t, int32_t> block_coord = {n_block, bidh, bidb};
|
154 |
+
|
155 |
+
// dK and dV output accumulator.
|
156 |
+
Tensor tdKrdK = partition_fragment_C(tiled_mma_dKV, select<!dKV_swapAB ? 1 : 2, !dKV_swapAB? 2 : 1>(TileShape_MNK{}));
|
157 |
+
Tensor tdVrdV = partition_fragment_C(tiled_mma_dKV, select<!dKV_swapAB ? 1 : 2, !dKV_swapAB? 2 : 1>(TileShape_MNK{}));
|
158 |
+
bool tile_valid = mainloop.mma(params.mainloop, tdKrdK, tdVrdV, threadIdx.x,
|
159 |
+
block_coord, shared_storage);
|
160 |
+
scheduler.prefetch_next_work(params.scheduler, work_tile_info);
|
161 |
+
if (tile_valid) {
|
162 |
+
epilogue.store(params.epilogue, tdKrdK, tdVrdV, shared_storage, tiled_mma_dKV,
|
163 |
+
threadIdx.x, block_coord);
|
164 |
+
} else {
|
165 |
+
epilogue.store_zero(params.epilogue, threadIdx.x, block_coord);
|
166 |
+
}
|
167 |
+
}
|
168 |
+
|
169 |
+
}
|
170 |
+
|
171 |
+
};
|
172 |
+
|
173 |
+
} // namespace flash
|
flash-attn/flash_bwd_kernel_sm90.h
ADDED
@@ -0,0 +1,282 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
/******************************************************************************
|
3 |
+
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
4 |
+
******************************************************************************/
|
5 |
+
|
6 |
+
#pragma once
|
7 |
+
|
8 |
+
#include "cute/tensor.hpp"
|
9 |
+
|
10 |
+
#include <cutlass/cutlass.h>
|
11 |
+
#include <cutlass/arch/reg_reconfig.h>
|
12 |
+
#include <cutlass/array.h>
|
13 |
+
#include <cutlass/numeric_types.h>
|
14 |
+
#include <cutlass/numeric_conversion.h>
|
15 |
+
#include <cutlass/kernel_hardware_info.h>
|
16 |
+
#include "cutlass/pipeline/pipeline.hpp"
|
17 |
+
|
18 |
+
#include "utils.h"
|
19 |
+
|
20 |
+
namespace flash {
|
21 |
+
|
22 |
+
using namespace cute;
|
23 |
+
|
24 |
+
template <class CollectiveMainloop_, class CollectiveEpilogue_, class TileScheduler_>
|
25 |
+
class FlashAttnBwdSm90 {
|
26 |
+
|
27 |
+
public:
|
28 |
+
|
29 |
+
// Type Aliases
|
30 |
+
static constexpr bool Is_causal = CollectiveMainloop_::Is_causal;
|
31 |
+
static constexpr bool Is_local = CollectiveMainloop_::Is_local;
|
32 |
+
static_assert(CollectiveMainloop_::Varlen == CollectiveEpilogue_::Varlen);
|
33 |
+
static constexpr bool Varlen = CollectiveMainloop_::Varlen;
|
34 |
+
|
35 |
+
// Mainloop derived types
|
36 |
+
using CollectiveMainloop = CollectiveMainloop_;
|
37 |
+
using TileShape_MNK = typename CollectiveMainloop::TileShape_MNK;
|
38 |
+
using TiledMmaSdP = typename CollectiveMainloop::TiledMmaSdP;
|
39 |
+
using TiledMmadKV = typename CollectiveMainloop::TiledMmadKV;
|
40 |
+
using ArchTag = typename CollectiveMainloop::ArchTag;
|
41 |
+
using ClusterShape = typename CollectiveMainloop::ClusterShape;
|
42 |
+
using MainloopArguments = typename CollectiveMainloop::Arguments;
|
43 |
+
using MainloopParams = typename CollectiveMainloop::Params;
|
44 |
+
static constexpr bool dKV_swapAB = CollectiveMainloop::dKV_swapAB;
|
45 |
+
|
46 |
+
// Epilogue derived types
|
47 |
+
using CollectiveEpilogue = CollectiveEpilogue_;
|
48 |
+
using EpilogueArguments = typename CollectiveEpilogue::Arguments;
|
49 |
+
using EpilogueParams = typename CollectiveEpilogue::Params;
|
50 |
+
|
51 |
+
static_assert(ArchTag::kMinComputeCapability >= 90);
|
52 |
+
|
53 |
+
using TileScheduler = TileScheduler_;
|
54 |
+
using TileSchedulerArguments = typename flash::TileSchedulerArguments;
|
55 |
+
using TileSchedulerParams = typename TileScheduler::Params;
|
56 |
+
|
57 |
+
static constexpr uint32_t NumLoadWarpGroups = 1;
|
58 |
+
static constexpr uint32_t NumMmaWarpGroups = CUTE_STATIC_V(size(TiledMmaSdP{})) / cutlass::NumThreadsPerWarpGroup;
|
59 |
+
static constexpr uint32_t MaxThreadsPerBlock = CUTE_STATIC_V(size(TiledMmaSdP{})) + (NumLoadWarpGroups * cutlass::NumThreadsPerWarpGroup);
|
60 |
+
static constexpr uint32_t MinBlocksPerMultiprocessor = 1;
|
61 |
+
static_assert(NumMmaWarpGroups == 2 || NumMmaWarpGroups == 3);
|
62 |
+
|
63 |
+
/// Register requirement for Load and Math WGs
|
64 |
+
static constexpr uint32_t LoadRegisterRequirement = NumMmaWarpGroups == 2 ? 24 : 32;
|
65 |
+
static constexpr uint32_t MmaRegisterRequirement = NumMmaWarpGroups == 2 ? 240 : 160;
|
66 |
+
// If you want to print from the producer warp, you'd need to increase the number of registers
|
67 |
+
// Otherwise you'll get CUDA error.
|
68 |
+
// static constexpr uint32_t LoadRegisterRequirement = 40;
|
69 |
+
// static constexpr uint32_t MmaRegisterRequirement = NumMmaWarpGroups == 2 ? 232 : 152;
|
70 |
+
|
71 |
+
// Kernel level shared memory storage
|
72 |
+
struct SharedStorage {
|
73 |
+
struct TensorStorage : cute::aligned_struct<128> {
|
74 |
+
union {
|
75 |
+
typename CollectiveMainloop::TensorStorage mainloop;
|
76 |
+
typename CollectiveEpilogue::TensorStorage epilogue;
|
77 |
+
};
|
78 |
+
} tensors;
|
79 |
+
|
80 |
+
struct PipelineStorage : cute::aligned_struct<16> {
|
81 |
+
alignas(16) cutlass::arch::ClusterTransactionBarrier barrier_KV;
|
82 |
+
alignas(16) typename CollectiveMainloop::MainloopPipeline::SharedStorage pipeline_q;
|
83 |
+
alignas(16) typename CollectiveMainloop::MainloopPipeline_dO::SharedStorage pipeline_do;
|
84 |
+
alignas(16) typename TileScheduler::SharedStorage smem_scheduler;
|
85 |
+
} pipelines;
|
86 |
+
|
87 |
+
};
|
88 |
+
|
89 |
+
static constexpr int SharedStorageSize = sizeof(SharedStorage);
|
90 |
+
|
91 |
+
// Device side arguments
|
92 |
+
struct Arguments {
|
93 |
+
MainloopArguments mainloop{};
|
94 |
+
EpilogueArguments epilogue{};
|
95 |
+
cutlass::KernelHardwareInfo hw_info{};
|
96 |
+
TileSchedulerArguments scheduler{};
|
97 |
+
};
|
98 |
+
|
99 |
+
// Kernel entry point API
|
100 |
+
struct Params {
|
101 |
+
MainloopParams mainloop{};
|
102 |
+
EpilogueParams epilogue{};
|
103 |
+
cutlass::KernelHardwareInfo hw_info{};
|
104 |
+
TileSchedulerParams scheduler{};
|
105 |
+
};
|
106 |
+
|
107 |
+
//
|
108 |
+
// Methods
|
109 |
+
//
|
110 |
+
|
111 |
+
// Convert to underlying arguments. In this case, a simple copy for the aliased type.
|
112 |
+
static
|
113 |
+
Params
|
114 |
+
to_underlying_arguments(Arguments const& args) {
|
115 |
+
CUTLASS_TRACE_HOST("to_underlying_arguments():");
|
116 |
+
|
117 |
+
// Get SM count if needed, otherwise use user supplied SM count
|
118 |
+
int sm_count = args.hw_info.sm_count;
|
119 |
+
if (sm_count <= 0) {
|
120 |
+
CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n"
|
121 |
+
" For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count.");
|
122 |
+
sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id);
|
123 |
+
}
|
124 |
+
|
125 |
+
CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count);
|
126 |
+
|
127 |
+
cutlass::KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count};
|
128 |
+
return {
|
129 |
+
CollectiveMainloop::to_underlying_arguments(args.mainloop),
|
130 |
+
CollectiveEpilogue::to_underlying_arguments(args.epilogue),
|
131 |
+
hw_info,
|
132 |
+
TileScheduler::to_underlying_arguments(args.scheduler)
|
133 |
+
};
|
134 |
+
}
|
135 |
+
|
136 |
+
// Computes the kernel launch grid shape based on runtime parameters
|
137 |
+
static dim3
|
138 |
+
get_grid_shape(Params const& params) {
|
139 |
+
return TileScheduler::get_grid_shape(params.scheduler, params.hw_info.sm_count);
|
140 |
+
}
|
141 |
+
|
142 |
+
static dim3
|
143 |
+
get_block_shape() {
|
144 |
+
return dim3(MaxThreadsPerBlock, 1, 1);
|
145 |
+
}
|
146 |
+
|
147 |
+
CUTLASS_DEVICE
|
148 |
+
void
|
149 |
+
operator()(Params const& params, char* smem_buf) {
|
150 |
+
|
151 |
+
static constexpr int NumMmaThreads = NumMmaWarpGroups * cutlass::NumThreadsPerWarpGroup;
|
152 |
+
static constexpr int NumCopyThreads = NumLoadWarpGroups * cutlass::NumThreadsPerWarpGroup;
|
153 |
+
static constexpr int kBlockM = get<0>(TileShape_MNK{});
|
154 |
+
static constexpr int kBlockN = get<1>(TileShape_MNK{});
|
155 |
+
|
156 |
+
using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline;
|
157 |
+
using PipelineParams = typename MainloopPipeline::Params;
|
158 |
+
using PipelineState = typename MainloopPipeline::PipelineState;
|
159 |
+
using MainloopPipeline_dO = typename CollectiveMainloop::MainloopPipeline_dO;
|
160 |
+
using PipelineParams_dO = typename MainloopPipeline_dO::Params;
|
161 |
+
using PipelineState_dO = typename MainloopPipeline_dO::PipelineState;
|
162 |
+
static constexpr bool Q_dO_same_stages = std::is_same_v<MainloopPipeline, MainloopPipeline_dO>;
|
163 |
+
|
164 |
+
SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(smem_buf);
|
165 |
+
|
166 |
+
int const lane_predicate = cute::elect_one_sync();
|
167 |
+
int const warp_idx = cutlass::canonical_warp_idx_sync();
|
168 |
+
|
169 |
+
// Issue Tma Descriptor Prefetch from a single thread
|
170 |
+
if (warp_idx == 0 && lane_predicate) {
|
171 |
+
CollectiveMainloop::prefetch_tma_descriptors(params.mainloop);
|
172 |
+
CollectiveEpilogue::prefetch_tma_descriptors(params.epilogue);
|
173 |
+
}
|
174 |
+
|
175 |
+
// Obtain warp index
|
176 |
+
int const warp_group_thread_idx = threadIdx.x % cutlass::NumThreadsPerWarpGroup;
|
177 |
+
|
178 |
+
PipelineParams pipeline_params;
|
179 |
+
pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytesQ + CollectiveMainloop::TmaTransactionBytesLSE;
|
180 |
+
int warp_group_idx = cutlass::canonical_warp_group_idx();
|
181 |
+
pipeline_params.role = warp_group_idx == 0
|
182 |
+
? MainloopPipeline::ThreadCategory::Producer
|
183 |
+
: MainloopPipeline::ThreadCategory::Consumer;
|
184 |
+
pipeline_params.is_leader = warp_group_thread_idx == 0;
|
185 |
+
pipeline_params.num_consumers = NumMmaThreads;
|
186 |
+
|
187 |
+
if (warp_idx == 0 && lane_predicate) {
|
188 |
+
shared_storage.pipelines.barrier_KV.init(1 /*numThreads*/);
|
189 |
+
}
|
190 |
+
// We're counting on pipeline_q to call cutlass::arch::fence_barrier_init();
|
191 |
+
MainloopPipeline pipeline_q(shared_storage.pipelines.pipeline_q, pipeline_params, ClusterShape{});
|
192 |
+
auto role_dO = warp_group_idx == 0
|
193 |
+
? MainloopPipeline_dO::ThreadCategory::Producer
|
194 |
+
: MainloopPipeline_dO::ThreadCategory::Consumer;
|
195 |
+
PipelineParams_dO pipeline_params_dO {pipeline_params.transaction_bytes, role_dO, pipeline_params.is_leader, pipeline_params.num_consumers};
|
196 |
+
MainloopPipeline_dO pipeline_do(shared_storage.pipelines.pipeline_do, cute::conditional_return<Q_dO_same_stages>(pipeline_params, pipeline_params_dO), ClusterShape{});
|
197 |
+
|
198 |
+
CollectiveMainloop mainloop;
|
199 |
+
CollectiveEpilogue epilogue;
|
200 |
+
|
201 |
+
// We need this to guarantee that the Pipeline init is visible to all producers and consumer blocks in the Cluster
|
202 |
+
if constexpr (size(ClusterShape{}) > 1) {
|
203 |
+
cute::cluster_arrive_relaxed();
|
204 |
+
cute::cluster_wait();
|
205 |
+
} else {
|
206 |
+
__syncthreads();
|
207 |
+
}
|
208 |
+
|
209 |
+
TileScheduler scheduler(reinterpret_cast<typename TileScheduler::SharedStorage*>(&shared_storage.pipelines.smem_scheduler));
|
210 |
+
|
211 |
+
if (warp_group_idx == 0) { // Producer
|
212 |
+
cutlass::arch::warpgroup_reg_dealloc<LoadRegisterRequirement>();
|
213 |
+
|
214 |
+
int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0);
|
215 |
+
if (warp_idx_in_warpgroup == 0) { // Load K, V, and do TMA on Q and dO
|
216 |
+
PipelineState smem_pipe_write = cutlass::make_producer_start_state<MainloopPipeline>();
|
217 |
+
PipelineState_dO smem_pipe_write_do = cutlass::make_producer_start_state<MainloopPipeline_dO>();
|
218 |
+
for (auto work_tile_info = scheduler.template get_initial_work</*IsProducerWarp=*/true>(params.scheduler);
|
219 |
+
work_tile_info.is_valid(params.scheduler);
|
220 |
+
work_tile_info = scheduler.template get_next_work</*IsProducerWarp=*/true>(params.scheduler, work_tile_info)) {
|
221 |
+
auto block_coord_ = work_tile_info.get_block_coord(params.scheduler);
|
222 |
+
auto [n_block, bidh, bidb, _ /*split_idx*/] = block_coord_;
|
223 |
+
cute::tuple<int32_t, int32_t, int32_t> block_coord = {n_block, bidh, bidb};
|
224 |
+
auto scheduler_prefetch = [&scheduler, ¶ms, &work_tile_info]() {
|
225 |
+
scheduler.prefetch_next_work(params.scheduler, work_tile_info);
|
226 |
+
};
|
227 |
+
mainloop.load(params.mainloop, pipeline_q, pipeline_do, smem_pipe_write,
|
228 |
+
smem_pipe_write_do, shared_storage, scheduler_prefetch, block_coord);
|
229 |
+
}
|
230 |
+
mainloop.load_tail(pipeline_q, pipeline_do, smem_pipe_write, smem_pipe_write_do);
|
231 |
+
} else if (warp_idx_in_warpgroup == 1) {
|
232 |
+
for (auto work_tile_info = scheduler.template get_initial_work</*IsProducerWarp=*/false>(params.scheduler);
|
233 |
+
work_tile_info.is_valid(params.scheduler);
|
234 |
+
work_tile_info = scheduler.template get_next_work</*IsProducerWarp=*/false>(params.scheduler, work_tile_info)) {
|
235 |
+
auto block_coord_ = work_tile_info.get_block_coord(params.scheduler);
|
236 |
+
auto [n_block, bidh, bidb, _ /*split_idx*/] = block_coord_;
|
237 |
+
cute::tuple<int32_t, int32_t, int32_t> block_coord = {n_block, bidh, bidb};
|
238 |
+
mainloop.store_dq(params.mainloop, shared_storage, block_coord);
|
239 |
+
}
|
240 |
+
}
|
241 |
+
} else { // Consumer
|
242 |
+
cutlass::arch::warpgroup_reg_alloc<MmaRegisterRequirement>();
|
243 |
+
// Initialize matmul objects.
|
244 |
+
TiledMmadKV tiled_mma_dKV;
|
245 |
+
|
246 |
+
PipelineState smem_pipe_read;
|
247 |
+
PipelineState_dO smem_pipe_read_do;
|
248 |
+
|
249 |
+
mainloop.mma_init();
|
250 |
+
scheduler.init_consumer();
|
251 |
+
|
252 |
+
int work_idx = 0;
|
253 |
+
CUTLASS_PRAGMA_NO_UNROLL
|
254 |
+
for (auto work_tile_info = scheduler.template get_initial_work</*IsProducerWarp=*/false>(params.scheduler);
|
255 |
+
work_tile_info.is_valid(params.scheduler);
|
256 |
+
work_tile_info = scheduler.template get_next_work</*IsProducerWarp=*/false>(params.scheduler, work_tile_info)) {
|
257 |
+
auto block_coord_ = work_tile_info.get_block_coord(params.scheduler);
|
258 |
+
auto [n_block, bidh, bidb, _ /*split_idx*/] = block_coord_;
|
259 |
+
cute::tuple<int32_t, int32_t, int32_t> block_coord = {n_block, bidh, bidb};
|
260 |
+
|
261 |
+
// dK and dV output accumulator.
|
262 |
+
Tensor tdKrdK = partition_fragment_C(tiled_mma_dKV, select<!dKV_swapAB ? 1 : 2, !dKV_swapAB? 2 : 1>(TileShape_MNK{}));
|
263 |
+
Tensor tdVrdV = partition_fragment_C(tiled_mma_dKV, select<!dKV_swapAB ? 1 : 2, !dKV_swapAB? 2 : 1>(TileShape_MNK{}));
|
264 |
+
bool tile_valid = mainloop.mma(
|
265 |
+
params.mainloop, pipeline_q, pipeline_do, smem_pipe_read, smem_pipe_read_do,
|
266 |
+
tdKrdK, tdVrdV, threadIdx.x - NumCopyThreads, work_idx, block_coord, shared_storage);
|
267 |
+
if (tile_valid) {
|
268 |
+
epilogue.store(params.epilogue, tdKrdK, tdVrdV, shared_storage, tiled_mma_dKV,
|
269 |
+
threadIdx.x - NumCopyThreads, block_coord);
|
270 |
+
} else {
|
271 |
+
epilogue.store_zero(params.epilogue, threadIdx.x - NumCopyThreads, block_coord);
|
272 |
+
}
|
273 |
+
|
274 |
+
}
|
275 |
+
epilogue.store_tail();
|
276 |
+
}
|
277 |
+
|
278 |
+
}
|
279 |
+
|
280 |
+
};
|
281 |
+
|
282 |
+
} // namespace flash
|
flash-attn/flash_bwd_launch_template.h
ADDED
@@ -0,0 +1,390 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/******************************************************************************
|
2 |
+
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
3 |
+
******************************************************************************/
|
4 |
+
|
5 |
+
#pragma once
|
6 |
+
|
7 |
+
#include "cute/tensor.hpp"
|
8 |
+
|
9 |
+
#include "cutlass/device_kernel.h" // For device_kernel
|
10 |
+
#include "cutlass/kernel_launch.h" // For kernel_launch
|
11 |
+
#include "cutlass/cluster_launch.hpp" // For ClusterLauncher
|
12 |
+
|
13 |
+
#include "static_switch.h"
|
14 |
+
#include "flash.h"
|
15 |
+
#include "flash_bwd_preprocess_kernel.h"
|
16 |
+
#include "flash_bwd_postprocess_kernel.h"
|
17 |
+
#include "tile_scheduler.hpp"
|
18 |
+
#include "mainloop_bwd_sm90_tma_gmma_ws.hpp"
|
19 |
+
#include "mainloop_bwd_sm80.hpp"
|
20 |
+
#include "epilogue_bwd.hpp"
|
21 |
+
#include "flash_bwd_kernel_sm90.h"
|
22 |
+
#include "flash_bwd_kernel_sm80.h"
|
23 |
+
|
24 |
+
using namespace cute;
|
25 |
+
|
26 |
+
template <int Arch, int kHeadDim, int kBlockM, int kBlockN, typename Element,
|
27 |
+
bool Is_causal, bool Is_local, bool Has_softcap, bool Varlen, bool Deterministic, bool GQA,
|
28 |
+
int Stages_dO=2, int Stages_dS_or_QSm80=2,
|
29 |
+
bool SdP_swapAB=true, bool dKV_swapAB=false, bool dQ_swapAB=false,
|
30 |
+
int NumMmaWarpGroups=2, int AtomLayoutMSdP=1, int AtomLayoutNdKV=2, int AtomLayoutMdQ=1,
|
31 |
+
bool V_in_regs=false>
|
32 |
+
void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
33 |
+
static_assert(!(Is_causal && Is_local), "Is_causal and Is_local cannot be true at the same time.");
|
34 |
+
using ElementAccum = float;
|
35 |
+
using ArchTag = std::conditional_t<Arch >= 90, cutlass::arch::Sm90, cutlass::arch::Sm80>;
|
36 |
+
|
37 |
+
int const total_q_padded_rounded = cute::round_up(params.total_q + params.b * kBlockM, kBlockM);
|
38 |
+
int const total_k_padded_rounded = cute::round_up(params.total_k + params.b * kBlockN, kBlockN);
|
39 |
+
bool const is_varlen_q = params.cu_seqlens_q;
|
40 |
+
bool const is_varlen_k = params.cu_seqlens_k;
|
41 |
+
int seqlen_q = !is_varlen_q ? params.seqlen_q : params.total_q;
|
42 |
+
int seqlen_k = !is_varlen_k ? params.seqlen_k : params.total_k;
|
43 |
+
int seqlen_q_rounded = !is_varlen_q ? params.seqlen_q_rounded : total_q_padded_rounded;
|
44 |
+
int seqlen_k_rounded = !is_varlen_k ? params.seqlen_k_rounded : total_k_padded_rounded;
|
45 |
+
int batch_q = !is_varlen_q ? params.b : 1;
|
46 |
+
int batch_k = !is_varlen_k ? params.b : 1;
|
47 |
+
|
48 |
+
using TileShape_MK = cute::Shape<Int<kBlockM>, Int<kHeadDim>>;
|
49 |
+
using PreprocessKernel = flash::FlashAttnBwdPreprocess<TileShape_MK, Element, ElementAccum, ArchTag, /*Clear_dQaccum=*/true, Varlen>;
|
50 |
+
typename PreprocessKernel::Arguments preprocess_args {
|
51 |
+
static_cast<Element const*>(params.o_ptr),
|
52 |
+
{seqlen_q, params.dv, params.h, batch_q}, // shape_O
|
53 |
+
{params.o_row_stride, _1{}, params.o_head_stride, !is_varlen_q ? params.o_batch_stride : 0}, // stride_O
|
54 |
+
static_cast<Element const*>(params.do_ptr),
|
55 |
+
{params.do_row_stride, _1{}, params.do_head_stride, !is_varlen_q ? params.do_batch_stride : 0}, // stride_dO
|
56 |
+
static_cast<float*>(params.dsoftmax_sum),
|
57 |
+
{seqlen_q_rounded, params.h, batch_q}, // shape_dPsum
|
58 |
+
{_1{}, seqlen_q_rounded, !is_varlen_q ? params.h * params.seqlen_q_rounded : 0}, // stride_dPsum
|
59 |
+
static_cast<float*>(params.softmax_lse_ptr),
|
60 |
+
{_1{}, seqlen_q, !is_varlen_q ? params.h * params.seqlen_q : 0}, // stride_LSE
|
61 |
+
static_cast<float*>(params.softmax_lse_log2_ptr),
|
62 |
+
{_1{}, seqlen_q_rounded, !is_varlen_q ? params.h * params.seqlen_q_rounded : 0}, // stride_LSE_log2
|
63 |
+
static_cast<ElementAccum*>(params.dq_accum_ptr),
|
64 |
+
{seqlen_q_rounded * params.d_rounded, params.h, batch_q}, // shape_dQaccum
|
65 |
+
{_1{}, seqlen_q_rounded * params.d_rounded, !is_varlen_q ? params.d_rounded * seqlen_q_rounded * params.h : 0}, // stride_dQaccum
|
66 |
+
params.b,
|
67 |
+
params.dq_semaphore,
|
68 |
+
params.cu_seqlens_q,
|
69 |
+
params.seqused_q
|
70 |
+
};
|
71 |
+
typename PreprocessKernel::Params preprocess_params = PreprocessKernel::to_underlying_arguments(preprocess_args);
|
72 |
+
int num_m_block = cute::ceil_div(params.seqlen_q, kBlockM);
|
73 |
+
dim3 grid_m(num_m_block, params.h, params.b);
|
74 |
+
cutlass::kernel_launch<PreprocessKernel>(grid_m, PreprocessKernel::MaxThreadsPerBlock, PreprocessKernel::SharedStorageSize, stream, preprocess_params, false /*launch_with_pdl*/);
|
75 |
+
CHECK_CUDA_KERNEL_LAUNCH();
|
76 |
+
|
77 |
+
using TileShape_MNK = cute::Shape<Int<kBlockM>, Int<kBlockN>, Int<kHeadDim>>;
|
78 |
+
using ClusterShape = cute::Shape<_1, Int<1>, _1>; // Currently doesn't not support cluster
|
79 |
+
// Stages_dS_or_QSm80 is Stages_dS if Sm90 and Stages if Sm80
|
80 |
+
static constexpr int Stages = Arch >= 90 ? 2 : Stages_dS_or_QSm80;
|
81 |
+
static constexpr int Stages_dS = Arch >= 90 ? Stages_dS_or_QSm80 : 1;
|
82 |
+
using CollectiveMainloop = std::conditional_t<
|
83 |
+
Arch >= 90,
|
84 |
+
flash::CollectiveMainloopBwdSm90<Stages, Stages_dO, Stages_dS, ClusterShape, TileShape_MNK, Element, ElementAccum, cutlass::arch::Sm90,
|
85 |
+
Is_causal, Is_local, Has_softcap, Varlen, Deterministic,
|
86 |
+
SdP_swapAB, dKV_swapAB, dQ_swapAB, NumMmaWarpGroups, AtomLayoutMSdP, AtomLayoutNdKV, AtomLayoutMdQ, V_in_regs>,
|
87 |
+
flash::CollectiveMainloopBwdSm80<Stages, Stages_dO, TileShape_MNK, Element, ElementAccum, cutlass::arch::Sm80,
|
88 |
+
Is_causal, Is_local, Has_softcap, Varlen, Deterministic,
|
89 |
+
SdP_swapAB, dKV_swapAB, dQ_swapAB, NumMmaWarpGroups, AtomLayoutMSdP, AtomLayoutNdKV, AtomLayoutMdQ, V_in_regs>
|
90 |
+
>;
|
91 |
+
using CollectiveEpilogue = std::conditional_t<
|
92 |
+
!GQA,
|
93 |
+
flash::CollectiveEpilogueBwd<TileShape_MNK, Element, ArchTag, CollectiveMainloop::NumMmaThreads, Varlen, dKV_swapAB, NumMmaWarpGroups * (Arch >= 90 ? 1 : cutlass::NumWarpsPerWarpGroup) / AtomLayoutNdKV>,
|
94 |
+
flash::CollectiveEpilogueBwdGQA<TileShape_MNK, ElementAccum, ArchTag, CollectiveMainloop::NumMmaThreads, Varlen, Deterministic>
|
95 |
+
>;
|
96 |
+
using Scheduler = std::conditional_t<
|
97 |
+
Is_causal && !Varlen,
|
98 |
+
flash::SingleTileBwdLPTScheduler,
|
99 |
+
flash::SingleTileScheduler<Varlen, false /*Split*/, false /*PackGQA*/, kBlockN>
|
100 |
+
>;
|
101 |
+
using AttnKernel = std::conditional_t<
|
102 |
+
Arch >= 90,
|
103 |
+
flash::enable_sm90_or_later<flash::FlashAttnBwdSm90<CollectiveMainloop, CollectiveEpilogue, Scheduler>>,
|
104 |
+
flash::enable_sm80_to_sm89<flash::FlashAttnBwdSm80<CollectiveMainloop, CollectiveEpilogue, Scheduler>>
|
105 |
+
>;
|
106 |
+
|
107 |
+
typename CollectiveMainloop::Arguments mainloop_args {
|
108 |
+
static_cast<Element const*>(params.q_ptr),
|
109 |
+
{seqlen_q, params.d, params.h, batch_q}, // shape_Q
|
110 |
+
{params.q_row_stride, _1{}, params.q_head_stride, !is_varlen_q ? params.q_batch_stride : 0}, // stride_Q
|
111 |
+
static_cast<Element const*>(params.k_ptr),
|
112 |
+
{seqlen_k, params.d, params.h_k, batch_k}, // shape_K
|
113 |
+
{params.k_row_stride, _1{}, params.k_head_stride, !is_varlen_k ? params.k_batch_stride : 0}, // stride_K
|
114 |
+
static_cast<Element const*>(params.v_ptr),
|
115 |
+
{seqlen_k, params.dv, params.h_k, batch_k}, // shape_V
|
116 |
+
{params.v_row_stride, _1{}, params.v_head_stride, !is_varlen_k ? params.v_batch_stride : 0}, // stride_V
|
117 |
+
static_cast<Element const*>(params.do_ptr),
|
118 |
+
{seqlen_q, params.dv, params.h, batch_q}, // shape_dO
|
119 |
+
{params.do_row_stride, _1{}, params.do_head_stride, !is_varlen_q ? params.do_batch_stride : 0}, // stride_dO
|
120 |
+
static_cast<ElementAccum*>(params.dq_accum_ptr),
|
121 |
+
{seqlen_q_rounded * params.d_rounded, params.h, batch_q}, // shape_dQaccum
|
122 |
+
{_1{}, seqlen_q_rounded * params.d_rounded, !is_varlen_q ? params.d_rounded * params.seqlen_q_rounded * params.h : 0}, // stride_dQaccum
|
123 |
+
static_cast<float*>(params.softmax_lse_log2_ptr),
|
124 |
+
{seqlen_q_rounded, params.h, batch_q}, // shape_LSE
|
125 |
+
{_1{}, seqlen_q_rounded, !is_varlen_q ? params.h * params.seqlen_q_rounded : 0}, // stride_LSE_log2
|
126 |
+
static_cast<float*>(params.dsoftmax_sum),
|
127 |
+
{_1{}, seqlen_q_rounded, !is_varlen_q ? params.h * params.seqlen_q_rounded : 0}, // stride_dPsum
|
128 |
+
params.scale_softmax,
|
129 |
+
params.window_size_left, params.window_size_right, 0 /*attention_chunk*/,
|
130 |
+
params.softcap,
|
131 |
+
params.b,
|
132 |
+
params.dq_semaphore,
|
133 |
+
params.cu_seqlens_q, params.cu_seqlens_k,
|
134 |
+
params.seqused_q, params.seqused_k
|
135 |
+
};
|
136 |
+
// The case work with GQA is ugly but idk how to fix it.
|
137 |
+
typename CollectiveEpilogue::Arguments epilogue_args {
|
138 |
+
static_cast<typename CollectiveEpilogue::Element*>(!GQA ? params.dk_ptr : params.dk_accum_ptr),
|
139 |
+
[&] {
|
140 |
+
if constexpr (!GQA) {
|
141 |
+
return typename CollectiveEpilogue::ShapedKV {seqlen_k, params.d, params.h, batch_k}; // shape_dK
|
142 |
+
} else {
|
143 |
+
return typename CollectiveEpilogue::ShapedKV {seqlen_k_rounded * params.d_rounded, params.h_k, batch_k}; // shape_dKaccum
|
144 |
+
}
|
145 |
+
}(),
|
146 |
+
[&] {
|
147 |
+
if constexpr (!GQA) {
|
148 |
+
return typename CollectiveEpilogue::StridedKV {params.dk_row_stride, _1{}, params.dk_head_stride, !is_varlen_k ? params.dk_batch_stride : 0}; // stride_dK
|
149 |
+
} else {
|
150 |
+
return typename CollectiveEpilogue::StridedKV {_1{}, params.d_rounded * seqlen_k_rounded, !is_varlen_k ? params.h_k * params.d_rounded * params.seqlen_k_rounded : 0}; // stride_dKaccum
|
151 |
+
}
|
152 |
+
}(),
|
153 |
+
static_cast<typename CollectiveEpilogue::Element*>(!GQA ? params.dv_ptr : params.dv_accum_ptr),
|
154 |
+
[&] {
|
155 |
+
if constexpr (!GQA) {
|
156 |
+
return typename CollectiveEpilogue::ShapedKV {seqlen_k, params.dv, params.h, batch_k}; // shape_dV
|
157 |
+
} else {
|
158 |
+
return typename CollectiveEpilogue::ShapedKV {seqlen_k_rounded * params.dv_rounded, params.h_k, batch_k}; // shape_dVaccum
|
159 |
+
}
|
160 |
+
}(),
|
161 |
+
[&] {
|
162 |
+
if constexpr (!GQA) {
|
163 |
+
return typename CollectiveEpilogue::StridedKV {params.dv_row_stride, _1{}, params.dv_head_stride, !is_varlen_k ? params.dv_batch_stride : 0}; // stride_dV
|
164 |
+
} else {
|
165 |
+
return typename CollectiveEpilogue::StridedKV {_1{}, params.dv_rounded * seqlen_k_rounded, !is_varlen_k ? params.h_k * params.dv_rounded * params.seqlen_k_rounded : 0}; // stride_dVaccum
|
166 |
+
}
|
167 |
+
}(),
|
168 |
+
params.h,
|
169 |
+
params.dk_semaphore,
|
170 |
+
params.dv_semaphore,
|
171 |
+
params.cu_seqlens_k,
|
172 |
+
params.seqused_k,
|
173 |
+
};
|
174 |
+
|
175 |
+
int num_blocks_n = cutlass::ceil_div(params.seqlen_k, get<1>(TileShape_MNK{}));
|
176 |
+
num_blocks_n = cutlass::round_up(num_blocks_n, size<1>(ClusterShape{}));
|
177 |
+
typename flash::TileSchedulerArguments scheduler_args {
|
178 |
+
num_blocks_n, params.h, params.b, 1 /*num_splits*/,
|
179 |
+
params.h / params.h_k,
|
180 |
+
params.seqlen_k,
|
181 |
+
params.seqlen_q, params.d, params.dv, sizeof(Element),
|
182 |
+
params.tile_count_semaphore, params.cu_seqlens_k, params.seqused_k
|
183 |
+
};
|
184 |
+
|
185 |
+
int device;
|
186 |
+
cudaGetDevice(&device);
|
187 |
+
typename AttnKernel::Params kernel_params = AttnKernel::to_underlying_arguments({
|
188 |
+
mainloop_args, epilogue_args, {device, params.num_sm}, scheduler_args
|
189 |
+
});
|
190 |
+
|
191 |
+
dim3 grid_dims = AttnKernel::get_grid_shape(kernel_params);
|
192 |
+
dim3 block_dims = AttnKernel::get_block_shape();
|
193 |
+
int smem_size = AttnKernel::SharedStorageSize;
|
194 |
+
// int smem_size_q = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_q));
|
195 |
+
// int smem_size_do = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_do));
|
196 |
+
// int smem_size_ds = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_ds));
|
197 |
+
// int smem_size_dqacc = [&] {
|
198 |
+
// if constexpr (Arch >= 90) {
|
199 |
+
// return sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_dqacc));
|
200 |
+
// } else {
|
201 |
+
// return 0;
|
202 |
+
// }
|
203 |
+
// }();
|
204 |
+
// int smem_size_k = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_k));
|
205 |
+
// int smem_size_v = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_v));
|
206 |
+
// int smem_size_lse = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_lse));
|
207 |
+
// int smem_size_dpsum = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_dpsum));
|
208 |
+
// printf("smem_size = %d, q = %d, k = %d, v = %d, do = %d, ds = %d, dqacc = %d, lse = %d, dpsum = %d\n", smem_size, smem_size_q, smem_size_k, smem_size_v, smem_size_do, smem_size_ds, smem_size_dqacc, smem_size_lse, smem_size_dpsum);
|
209 |
+
if constexpr (size(ClusterShape{}) > 1) {
|
210 |
+
void const* kernel = (void const*) cutlass::device_kernel<AttnKernel>;
|
211 |
+
if (smem_size >= 48 * 1024) {
|
212 |
+
CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
|
213 |
+
}
|
214 |
+
dim3 cluster_dims(size<0>(ClusterShape{}), size<1>(ClusterShape{}), size<2>(ClusterShape{}));
|
215 |
+
cutlass::ClusterLauncher::launch(
|
216 |
+
grid_dims, cluster_dims, block_dims, smem_size, stream, kernel, kernel_params, false /*launch_with_pdl*/);
|
217 |
+
} else {
|
218 |
+
if (smem_size >= 48 * 1024) {
|
219 |
+
CHECK_CUDA(cudaFuncSetAttribute(cutlass::device_kernel<AttnKernel>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
|
220 |
+
}
|
221 |
+
cutlass::kernel_launch<AttnKernel>(grid_dims, block_dims, smem_size, stream, kernel_params, false /*launch_with_pdl*/);
|
222 |
+
}
|
223 |
+
CHECK_CUDA_KERNEL_LAUNCH();
|
224 |
+
|
225 |
+
using PostprocessKernel = flash::FlashAttnBwdPostprocessConvertdQ<TileShape_MK, Element, ElementAccum, ArchTag,
|
226 |
+
AttnKernel::CollectiveMainloop::NumMmaThreads,
|
227 |
+
typename AttnKernel::CollectiveMainloop::TiledMmadQ,
|
228 |
+
AttnKernel::CollectiveMainloop::dQ_swapAB
|
229 |
+
>;
|
230 |
+
typename PostprocessKernel::Arguments postprocess_args {
|
231 |
+
static_cast<ElementAccum const*>(params.dq_accum_ptr),
|
232 |
+
{seqlen_q_rounded * params.d_rounded, params.h, batch_q}, // shape_dQaccum
|
233 |
+
{_1{}, seqlen_q_rounded * params.d_rounded, !is_varlen_q ? params.d_rounded * params.seqlen_q_rounded * params.h : 0}, // stride_dQaccum
|
234 |
+
static_cast<Element*>(params.dq_ptr),
|
235 |
+
{seqlen_q, params.d, params.h, batch_q}, // shape_dQ
|
236 |
+
{params.dq_row_stride, _1{}, params.dq_head_stride, params.dq_batch_stride}, // stride_dQ
|
237 |
+
params.scale_softmax,
|
238 |
+
params.cu_seqlens_q,
|
239 |
+
params.seqused_q
|
240 |
+
};
|
241 |
+
typename PostprocessKernel::Params postprocess_params = PostprocessKernel::to_underlying_arguments(postprocess_args);
|
242 |
+
int num_m_block_postprocess = cute::ceil_div(params.seqlen_q, get<0>(TileShape_MK{}));
|
243 |
+
dim3 grid_m_postprocess(num_m_block_postprocess, params.h, params.b);
|
244 |
+
int smem_size_postprocess = PostprocessKernel::SharedStorageSize;
|
245 |
+
if (smem_size_postprocess >= 48 * 1024) {
|
246 |
+
CHECK_CUDA(cudaFuncSetAttribute(cutlass::device_kernel<PostprocessKernel>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_postprocess));
|
247 |
+
}
|
248 |
+
cutlass::kernel_launch<PostprocessKernel>(grid_m_postprocess, PostprocessKernel::MaxThreadsPerBlock, smem_size_postprocess, stream, postprocess_params, false /*launch_with_pdl*/);
|
249 |
+
CHECK_CUDA_KERNEL_LAUNCH();
|
250 |
+
|
251 |
+
if constexpr (GQA) {
|
252 |
+
using TileShape_NK = cute::Shape<Int<kBlockN>, Int<kHeadDim>>;
|
253 |
+
using PostprocessKerneldKV = flash::FlashAttnBwdPostprocessConvertdQ<TileShape_NK, Element, ElementAccum, ArchTag,
|
254 |
+
AttnKernel::CollectiveEpilogue::NumEpilogueThreads,
|
255 |
+
typename AttnKernel::CollectiveMainloop::TiledMmadKV,
|
256 |
+
AttnKernel::CollectiveMainloop::dKV_swapAB
|
257 |
+
>;
|
258 |
+
typename PostprocessKerneldKV::Arguments postprocess_dK_args {
|
259 |
+
static_cast<ElementAccum const*>(params.dk_accum_ptr),
|
260 |
+
{seqlen_k_rounded * params.d_rounded, params.h_k, batch_k}, // shape_dKaccum
|
261 |
+
{_1{}, seqlen_k_rounded * params.d_rounded, !is_varlen_k ? params.d_rounded * params.seqlen_k_rounded * params.h_k : 0}, // stride_dKaccum
|
262 |
+
static_cast<Element*>(params.dk_ptr),
|
263 |
+
{seqlen_k, params.d, params.h_k, batch_k}, // shape_dK
|
264 |
+
{params.dk_row_stride, _1{}, params.dk_head_stride, params.dk_batch_stride}, // stride_dK
|
265 |
+
1.f,
|
266 |
+
params.cu_seqlens_k,
|
267 |
+
params.seqused_k
|
268 |
+
};
|
269 |
+
typename PostprocessKerneldKV::Params postprocess_dK_params = PostprocessKerneldKV::to_underlying_arguments(postprocess_dK_args);
|
270 |
+
typename PostprocessKerneldKV::Arguments postprocess_dV_args {
|
271 |
+
static_cast<ElementAccum const*>(params.dv_accum_ptr),
|
272 |
+
{seqlen_k_rounded * params.dv_rounded, params.h_k, batch_k}, // shape_dVaccum
|
273 |
+
{_1{}, seqlen_k_rounded * params.dv_rounded, !is_varlen_k ? params.dv_rounded * params.seqlen_k_rounded * params.h_k : 0}, // stride_dVaccum
|
274 |
+
static_cast<Element*>(params.dv_ptr),
|
275 |
+
{seqlen_k, params.dv, params.h_k, batch_k}, // shape_dV
|
276 |
+
{params.dv_row_stride, _1{}, params.dv_head_stride, params.dv_batch_stride}, // stride_dV
|
277 |
+
1.f,
|
278 |
+
params.cu_seqlens_k,
|
279 |
+
params.seqused_k
|
280 |
+
};
|
281 |
+
typename PostprocessKerneldKV::Params postprocess_dV_params = PostprocessKerneldKV::to_underlying_arguments(postprocess_dV_args);
|
282 |
+
int num_n_block_postprocess = cute::ceil_div(params.seqlen_k, get<0>(TileShape_NK{}));
|
283 |
+
dim3 grid_n_postprocess(num_n_block_postprocess, params.h_k, params.b);
|
284 |
+
int smem_size_postprocess = PostprocessKerneldKV::SharedStorageSize;
|
285 |
+
if (smem_size_postprocess >= 48 * 1024) {
|
286 |
+
CHECK_CUDA(cudaFuncSetAttribute(cutlass::device_kernel<PostprocessKerneldKV>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_postprocess));
|
287 |
+
}
|
288 |
+
cutlass::kernel_launch<PostprocessKerneldKV>(grid_n_postprocess, PostprocessKerneldKV::MaxThreadsPerBlock, smem_size_postprocess, stream, postprocess_dK_params, false /*launch_with_pdl*/);
|
289 |
+
CHECK_CUDA_KERNEL_LAUNCH();
|
290 |
+
cutlass::kernel_launch<PostprocessKerneldKV>(grid_n_postprocess, PostprocessKerneldKV::MaxThreadsPerBlock, smem_size_postprocess, stream, postprocess_dV_params, false /*launch_with_pdl*/);
|
291 |
+
CHECK_CUDA_KERNEL_LAUNCH();
|
292 |
+
}
|
293 |
+
|
294 |
+
}
|
295 |
+
|
296 |
+
template<int Arch, typename T, int kBlockM, int kBlockN, int kHeadDim, bool Is_causal, bool Is_local, bool Has_softcap,
|
297 |
+
int Stages_dO=2, int Stages_dS_or_QSm80=2,
|
298 |
+
bool SdP_swapAB=true, bool dKV_swapAB=false, bool dQ_swapAB=false,
|
299 |
+
int NumMmaWarpGroups=2, int AtomLayoutMSdP=1, int AtomLayoutNdKV=2, int AtomLayoutMdQ=1,
|
300 |
+
bool V_in_regs=false>
|
301 |
+
void run_mha_bwd_dispatch(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
302 |
+
VARLEN_SWITCH(params.cu_seqlens_q != nullptr || params.cu_seqlens_k != nullptr, Varlen, [&] {
|
303 |
+
BOOL_SWITCH(params.h != params.h_k, GQA, [&] {
|
304 |
+
// BOOL_SWITCH(params.deterministic, Deterministic, [&] {
|
305 |
+
// run_flash_bwd<kHeadDim, kBlockM, kBlockN, T, Is_causal, Is_local, Has_softcap, Varlen, false, GQA, Stages_dO, Stages_dS_or_QSm80, SdP_swapAB, dKV_swapAB, dQ_swapAB, NumMmaWarpGroups, AtomLayoutMSdP, AtomLayoutNdKV, AtomLayoutMdQ>(params, stream);
|
306 |
+
run_flash_bwd<Arch, kHeadDim, kBlockM, kBlockN, T, Is_causal, Is_local, Has_softcap, Varlen /*Varlen*/, false /*Deterministic*/, GQA, Stages_dO, Stages_dS_or_QSm80, SdP_swapAB, dKV_swapAB, dQ_swapAB, NumMmaWarpGroups, AtomLayoutMSdP, AtomLayoutNdKV, AtomLayoutMdQ, V_in_regs>(params, stream);
|
307 |
+
// });
|
308 |
+
});
|
309 |
+
});
|
310 |
+
}
|
311 |
+
|
312 |
+
|
313 |
+
template<int Arch, typename T, bool Has_softcap>
|
314 |
+
void run_mha_bwd_hdim64(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
315 |
+
CAUSAL_LOCAL_SWITCH(params.is_causal, params.is_local, Is_causal, Is_local, [&] {
|
316 |
+
if constexpr (Arch >= 90) {
|
317 |
+
if constexpr (Is_causal && Has_softcap) {
|
318 |
+
// register spill with 128 x 128
|
319 |
+
run_mha_bwd_dispatch<Arch, T, 96, 128, 64, Is_causal, Is_local, Has_softcap, 2, 2, true, false, true, 2, 1, 2, 2, false>(params, stream);
|
320 |
+
} else {
|
321 |
+
// With ShuffleStats we no longer have register spilling when Has_softcap and using 128 x 128 block.
|
322 |
+
run_mha_bwd_dispatch<Arch, T, 128, 128, 64, Is_causal, Is_local, Has_softcap, 2, 2, true, false, false, 2, 1, 2, 2, false>(params, stream);
|
323 |
+
}
|
324 |
+
} else if constexpr (Arch == 86 || Arch == 89) {
|
325 |
+
run_mha_bwd_dispatch<Arch, T, 64, 128, 64, Is_causal, Is_local, Has_softcap, 2, 2, false, false, false, 2, 2, 4, 2, true>(params, stream);
|
326 |
+
// run_mha_bwd_dispatch<Arch, T, 96, 96, 64, Is_causal, Is_local, Has_softcap, 1, 2, false, true, true, 2, 2, 4, 4, false>(params, stream);
|
327 |
+
// run_mha_bwd_dispatch<Arch, T, 80, 128, 64, Is_causal, Is_local, Has_softcap, 1, 2, true, false, true, 2, 2, 4, 2, true>(params, stream);
|
328 |
+
// run_mha_bwd_dispatch<Arch, T, 96, 128, 64, Is_causal, Is_local, Has_softcap, 1, 2, true, false, true, 2, 1, 8, 4, false>(params, stream);
|
329 |
+
} else {
|
330 |
+
run_mha_bwd_dispatch<Arch, T, 128, 128, 64, Is_causal, Is_local, Has_softcap, 2, 2, false, false, false, 2, 4, 4, 4, false>(params, stream);
|
331 |
+
}
|
332 |
+
});
|
333 |
+
}
|
334 |
+
|
335 |
+
template<int Arch, typename T, bool Has_softcap>
|
336 |
+
void run_mha_bwd_hdim96(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
337 |
+
CAUSAL_LOCAL_SWITCH(params.is_causal, params.is_local, Is_causal, Is_local, [&] {
|
338 |
+
if constexpr (Arch >= 90) {
|
339 |
+
run_mha_bwd_dispatch<Arch, T, 64, 128, 96, Is_causal, Is_local, Has_softcap, 2, 2, true, false, false, 2, 1, 2, 1, true>(params, stream);
|
340 |
+
} else if constexpr (Arch == 86 || Arch == 89) {
|
341 |
+
run_mha_bwd_dispatch<Arch, T, 64, 128, 96, Is_causal, Is_local, Has_softcap, 1, 2, false, false, false, 2, 2, 4, 2, true>(params, stream);
|
342 |
+
} else {
|
343 |
+
run_mha_bwd_dispatch<Arch, T, 64, 128, 96, Is_causal, Is_local, Has_softcap, 2, 2, false, false, false, 2, 2, 4, 2, false>(params, stream);
|
344 |
+
}
|
345 |
+
});
|
346 |
+
}
|
347 |
+
|
348 |
+
template<int Arch, typename T, bool Has_softcap>
|
349 |
+
void run_mha_bwd_hdim128(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
350 |
+
CAUSAL_LOCAL_SWITCH(params.is_causal, params.is_local, Is_causal, Is_local, [&] {
|
351 |
+
if constexpr (Arch >= 90) {
|
352 |
+
if constexpr (Is_causal || Is_local || Has_softcap) {
|
353 |
+
run_mha_bwd_dispatch<Arch, T, 64, 128, 128, Is_causal, Is_local, Has_softcap, 2, 2, true, false, false, 2, 1, 2, 1, false>(params, stream);
|
354 |
+
} else {
|
355 |
+
run_mha_bwd_dispatch<Arch, T, 80, 128, 128, Is_causal, Is_local, Has_softcap, 2, 2, true, false, true, 2, 1, 2, 1, false>(params, stream);
|
356 |
+
}
|
357 |
+
} else if constexpr (Arch == 86 || Arch == 89) {
|
358 |
+
run_mha_bwd_dispatch<Arch, T, 64, 96, 128, Is_causal, Is_local, Has_softcap, 1, 2, false, false, false, 2, 2, 2, 2, true>(params, stream);
|
359 |
+
} else {
|
360 |
+
run_mha_bwd_dispatch<Arch, T, 64, 128, 128, Is_causal, Is_local, Has_softcap, 2, 2, false, false, false, 2, 2, 2, 2, false>(params, stream);
|
361 |
+
}
|
362 |
+
});
|
363 |
+
}
|
364 |
+
|
365 |
+
template<int Arch, typename T, bool Has_softcap>
|
366 |
+
void run_mha_bwd_hdim192(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
367 |
+
CAUSAL_LOCAL_SWITCH(params.is_causal, params.is_local, Is_causal, Is_local, [&] {
|
368 |
+
if constexpr (Arch >= 90) {
|
369 |
+
run_mha_bwd_dispatch<Arch, T, 64, 96, 192, Is_causal, Is_local, Has_softcap, 1, 1, false, true, false, 3, 1, 1, 1, false>(params, stream);
|
370 |
+
} else if constexpr (Arch == 86 || Arch == 89) {
|
371 |
+
run_mha_bwd_dispatch<Arch, T, 64, 64, 192, Is_causal, Is_local, Has_softcap, 1, 1, false, false, false, 2, 2, 2, 2, true>(params, stream);
|
372 |
+
} else {
|
373 |
+
run_mha_bwd_dispatch<Arch, T, 64, 80, 192, Is_causal, Is_local, Has_softcap, 1, 2, false, true, false, 2, 4, 2, 2, false>(params, stream);
|
374 |
+
}
|
375 |
+
});
|
376 |
+
}
|
377 |
+
|
378 |
+
template<int Arch, typename T, bool Has_softcap>
|
379 |
+
void run_mha_bwd_hdim256(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
380 |
+
CAUSAL_LOCAL_SWITCH(params.is_causal, params.is_local, Is_causal, Is_local, [&] {
|
381 |
+
if constexpr (Arch >= 90) {
|
382 |
+
run_mha_bwd_dispatch<Arch, T, 64, 80, 256, Is_causal, Is_local, Has_softcap, 1, 1, false, true, true, 2, 1, 1, 1, false>(params, stream);
|
383 |
+
} else if constexpr (Arch == 86 || Arch == 89) {
|
384 |
+
run_mha_bwd_dispatch<Arch, T, 32, 64, 256, Is_causal, Is_local, Has_softcap, 1, 1, false, false, false, 2, 2, 2, 1, true>(params, stream);
|
385 |
+
// run_mha_bwd_dispatch<Arch, T, 64, 32, 256, Is_causal, Is_local, Has_softcap, 1, 1, false, false, false, 2, 4, 1, 2, true>(params, stream);
|
386 |
+
} else {
|
387 |
+
run_mha_bwd_dispatch<Arch, T, 64, 64, 256, Is_causal, Is_local, Has_softcap, 1, 1, false, false, false, 2, 4, 2, 2, false>(params, stream);
|
388 |
+
}
|
389 |
+
});
|
390 |
+
}
|
flash-attn/flash_bwd_postprocess_kernel.h
ADDED
@@ -0,0 +1,256 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/******************************************************************************
|
2 |
+
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
3 |
+
******************************************************************************/
|
4 |
+
|
5 |
+
#pragma once
|
6 |
+
|
7 |
+
#include "cute/tensor.hpp"
|
8 |
+
|
9 |
+
#include <cutlass/cutlass.h>
|
10 |
+
#include <cutlass/array.h>
|
11 |
+
#include <cutlass/numeric_types.h>
|
12 |
+
#include <cutlass/numeric_conversion.h>
|
13 |
+
#include "cutlass/arch/barrier.h"
|
14 |
+
|
15 |
+
#include "seqlen.h"
|
16 |
+
#include "utils.h"
|
17 |
+
|
18 |
+
namespace flash {
|
19 |
+
|
20 |
+
using namespace cute;
|
21 |
+
|
22 |
+
template <class TileShape_MK_, class Element, class ElementAccum, class ArchTag_, int kNThreads, class TiledMma, bool dQ_swapAB>
|
23 |
+
class FlashAttnBwdPostprocessConvertdQ {
|
24 |
+
|
25 |
+
public:
|
26 |
+
|
27 |
+
// Type Aliases
|
28 |
+
using TileShape_MK = TileShape_MK_;
|
29 |
+
using ArchTag = ArchTag_;
|
30 |
+
|
31 |
+
static_assert(ArchTag::kMinComputeCapability >= 75);
|
32 |
+
static constexpr bool IsSm90 = ArchTag::kMinComputeCapability >= 90;
|
33 |
+
|
34 |
+
static constexpr uint32_t MaxThreadsPerBlock = kNThreads;
|
35 |
+
static constexpr uint32_t MinBlocksPerMultiprocessor = 2;
|
36 |
+
|
37 |
+
static constexpr int kBlockM = get<0>(TileShape_MK{});
|
38 |
+
static constexpr int kHeadDim = get<1>(TileShape_MK{});
|
39 |
+
static_assert(!IsSm90 || kNThreads % cutlass::NumThreadsPerWarpGroup == 0, "kNThreads must be a multiple of NumThreadsPerWarpGroup");
|
40 |
+
static constexpr int NumdQWarpGgroups = kNThreads / cutlass::NumThreadsPerWarpGroup;
|
41 |
+
using R2SLayoutAtomdQaccum = std::conditional_t<
|
42 |
+
IsSm90,
|
43 |
+
Layout<Shape<Int<cutlass::NumThreadsPerWarpGroup>, Int<NumdQWarpGgroups>>>,
|
44 |
+
Layout<Shape<Int<kNThreads>>>
|
45 |
+
>;
|
46 |
+
using R2STiledCopydQaccum = decltype(make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{}, R2SLayoutAtomdQaccum{},
|
47 |
+
Layout<Shape<Int<IsSm90 ? 4 : 1>>>{})); // Val layout, 1 or 4 vals per read
|
48 |
+
using G2SLayoutAtomdQaccum = Layout<Shape<Int<kNThreads>>>;
|
49 |
+
// UniversalCopy instead of AutoVectorizingCopyWithAssumedAlignment as the latter generates cp.async instructions
|
50 |
+
using G2STiledCopydQaccum = decltype(make_tiled_copy(Copy_Atom<UniversalCopy<uint128_t>, ElementAccum>{}, G2SLayoutAtomdQaccum{},
|
51 |
+
Layout<Shape<_4>>{})); // Val layout, 4 vals per read
|
52 |
+
// We don't do bound checking for the gmem -> smem load so we just assert here.
|
53 |
+
static_assert(IsSm90 || (kBlockM * kHeadDim) % (kNThreads * 4) == 0);
|
54 |
+
static constexpr int SmemdQaccumSize = size(TileShape_MK{});
|
55 |
+
using SmemLayoutdQaccumFlat = Layout<Shape<Int<SmemdQaccumSize>>>;
|
56 |
+
using SmemLayoutdQaccum = std::conditional_t<
|
57 |
+
IsSm90,
|
58 |
+
Layout<Shape<Int<kBlockM * kHeadDim / NumdQWarpGgroups>, Int<NumdQWarpGgroups>>>,
|
59 |
+
Layout<Shape<Int<kBlockM * kHeadDim>>>
|
60 |
+
>;
|
61 |
+
|
62 |
+
// We can't just use kHeadDim here. E.g. if MMA shape is 64 x 96 but split across 2 WGs,
|
63 |
+
// then setting kBlockKSmem to 32 will cause "Static shape_div failure".
|
64 |
+
// We want to treat it as 64 x 48, so kBlockKSmem should be 16.
|
65 |
+
static constexpr int MmaShapeN = get<1>(typename TiledMma::AtomShape_MNK{});
|
66 |
+
static constexpr int kBlockKSmem = MmaShapeN % 64 == 0 ? 64 : (MmaShapeN % 32 == 0 ? 32 : 16);
|
67 |
+
static constexpr int kSwizzle = kBlockKSmem == 64 ? 3 : (kBlockKSmem == 32 ? 2 : 1);
|
68 |
+
using SmemLayoutAtomdQ =
|
69 |
+
decltype(composition(Swizzle<kSwizzle, 3, 3>{},
|
70 |
+
Layout<Shape<Int<8>, Int<kBlockKSmem>>,
|
71 |
+
Stride<Int<kBlockKSmem>, _1>>{}));
|
72 |
+
using SmemLayoutdQ = decltype(tile_to_shape(SmemLayoutAtomdQ{}, TileShape_MK{}));
|
73 |
+
using SmemLayoutdQt =
|
74 |
+
decltype(cute::composition(SmemLayoutdQ{},
|
75 |
+
make_layout(make_shape(get<1>(TileShape_MK{}), get<0>(TileShape_MK{})),
|
76 |
+
make_stride(Int<get<0>(TileShape_MK{})>{}, _1{}))));
|
77 |
+
|
78 |
+
using SmemCopyAtomdQ = Copy_Atom<
|
79 |
+
std::conditional_t<
|
80 |
+
IsSm90,
|
81 |
+
std::conditional_t<!dQ_swapAB, cute::SM90_U32x4_STSM_N, cute::SM90_U16x8_STSM_T>,
|
82 |
+
AutoVectorizingCopyWithAssumedAlignment<128>
|
83 |
+
>,
|
84 |
+
Element>;
|
85 |
+
|
86 |
+
static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
|
87 |
+
static_assert(kHeadDim % kGmemElemsPerLoad == 0, "Headdim must be a multiple of kGmemElemsPerLoad");
|
88 |
+
static constexpr int kGmemThreadsPerRow = cutlass::gcd(kHeadDim / kGmemElemsPerLoad, int(MaxThreadsPerBlock));
|
89 |
+
static_assert(MaxThreadsPerBlock % kGmemThreadsPerRow == 0, "MaxThreadsPerBlock must be a multiple of kGmemThreadsPerRow");
|
90 |
+
using GmemLayoutAtom = Layout<Shape <Int<MaxThreadsPerBlock / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
|
91 |
+
Stride<Int<kGmemThreadsPerRow>, _1>>;
|
92 |
+
using GmemTiledCopy = decltype(
|
93 |
+
make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>{},
|
94 |
+
GmemLayoutAtom{},
|
95 |
+
Layout<Shape<_1, Int<kGmemElemsPerLoad>>>{})); // Val layout, 8 or 16 vals per load
|
96 |
+
|
97 |
+
struct SharedStorage : cute::aligned_struct<128> {
|
98 |
+
cute::array_aligned<ElementAccum, cute::cosize_v<SmemLayoutdQaccum>> smem_dqacc;
|
99 |
+
cute::array_aligned<Element, cute::cosize_v<SmemLayoutdQ>> smem_dq;
|
100 |
+
alignas(16) cutlass::arch::ClusterTransactionBarrier barrier_dQaccum;
|
101 |
+
};
|
102 |
+
|
103 |
+
static constexpr int SharedStorageSize = sizeof(SharedStorage);
|
104 |
+
|
105 |
+
using ShapedQ = cute::Shape<int32_t, int32_t, int32_t, int32_t>; // (seqlen_q, d, head, batch)
|
106 |
+
using StridedQ = cute::Stride<int64_t, _1, int64_t, int64_t>;
|
107 |
+
using ShapedQaccum = cute::Shape<int32_t, int32_t, int32_t>; // (seqlen_q * d, head, batch)
|
108 |
+
using StridedQaccum = cute::Stride<_1, int64_t, int64_t>;
|
109 |
+
|
110 |
+
// Device side arguments
|
111 |
+
struct Arguments {
|
112 |
+
ElementAccum const* ptr_dQaccum;
|
113 |
+
ShapedQaccum const shape_dQaccum;
|
114 |
+
StridedQaccum const stride_dQaccum;
|
115 |
+
Element* ptr_dQ;
|
116 |
+
ShapedQ const shape_dQ;
|
117 |
+
StridedQ const stride_dQ;
|
118 |
+
float const softmax_scale;
|
119 |
+
int const* cu_seqlens = nullptr;
|
120 |
+
int const* seqused = nullptr;
|
121 |
+
};
|
122 |
+
|
123 |
+
// Kernel entry point API
|
124 |
+
struct Params {
|
125 |
+
ElementAccum const* ptr_dQaccum;
|
126 |
+
ShapedQaccum const shape_dQaccum;
|
127 |
+
StridedQaccum const stride_dQaccum;
|
128 |
+
Element* ptr_dQ;
|
129 |
+
ShapedQ const shape_dQ;
|
130 |
+
StridedQ const stride_dQ;
|
131 |
+
float const softmax_scale;
|
132 |
+
int const* cu_seqlens = nullptr;
|
133 |
+
int const* seqused = nullptr;
|
134 |
+
};
|
135 |
+
|
136 |
+
// Convert to underlying arguments. In this case, a simple copy for the aliased type.
|
137 |
+
static
|
138 |
+
Params
|
139 |
+
to_underlying_arguments(Arguments const& args) {
|
140 |
+
return {
|
141 |
+
args.ptr_dQaccum,
|
142 |
+
args.shape_dQaccum,
|
143 |
+
args.stride_dQaccum,
|
144 |
+
args.ptr_dQ,
|
145 |
+
args.shape_dQ,
|
146 |
+
args.stride_dQ,
|
147 |
+
args.softmax_scale,
|
148 |
+
args.cu_seqlens,
|
149 |
+
args.seqused
|
150 |
+
};
|
151 |
+
}
|
152 |
+
|
153 |
+
CUTLASS_DEVICE
|
154 |
+
void
|
155 |
+
operator()(Params const& params, char* smem_buf) {
|
156 |
+
|
157 |
+
static constexpr int kBlockM = get<0>(TileShape_MK{});
|
158 |
+
SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(smem_buf);
|
159 |
+
|
160 |
+
Tensor sdQaccum = make_tensor(make_smem_ptr(shared_storage.smem_dqacc.data()), SmemLayoutdQaccum{});
|
161 |
+
Tensor sdQaccum_flat = make_tensor(make_smem_ptr(shared_storage.smem_dqacc.data()), SmemLayoutdQaccumFlat{});
|
162 |
+
Tensor sdQ = make_tensor(make_smem_ptr(shared_storage.smem_dq.data()), SmemLayoutdQ{});
|
163 |
+
Tensor sdQt = make_tensor(make_smem_ptr(shared_storage.smem_dq.data()), SmemLayoutdQt{});
|
164 |
+
|
165 |
+
int const thread_idx = threadIdx.x;
|
166 |
+
int const m_block = blockIdx.x;
|
167 |
+
int const bidh = blockIdx.y;
|
168 |
+
int const bidb = blockIdx.z;
|
169 |
+
|
170 |
+
flash::SeqlenInfo<true /*Varlen*/, kBlockM> seqlen_info(bidb, size<0>(params.shape_dQ), params.cu_seqlens, params.seqused);
|
171 |
+
bool const is_varlen = params.cu_seqlens;
|
172 |
+
if (is_varlen && m_block * kBlockM >= seqlen_info.seqlen) { return; }
|
173 |
+
|
174 |
+
// Step 1: load dQaccum from gmem to smem
|
175 |
+
Tensor mdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum const*>(params.ptr_dQaccum)),
|
176 |
+
params.shape_dQaccum, params.stride_dQaccum)(_, bidh, !is_varlen ? bidb : 0);
|
177 |
+
Tensor gdQaccum = local_tile(domain_offset(make_coord(seqlen_info.offset_padded * kHeadDim), mdQaccum), Shape<Int<kBlockM * kHeadDim>>{}, make_coord(m_block)); // (M * K)
|
178 |
+
if constexpr (IsSm90) { // Use BulkCopy
|
179 |
+
static constexpr uint32_t TmaTransactionBytesdQaccum = static_cast<uint32_t>(size(SmemLayoutdQaccumFlat{}) * cute::sizeof_bits_v<ElementAccum> / 8);
|
180 |
+
auto bulk_copy = Copy_Traits<SM90_BULK_COPY_AUTO>{};
|
181 |
+
// if (thread0()) { print(gdQaccum); printf("\n"); print(sdQaccum_flat); printf("\n"); }
|
182 |
+
if (thread_idx == 0) {
|
183 |
+
shared_storage.barrier_dQaccum.init(1 /*numThreads*/);
|
184 |
+
shared_storage.barrier_dQaccum.arrive_and_expect_tx(TmaTransactionBytesdQaccum);
|
185 |
+
copy(bulk_copy.with(*reinterpret_cast<uint64_t*>(&shared_storage.barrier_dQaccum)), gdQaccum, sdQaccum_flat);
|
186 |
+
}
|
187 |
+
__syncthreads();
|
188 |
+
shared_storage.barrier_dQaccum.wait(0);
|
189 |
+
} else {
|
190 |
+
G2STiledCopydQaccum g2s_tiled_copy_dQaccum;
|
191 |
+
auto g2s_thr_copy_dQaccum = g2s_tiled_copy_dQaccum.get_thread_slice(thread_idx);
|
192 |
+
Tensor tdQgdQaccumg2s = g2s_thr_copy_dQaccum.partition_S(gdQaccum);
|
193 |
+
Tensor tdQsdQaccumg2s = g2s_thr_copy_dQaccum.partition_D(sdQaccum);
|
194 |
+
cute::copy(g2s_tiled_copy_dQaccum, tdQgdQaccumg2s, tdQsdQaccumg2s);
|
195 |
+
__syncthreads();
|
196 |
+
}
|
197 |
+
|
198 |
+
// __syncthreads(); if (cute::thread0()) { print_tensor(sdQaccum); }
|
199 |
+
|
200 |
+
// Step 2: Load dQaccum from smem to register, then convert fp32 -> fp16/bf16
|
201 |
+
R2STiledCopydQaccum s2r_tiled_copy_dQaccum;
|
202 |
+
auto s2r_thr_copy_dQaccum = s2r_tiled_copy_dQaccum.get_thread_slice(thread_idx);
|
203 |
+
Tensor tdQsdQaccum = s2r_thr_copy_dQaccum.partition_S(sdQaccum);
|
204 |
+
TiledMma tiled_mma_dQ;
|
205 |
+
Tensor taccdQrdQaccum = partition_fragment_C(tiled_mma_dQ, select<!dQ_swapAB ? 0 : 1, !dQ_swapAB ? 1 : 0>(TileShape_MK{}));
|
206 |
+
// if (blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 1) { print(tiled_mma_dQ); printf("\n"); }
|
207 |
+
// if (blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 1) { print(tdQsdQaccum); }
|
208 |
+
// if (blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 1) { print(taccdQrdQaccum); }
|
209 |
+
CUTE_STATIC_ASSERT_V(size(taccdQrdQaccum) == size(tdQsdQaccum));
|
210 |
+
Tensor tdQrdQaccum = s2r_thr_copy_dQaccum.retile_D(taccdQrdQaccum);
|
211 |
+
cute::copy(s2r_tiled_copy_dQaccum, tdQsdQaccum, tdQrdQaccum);
|
212 |
+
#pragma unroll
|
213 |
+
for (int i = 0; i < size(taccdQrdQaccum); ++i) { taccdQrdQaccum(i) *= params.softmax_scale; }
|
214 |
+
// Convert tdQrdQ from fp32 to fp16
|
215 |
+
Tensor rdQ = make_tensor_like<Element>(taccdQrdQaccum);
|
216 |
+
flash::convert_type_out(taccdQrdQaccum, rdQ);
|
217 |
+
|
218 |
+
// Step 3: Copy dQ from register to smem
|
219 |
+
auto smem_tiled_copy_dQ = make_tiled_copy_C(SmemCopyAtomdQ{}, tiled_mma_dQ);
|
220 |
+
auto smem_thr_copy_dQ = smem_tiled_copy_dQ.get_thread_slice(thread_idx);
|
221 |
+
Tensor taccdQrdQ = smem_thr_copy_dQ.retile_S(rdQ); // ((Atom,AtomNum), MMA_N, MMA_N)
|
222 |
+
// if (cute::thread0()) { print(smem_tiled_copy_dQ); }
|
223 |
+
// if (cute::thread0()) { print(smem_thr_copy_dQ); }
|
224 |
+
// if (cute::thread0()) { print(sdQ); }
|
225 |
+
Tensor taccdQsdQ = smem_thr_copy_dQ.partition_D(cute::conditional_return<!dQ_swapAB>(sdQ, sdQt)); // ((Atom,AtomNum),PIPE_M,PIPE_N)
|
226 |
+
cute::copy(smem_tiled_copy_dQ, taccdQrdQ, taccdQsdQ);
|
227 |
+
__syncthreads();
|
228 |
+
|
229 |
+
// Step 4: Copy dQ from smem to register to prepare for coalesced write to gmem
|
230 |
+
Tensor mdQ = make_tensor(make_gmem_ptr(params.ptr_dQ), params.shape_dQ, params.stride_dQ)(_, _, bidh, !is_varlen ? bidb : 0);
|
231 |
+
Tensor gdQ = local_tile(domain_offset(make_coord(seqlen_info.offset, _0{}), mdQ), TileShape_MK{}, make_coord(m_block, _0{})); // (M, K)
|
232 |
+
GmemTiledCopy gmem_tiled_copy_dQ;
|
233 |
+
auto gmem_thr_copy_dQ = gmem_tiled_copy_dQ.get_thread_slice(thread_idx);
|
234 |
+
Tensor tdQsdQ = gmem_thr_copy_dQ.partition_S(sdQ); // ((Atom,AtomNum),ATOM_M,ATOM_N)
|
235 |
+
Tensor tdQgdQ = gmem_thr_copy_dQ.partition_D(gdQ);
|
236 |
+
|
237 |
+
Tensor tdQrdQ = make_fragment_like(tdQsdQ);
|
238 |
+
Tensor tdQcdQ = gmem_thr_copy_dQ.partition_D(cute::make_identity_tensor(TileShape_MK{}));
|
239 |
+
Tensor tdQpdQ = make_tensor<bool>(make_shape(size<2>(tdQgdQ)));
|
240 |
+
#pragma unroll
|
241 |
+
for (int k = 0; k < size(tdQpdQ); ++k) { tdQpdQ(k) = get<1>(tdQcdQ(_0{}, _0{}, k)) < get<1>(params.shape_dQ); }
|
242 |
+
// Need to check OOB when reading from smem if kBlockM isn't evenly tiled
|
243 |
+
static constexpr bool EvenM = kBlockM % CUTE_STATIC_V(size<0>(GmemLayoutAtom{})) == 0;
|
244 |
+
flash::copy</*Is_even_MN=*/EvenM, /*Is_even_K=*/true, /*Clear_OOB_MN=*/false>(
|
245 |
+
gmem_tiled_copy_dQ, tdQsdQ, tdQrdQ, tdQcdQ, tdQpdQ, kBlockM);
|
246 |
+
|
247 |
+
// Step 5: Copy dQ from register to gmem
|
248 |
+
// Clear_OOB_K must be false since we don't want to write zeros to gmem
|
249 |
+
flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
|
250 |
+
gmem_tiled_copy_dQ, tdQrdQ, tdQgdQ, tdQcdQ, tdQpdQ, std::min(seqlen_info.seqlen - m_block * kBlockM, kBlockM)
|
251 |
+
);
|
252 |
+
}
|
253 |
+
|
254 |
+
};
|
255 |
+
|
256 |
+
} // namespace flash
|
flash-attn/flash_bwd_preprocess_kernel.h
ADDED
@@ -0,0 +1,252 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/******************************************************************************
|
2 |
+
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
3 |
+
******************************************************************************/
|
4 |
+
|
5 |
+
#pragma once
|
6 |
+
|
7 |
+
#include "cute/tensor.hpp"
|
8 |
+
|
9 |
+
#include <cutlass/cutlass.h>
|
10 |
+
#include <cutlass/array.h>
|
11 |
+
#include <cutlass/numeric_types.h>
|
12 |
+
#include <cutlass/numeric_conversion.h>
|
13 |
+
|
14 |
+
#include "seqlen.h"
|
15 |
+
#include "utils.h"
|
16 |
+
|
17 |
+
namespace flash {
|
18 |
+
|
19 |
+
using namespace cute;
|
20 |
+
|
21 |
+
template <class TileShape_MK_, class Element, class ElementAccum, class ArchTag_, bool Clear_dQaccum, bool Varlen>
|
22 |
+
class FlashAttnBwdPreprocess {
|
23 |
+
|
24 |
+
public:
|
25 |
+
|
26 |
+
// Type Aliases
|
27 |
+
using TileShape_MK = TileShape_MK_;
|
28 |
+
using ArchTag = ArchTag_;
|
29 |
+
|
30 |
+
static_assert(std::is_same_v<Element, cutlass::half_t> && ArchTag::kMinComputeCapability >= 75 ||
|
31 |
+
std::is_same_v<Element, cutlass::bfloat16_t> && ArchTag::kMinComputeCapability >= 80 ||
|
32 |
+
std::is_same_v<Element, cutlass::float_e4m3_t> && ArchTag::kMinComputeCapability >= 89);
|
33 |
+
|
34 |
+
static constexpr uint32_t MaxThreadsPerBlock = 256;
|
35 |
+
static constexpr uint32_t MinBlocksPerMultiprocessor = 2;
|
36 |
+
static constexpr int SharedStorageSize = 0;
|
37 |
+
|
38 |
+
static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
|
39 |
+
static_assert(get<1>(TileShape_MK{}) % kGmemElemsPerLoad == 0, "Headdim must be a multiple of kGmemElemsPerLoad");
|
40 |
+
static constexpr int kBlockM = get<0>(TileShape_MK{});
|
41 |
+
static constexpr int kHeadDim = get<1>(TileShape_MK{});
|
42 |
+
// We want kBlockKGmem to be a power of 2 so that when we do the summing,
|
43 |
+
// it's just between threads in the same warp
|
44 |
+
static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32);
|
45 |
+
static constexpr int kGmemThreadsPerRow = kBlockKGmem / kGmemElemsPerLoad;
|
46 |
+
static_assert(MaxThreadsPerBlock % kGmemThreadsPerRow == 0, "MaxThreadsPerBlock must be a multiple of kGmemThreadsPerRow");
|
47 |
+
using GmemLayoutAtom = Layout<Shape <Int<MaxThreadsPerBlock / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
|
48 |
+
Stride<Int<kGmemThreadsPerRow>, _1>>;
|
49 |
+
using GmemTiledCopy = decltype(
|
50 |
+
make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>{},
|
51 |
+
GmemLayoutAtom{},
|
52 |
+
Layout<Shape<_1, Int<kGmemElemsPerLoad>>>{})); // Val layout, 8 or 16 vals per load
|
53 |
+
|
54 |
+
static constexpr int kGmemElemsPerLoadAccum = sizeof(cute::uint128_t) / sizeof(ElementAccum);
|
55 |
+
static_assert((kBlockM * kHeadDim / kGmemElemsPerLoadAccum) % MaxThreadsPerBlock == 0, "MaxThreadsPerBlock must divide kBlockM * kHeadDim / kGmemElemsPerLoadAccum");
|
56 |
+
using GmemLayoutAtomAccum = Layout<Shape<Int<MaxThreadsPerBlock>>>;
|
57 |
+
using GmemTiledCopyAccum = decltype(
|
58 |
+
make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{},
|
59 |
+
GmemLayoutAtomAccum{},
|
60 |
+
Layout<Shape<Int<kGmemElemsPerLoadAccum>>>{})); // Val layout, 4 vals per store
|
61 |
+
|
62 |
+
using ShapeO = cute::Shape<int32_t, int32_t, int32_t, int32_t>; // (seqlen_q, d, head, batch)
|
63 |
+
using StrideO = cute::Stride<int64_t, _1, int64_t, int64_t>;
|
64 |
+
using ShapedPsum = cute::Shape<int32_t, int32_t, int32_t>; // (seqlen_q, head, batch)
|
65 |
+
using StridedPsum = cute::Stride<_1, int64_t, int64_t>;
|
66 |
+
using ShapedQaccum = cute::Shape<int32_t, int32_t, int32_t>; // (seqlen_q * d, head, batch)
|
67 |
+
using StridedQaccum = cute::Stride<_1, int64_t, int64_t>;
|
68 |
+
|
69 |
+
// Device side arguments
|
70 |
+
struct Arguments {
|
71 |
+
Element const* ptr_O;
|
72 |
+
ShapeO const shape_O;
|
73 |
+
StrideO const stride_O;
|
74 |
+
Element const* ptr_dO;
|
75 |
+
StrideO const stride_dO;
|
76 |
+
float* ptr_dPsum;
|
77 |
+
ShapedPsum const shape_dPsum;
|
78 |
+
StridedPsum const stride_dPsum;
|
79 |
+
float const* ptr_LSE;
|
80 |
+
StridedPsum const stride_LSE;
|
81 |
+
float *ptr_LSE_log2;
|
82 |
+
StridedPsum const stride_LSE_log2;
|
83 |
+
ElementAccum* ptr_dQaccum;
|
84 |
+
ShapedQaccum const shape_dQaccum;
|
85 |
+
StridedQaccum const stride_dQaccum;
|
86 |
+
int num_batch; // We need this to know the size of dq_semaphore in case of varlen
|
87 |
+
int* dq_semaphore;
|
88 |
+
int const* cu_seqlens = nullptr;
|
89 |
+
int const* seqused = nullptr;
|
90 |
+
};
|
91 |
+
|
92 |
+
// Kernel entry point API
|
93 |
+
struct Params {
|
94 |
+
Element const* ptr_O;
|
95 |
+
ShapeO const shape_O;
|
96 |
+
StrideO const stride_O;
|
97 |
+
Element const* ptr_dO;
|
98 |
+
StrideO const stride_dO;
|
99 |
+
float* ptr_dPsum;
|
100 |
+
ShapedPsum const shape_dPsum;
|
101 |
+
StridedPsum const stride_dPsum;
|
102 |
+
float const* ptr_LSE;
|
103 |
+
StridedPsum const stride_LSE;
|
104 |
+
float* ptr_LSE_log2;
|
105 |
+
StridedPsum const stride_LSE_log2;
|
106 |
+
ElementAccum* ptr_dQaccum;
|
107 |
+
ShapedQaccum const shape_dQaccum;
|
108 |
+
StridedQaccum const stride_dQaccum;
|
109 |
+
int num_batch;
|
110 |
+
int* dq_semaphore;
|
111 |
+
int const* cu_seqlens = nullptr;
|
112 |
+
int const* seqused = nullptr;
|
113 |
+
};
|
114 |
+
|
115 |
+
// Convert to underlying arguments. In this case, a simple copy for the aliased type.
|
116 |
+
static
|
117 |
+
Params
|
118 |
+
to_underlying_arguments(Arguments const& args) {
|
119 |
+
return {
|
120 |
+
args.ptr_O,
|
121 |
+
args.shape_O,
|
122 |
+
args.stride_O,
|
123 |
+
args.ptr_dO,
|
124 |
+
args.stride_dO,
|
125 |
+
args.ptr_dPsum,
|
126 |
+
args.shape_dPsum,
|
127 |
+
args.stride_dPsum,
|
128 |
+
args.ptr_LSE,
|
129 |
+
args.stride_LSE,
|
130 |
+
args.ptr_LSE_log2,
|
131 |
+
args.stride_LSE_log2,
|
132 |
+
args.ptr_dQaccum,
|
133 |
+
args.shape_dQaccum,
|
134 |
+
args.stride_dQaccum,
|
135 |
+
args.num_batch,
|
136 |
+
args.dq_semaphore,
|
137 |
+
args.cu_seqlens,
|
138 |
+
args.seqused
|
139 |
+
};
|
140 |
+
}
|
141 |
+
|
142 |
+
CUTLASS_DEVICE
|
143 |
+
void
|
144 |
+
operator()(Params const& params, [[maybe_unused]] char* smem_buf) {
|
145 |
+
|
146 |
+
static constexpr int kBlockM = get<0>(TileShape_MK{});
|
147 |
+
|
148 |
+
int const thread_idx = threadIdx.x;
|
149 |
+
int const m_block = blockIdx.x;
|
150 |
+
int const bidh = blockIdx.y;
|
151 |
+
int const bidb = blockIdx.z;
|
152 |
+
|
153 |
+
flash::SeqlenInfo<Varlen, kBlockM> seqlen_info(bidb, size<0>(params.shape_O), params.cu_seqlens, params.seqused);
|
154 |
+
bool const is_varlen = Varlen && params.cu_seqlens;
|
155 |
+
int const seqlen_o = seqlen_info.seqlen;
|
156 |
+
if (is_varlen && m_block * kBlockM >= seqlen_o) { return; }
|
157 |
+
|
158 |
+
Tensor mO = make_tensor(make_gmem_ptr(params.ptr_O), params.shape_O, params.stride_O)(_, _, bidh, !is_varlen ? bidb : 0);
|
159 |
+
Tensor gO = local_tile(cute::domain_offset(make_coord(seqlen_info.offset, _0{}), mO), TileShape_MK{}, make_coord(m_block, _0{})); // (M, K)
|
160 |
+
Tensor mdO = make_tensor(make_gmem_ptr(params.ptr_dO), params.shape_O, params.stride_dO)(_, _, bidh, !is_varlen ? bidb : 0);
|
161 |
+
Tensor gdO = local_tile(cute::domain_offset(make_coord(seqlen_info.offset, _0{}), mdO), TileShape_MK{}, make_coord(m_block, _0{})); // (M, K)
|
162 |
+
|
163 |
+
auto shape_LSE = select<0, 2, 3>(params.shape_O);
|
164 |
+
Tensor mLSE = make_tensor(make_gmem_ptr(params.ptr_LSE), shape_LSE, params.stride_LSE)(_, bidh, !is_varlen ? bidb : 0);
|
165 |
+
Tensor gLSE = local_tile(cute::domain_offset(make_coord(seqlen_info.offset), mLSE), Shape<Int<kBlockM>>{}, make_coord(m_block));
|
166 |
+
static_assert(kBlockM <= MaxThreadsPerBlock);
|
167 |
+
float lse = thread_idx < seqlen_o - m_block * kBlockM && thread_idx < kBlockM ? gLSE(thread_idx) : INFINITY;
|
168 |
+
|
169 |
+
GmemTiledCopy gmem_tiled_copy_O;
|
170 |
+
auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx);
|
171 |
+
|
172 |
+
Tensor tOgO = gmem_thr_copy_O.partition_S(gO);
|
173 |
+
Tensor tOgdO = gmem_thr_copy_O.partition_S(gdO);
|
174 |
+
// Construct identity layout for gO
|
175 |
+
Tensor cO = cute::make_identity_tensor(TileShape_MK{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
|
176 |
+
// Repeat the partitioning with identity layouts
|
177 |
+
Tensor tOcO = gmem_thr_copy_O.partition_D(cO);
|
178 |
+
Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOgO)));
|
179 |
+
#pragma unroll
|
180 |
+
for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < get<1>(params.shape_O); }
|
181 |
+
|
182 |
+
// (8, kBlockM / 32, kHeadDim / 64) or (8, kBlockM / 16, kHeadDim / 128)
|
183 |
+
Tensor tOrO = make_fragment_like(tOgO);
|
184 |
+
Tensor tOrdO = make_fragment_like(tOgdO);
|
185 |
+
flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/true, /*Clearn_OOB_K=*/true>(
|
186 |
+
gmem_tiled_copy_O, tOgO, tOrO, tOcO, tOpO, seqlen_o - m_block * kBlockM
|
187 |
+
);
|
188 |
+
flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/true, /*Clearn_OOB_K=*/true>(
|
189 |
+
gmem_tiled_copy_O, tOgdO, tOrdO, tOcO, tOpO, seqlen_o - m_block * kBlockM
|
190 |
+
);
|
191 |
+
// if (threadIdx.x == 222) { printf("bidx = %d, bidy = %d, bidz = %d, seqlen_o = %d, m_block = %d, seqlen_o - m_block * kBlockM = %d, tOgO addr = %p\n", blockIdx.x, blockIdx.y, blockIdx.z, seqlen_o, m_block, seqlen_o - m_block * kBlockM, &tOgO(0));}
|
192 |
+
|
193 |
+
// Reshape from e.g. (8, kBlockM / 32, kHeadDim / 64) to (kBlockM / 32, (8, kHeadDim / 64))
|
194 |
+
Layout l = make_layout(get<1>(tOrO.layout()), make_layout(get<0>(tOrO.layout()), get<2>(tOrO.layout())));
|
195 |
+
Tensor tOrO_l = make_tensor(tOrO.data(), l);
|
196 |
+
Tensor o_fp32 = make_tensor_like<float>(tOrO_l);
|
197 |
+
flash::convert_type_out(tOrO_l, o_fp32);
|
198 |
+
Tensor tOrdO_l = make_tensor(tOrdO.data(), l);
|
199 |
+
Tensor do_fp32 = make_tensor_like<float>(tOrdO_l);
|
200 |
+
flash::convert_type_out(tOrdO_l, do_fp32);
|
201 |
+
// Sum across the last dimension
|
202 |
+
Tensor dP_sum = make_tensor<float>(make_shape(size<0>(o_fp32)));
|
203 |
+
#pragma unroll
|
204 |
+
for (int mi = 0; mi < size<0>(o_fp32); ++mi) {
|
205 |
+
float dP_sum_cur = do_fp32(mi, 0) * o_fp32(mi, 0);
|
206 |
+
#pragma unroll
|
207 |
+
for (int ni = 1; ni < size<1>(o_fp32); ni++) {
|
208 |
+
dP_sum_cur += do_fp32(mi, ni) * o_fp32(mi, ni);
|
209 |
+
}
|
210 |
+
flash::SumOp<float> sum_op;
|
211 |
+
dP_sum(mi) = flash::Allreduce<kGmemThreadsPerRow>::run(dP_sum_cur, sum_op);
|
212 |
+
}
|
213 |
+
|
214 |
+
Tensor mdPsum = make_tensor(make_gmem_ptr(params.ptr_dPsum), params.shape_dPsum, params.stride_dPsum)(_, bidh, !is_varlen ? bidb : 0);
|
215 |
+
Tensor gdPsum = local_tile(cute::domain_offset(make_coord(seqlen_info.offset_padded), mdPsum), Shape<Int<kBlockM>>{}, make_coord(m_block));
|
216 |
+
if (get<1>(tOcO(_0{}, _0{}, _0{})) == 0) {
|
217 |
+
#pragma unroll
|
218 |
+
for (int mi = 0; mi < size(dP_sum); ++mi) {
|
219 |
+
int const row = get<0>(tOcO(_0{}, mi, _0{}));
|
220 |
+
gdPsum(row) = row < seqlen_o - m_block * kBlockM ? dP_sum(mi) : 0;
|
221 |
+
}
|
222 |
+
}
|
223 |
+
|
224 |
+
int const seqlen_rounded = cute::round_up(seqlen_o, kBlockM);
|
225 |
+
Tensor mLSElog2 = make_tensor(make_gmem_ptr(params.ptr_LSE_log2), params.shape_dPsum, params.stride_LSE_log2)(_, bidh, !is_varlen ? bidb : 0);
|
226 |
+
Tensor gLSElog2 = local_tile(cute::domain_offset(make_coord(seqlen_info.offset_padded), mLSElog2), Shape<Int<kBlockM>>{}, make_coord(m_block));
|
227 |
+
if (thread_idx < seqlen_rounded - m_block * kBlockM && thread_idx < kBlockM) {
|
228 |
+
gLSElog2(thread_idx) = lse == -INFINITY ? 0.f : lse * float(M_LOG2E);
|
229 |
+
}
|
230 |
+
|
231 |
+
if constexpr (Clear_dQaccum) {
|
232 |
+
Tensor mdQaccum = make_tensor(make_gmem_ptr(params.ptr_dQaccum), params.shape_dQaccum, params.stride_dQaccum)(_, bidh, !is_varlen ? bidb : 0);
|
233 |
+
Tensor gdQaccum = local_tile(cute::domain_offset(make_coord(seqlen_info.offset_padded * kHeadDim), mdQaccum), Shape<Int<kBlockM * kHeadDim>>{}, make_coord(m_block));
|
234 |
+
GmemTiledCopyAccum gmem_tiled_copy_dQaccum;
|
235 |
+
auto gmem_thr_copy_dQaccum = gmem_tiled_copy_dQaccum.get_thread_slice(thread_idx);
|
236 |
+
Tensor tdQgdQaccum = gmem_thr_copy_dQaccum.partition_D(gdQaccum);
|
237 |
+
Tensor zero = make_fragment_like(tdQgdQaccum);
|
238 |
+
clear(zero);
|
239 |
+
cute::copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{}, zero, tdQgdQaccum);
|
240 |
+
}
|
241 |
+
|
242 |
+
if (params.dq_semaphore != nullptr && thread_idx == 0) {
|
243 |
+
int const num_batch = params.num_batch;
|
244 |
+
int const num_head = get<2>(params.shape_O);
|
245 |
+
params.dq_semaphore[bidh + bidb * num_head + m_block * num_head * num_batch] = 0;
|
246 |
+
}
|
247 |
+
|
248 |
+
}
|
249 |
+
|
250 |
+
};
|
251 |
+
|
252 |
+
} // namespace flash
|
flash-attn/flash_fwd_combine.cu
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) 2024, Tri Dao.
|
2 |
+
// Splitting the different head dimensions to different files to speed up compilation.
|
3 |
+
|
4 |
+
#include "flash_fwd_combine_launch_template.h"
|
5 |
+
|
6 |
+
template void run_mha_fwd_combine_<float, float, 64>(Flash_fwd_params ¶ms, cudaStream_t stream, bool enable_pdl);
|
7 |
+
template void run_mha_fwd_combine_<float, float, 128>(Flash_fwd_params ¶ms, cudaStream_t stream, bool enable_pdl);
|
8 |
+
|
9 |
+
template void run_mha_fwd_combine_<cutlass::half_t, float, 64>(Flash_fwd_params ¶ms, cudaStream_t stream, bool enable_pdl);
|
10 |
+
template void run_mha_fwd_combine_<cutlass::half_t, float, 128>(Flash_fwd_params ¶ms, cudaStream_t stream, bool enable_pdl);
|
11 |
+
|
12 |
+
template void run_mha_fwd_combine_<cutlass::bfloat16_t, float, 64>(Flash_fwd_params ¶ms, cudaStream_t stream, bool enable_pdl);
|
13 |
+
template void run_mha_fwd_combine_<cutlass::bfloat16_t, float, 128>(Flash_fwd_params ¶ms, cudaStream_t stream, bool enable_pdl);
|
flash-attn/flash_fwd_combine_kernel.h
ADDED
@@ -0,0 +1,482 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/******************************************************************************
|
2 |
+
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
3 |
+
******************************************************************************/
|
4 |
+
|
5 |
+
#pragma once
|
6 |
+
|
7 |
+
#include "cute/tensor.hpp"
|
8 |
+
|
9 |
+
#include <cutlass/cutlass.h>
|
10 |
+
#include <cutlass/arch/memory.h>
|
11 |
+
#include <cutlass/array.h>
|
12 |
+
#include <cutlass/numeric_types.h>
|
13 |
+
#include <cutlass/numeric_conversion.h>
|
14 |
+
|
15 |
+
#include "cutlass/arch/grid_dependency_control.h"
|
16 |
+
|
17 |
+
#include "seqlen.h"
|
18 |
+
#include "utils.h"
|
19 |
+
|
20 |
+
namespace flash {
|
21 |
+
|
22 |
+
using namespace cute;
|
23 |
+
|
24 |
+
template <class TileShape_MK_, int kLogMaxSplits_, int kNThreads, int AlignmentLSE_,
|
25 |
+
bool Is_even_K, bool Varlen, class Element, class ElementPartial, class ArchTag_>
|
26 |
+
class FlashAttnFwdCombine {
|
27 |
+
|
28 |
+
public:
|
29 |
+
|
30 |
+
// Type Aliases
|
31 |
+
using TileShape_MK = TileShape_MK_;
|
32 |
+
using ArchTag = ArchTag_;
|
33 |
+
static constexpr int kMaxSplits = 1 << kLogMaxSplits_;
|
34 |
+
static constexpr int AlignmentLSE = std::min(AlignmentLSE_, int(128 / 8 / sizeof(float)));
|
35 |
+
static_assert(AlignmentLSE >= 1);
|
36 |
+
static constexpr int kStages = 4;
|
37 |
+
|
38 |
+
static_assert(ArchTag::kMinComputeCapability >= 75);
|
39 |
+
static constexpr bool Has_cp_async = ArchTag::kMinComputeCapability >= 80;
|
40 |
+
|
41 |
+
static constexpr uint32_t MaxThreadsPerBlock = kNThreads;
|
42 |
+
static constexpr uint32_t MinBlocksPerMultiprocessor = 2;
|
43 |
+
|
44 |
+
static constexpr int kBlockM = get<0>(TileShape_MK{});
|
45 |
+
static constexpr int kBlockK = get<1>(TileShape_MK{});
|
46 |
+
|
47 |
+
static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(ElementPartial);
|
48 |
+
static_assert(kBlockK % kGmemElemsPerLoad == 0, "kBlockK must be a multiple of kGmemElemsPerLoad");
|
49 |
+
static constexpr int kBlockKGmem = kBlockK % 128 == 0 ? 128 : (kBlockK % 64 == 0 ? 64 : 32);
|
50 |
+
static constexpr int kGmemThreadsPerRow = kBlockKGmem / kGmemElemsPerLoad;
|
51 |
+
static_assert(MaxThreadsPerBlock % kGmemThreadsPerRow == 0, "MaxThreadsPerBlock must be a multiple of kGmemThreadsPerRow");
|
52 |
+
using GmemCopyAtom = std::conditional_t<
|
53 |
+
Has_cp_async,
|
54 |
+
cute::Copy_Atom<SM80_CP_ASYNC_CACHEGLOBAL<uint128_t>, ElementPartial>,
|
55 |
+
cute::Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementPartial>
|
56 |
+
>;
|
57 |
+
using GmemLayoutAtom = Layout<Shape <Int<MaxThreadsPerBlock / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
|
58 |
+
Stride<Int<kGmemThreadsPerRow>, _1>>;
|
59 |
+
static_assert(kBlockM % CUTE_STATIC_V(shape<0>(GmemLayoutAtom{})) == 0);
|
60 |
+
using GmemTiledCopyAccum = decltype(
|
61 |
+
make_tiled_copy(GmemCopyAtom{},
|
62 |
+
GmemLayoutAtom{},
|
63 |
+
Layout<Shape<_1, Int<kGmemElemsPerLoad>>>{})); // Val layout, 4 vals per load
|
64 |
+
using GmemTiledCopy = decltype(
|
65 |
+
make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>{},
|
66 |
+
GmemLayoutAtom{},
|
67 |
+
Layout<Shape<_1, Int<kGmemElemsPerLoad>>>{})); // Val layout, 4 vals per load
|
68 |
+
|
69 |
+
using AlignmentTypeLSE = cute::uint_byte_t<static_cast<int>(sizeof(float)) * AlignmentLSE>;
|
70 |
+
static constexpr int kGmemElemsPerLoadLSE = sizeof(AlignmentTypeLSE) / sizeof(float);
|
71 |
+
static_assert(kBlockM % kGmemElemsPerLoadLSE == 0, "kBlockM must be a multiple of kGmemElemsPerLoadLSE");
|
72 |
+
static_assert(kBlockM % 8 == 0, "kBlockM must be a multiple of 8");
|
73 |
+
static constexpr int kBlockMSmem = kBlockM % 128 == 0 ? 128 : (kBlockM % 64 == 0 ? 64 : (kBlockM % 32 == 0 ? 32 : (kBlockM % 16 == 0 ? 16 : 8)));
|
74 |
+
static constexpr int kGmemThreadsPerRowLSE = kBlockMSmem / kGmemElemsPerLoadLSE;
|
75 |
+
static_assert(MaxThreadsPerBlock % kGmemThreadsPerRowLSE == 0, "MaxThreadsPerBlock must be a multiple of kGmemThreadsPerRowLSE");
|
76 |
+
using GmemLayoutAtomLSE = Layout<Shape <Int<MaxThreadsPerBlock / kGmemThreadsPerRowLSE>, Int<kGmemThreadsPerRowLSE>>,
|
77 |
+
Stride<Int<kGmemThreadsPerRowLSE>, _1>>;
|
78 |
+
static_assert(kMaxSplits % CUTE_STATIC_V(shape<0>(GmemLayoutAtomLSE{})) == 0);
|
79 |
+
using GmemCopyAtomLSE = std::conditional_t<
|
80 |
+
Has_cp_async,
|
81 |
+
cute::Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<AlignmentTypeLSE>, float>,
|
82 |
+
cute::Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<AlignmentLSE * sizeof(float) * 8>, float>
|
83 |
+
>;
|
84 |
+
using GmemTiledCopyLSE = decltype(
|
85 |
+
make_tiled_copy(GmemCopyAtomLSE{},
|
86 |
+
GmemLayoutAtomLSE{},
|
87 |
+
Layout<Shape<_1, Int<kGmemElemsPerLoadLSE>>>{})); // Val layout, 4 vals per load
|
88 |
+
|
89 |
+
// Otherwise we get IMA when some threads access sLSE, as we're not doing any masking
|
90 |
+
static_assert((kBlockM * kMaxSplits * AlignmentLSE) % kNThreads == 0, "kNThreads must divide kBlockM * kMaxSplits * AlignmentLSE");
|
91 |
+
// This works for kBlockMSmem = 8, 16, 32, 64, 128, no bank conflicts
|
92 |
+
using SmemLSESwizzle = std::conditional_t<
|
93 |
+
kBlockMSmem == 8,
|
94 |
+
Swizzle<5, 0, 5>,
|
95 |
+
std::conditional_t<kBlockMSmem == 16, Swizzle<4, 0, 4>, Swizzle<3, 2, 3>>
|
96 |
+
>;
|
97 |
+
using SmemLayoutAtomLSE =
|
98 |
+
decltype(composition(SmemLSESwizzle{},
|
99 |
+
Layout<Shape<Int<8>, Int<kBlockMSmem>>,
|
100 |
+
Stride<Int<kBlockMSmem>, _1>>{}));
|
101 |
+
using SmemLayoutLSE = decltype(tile_to_shape(SmemLayoutAtomLSE{}, Shape<Int<kMaxSplits>, Int<kBlockM>>{}));
|
102 |
+
|
103 |
+
using SmemLayoutO = Layout<Shape<Int<kBlockM>, Int<kBlockK>, Int<kStages>>,
|
104 |
+
Stride<Int<kBlockK>, _1, Int<kBlockM * kBlockK>>>;
|
105 |
+
|
106 |
+
// We want each column (kMaxSplits) to be processed by threads in the same warp.
|
107 |
+
// To reduce the number of shuffles, we want as few threads on the same column as possible.
|
108 |
+
// E.g., if kBlockM is divisible by 64, and there are 256 threads, we want 4 threads (0, 1, 2, 4) per column
|
109 |
+
// have have 64 such quads.
|
110 |
+
static_assert(MaxThreadsPerBlock % kBlockMSmem == 0, "MaxThreadsPerBlock must be a multiple of kBlockMSmem");
|
111 |
+
static constexpr int kSmemThreadsPerColLSEt = MaxThreadsPerBlock / kBlockMSmem;
|
112 |
+
static_assert(cutlass::NumThreadsPerWarp % kSmemThreadsPerColLSEt == 0, "kSmemThreadsPerColLSEt must divide NumThreadsPerWarp");
|
113 |
+
using S2RLayoutAtomLSE = Layout<Shape<Int<kSmemThreadsPerColLSEt>, Int<MaxThreadsPerBlock / kSmemThreadsPerColLSEt>>>;
|
114 |
+
using S2RTiledCopyLSE = decltype(make_tiled_copy(cute::Copy_Atom<cute::DefaultCopy, float>{}, S2RLayoutAtomLSE{}, Layout<_1>{}));
|
115 |
+
|
116 |
+
using ShapeOPartial = cute::Shape<int32_t, int32_t, int32_t, int32_t, int32_t>; // (seqlen, d, num_splits, head, batch)
|
117 |
+
using StrideOPartial = cute::Stride<int64_t, _1, int64_t, int64_t, int64_t>;
|
118 |
+
using ShapeLSEPartial = cute::Shape<int32_t, int32_t, int32_t, int32_t>; // (seqlen, num_splits, head, batch)
|
119 |
+
using StrideLSEPartial = cute::Stride<_1, int64_t, int64_t, int64_t>; // (seqlen, num_splits, head, batch)
|
120 |
+
using ShapeO = cute::Shape<int32_t, int32_t, int32_t, int32_t>; // (seqlen, d, head, batch)
|
121 |
+
using StrideO = cute::Stride<int64_t, _1, int64_t, int64_t>;
|
122 |
+
using ShapeLSE = cute::Shape<int32_t, int32_t, int32_t>; // (seqlen, head, batch)
|
123 |
+
using StrideLSE = cute::Stride<_1, int64_t, int64_t>; // (seqlen, head, batch)
|
124 |
+
|
125 |
+
struct SharedStorage : cute::aligned_struct<128> {
|
126 |
+
cute::array_aligned<float, cute::cosize_v<SmemLayoutLSE>> smem_lse_partial;
|
127 |
+
cute::array_aligned<int, kBlockM> smem_max_valid_split;
|
128 |
+
cute::array_aligned<ElementPartial, cute::cosize_v<SmemLayoutO>> smem_o_partial;
|
129 |
+
};
|
130 |
+
|
131 |
+
static constexpr int SharedStorageSize = sizeof(SharedStorage);
|
132 |
+
|
133 |
+
// Device side arguments
|
134 |
+
struct Arguments {
|
135 |
+
ElementPartial const* const ptr_O_partial;
|
136 |
+
ShapeOPartial const shape_O_partial;
|
137 |
+
StrideOPartial const stride_O_partial;
|
138 |
+
float const* const ptr_LSE_partial;
|
139 |
+
ShapeLSEPartial const shape_LSE_partial;
|
140 |
+
StrideLSEPartial const stride_LSE_partial;
|
141 |
+
Element* const ptr_O;
|
142 |
+
StrideO const stride_O;
|
143 |
+
float* const ptr_LSE;
|
144 |
+
StrideLSE const stride_LSE;
|
145 |
+
int const* const cu_seqlens = nullptr;
|
146 |
+
int const* const seqused = nullptr;
|
147 |
+
int const* const num_splits_dynamic_ptr = nullptr;
|
148 |
+
int* const semaphore_to_reset = nullptr;
|
149 |
+
};
|
150 |
+
|
151 |
+
// Kernel entry point API
|
152 |
+
struct Params {
|
153 |
+
ElementPartial const* const ptr_O_partial;
|
154 |
+
ShapeOPartial const shape_O_partial;
|
155 |
+
StrideOPartial const stride_O_partial;
|
156 |
+
float const* const ptr_LSE_partial;
|
157 |
+
ShapeLSEPartial const shape_LSE_partial;
|
158 |
+
StrideLSEPartial const stride_LSE_partial;
|
159 |
+
Element* const ptr_O;
|
160 |
+
StrideO const stride_O;
|
161 |
+
float* const ptr_LSE;
|
162 |
+
StrideLSE const stride_LSE;
|
163 |
+
cutlass::FastDivmod seqlen_divmod, head_divmod;
|
164 |
+
int const* const cu_seqlens = nullptr;
|
165 |
+
int const* const seqused = nullptr;
|
166 |
+
int const* const num_splits_dynamic_ptr = nullptr;
|
167 |
+
int* const semaphore_to_reset = nullptr;
|
168 |
+
};
|
169 |
+
|
170 |
+
// Convert to underlying arguments. In this case, a simple copy for the aliased type.
|
171 |
+
static
|
172 |
+
Params
|
173 |
+
to_underlying_arguments(Arguments const& args) {
|
174 |
+
assert(get<1>(args.shape_LSE_partial) <= kMaxSplits);
|
175 |
+
return {
|
176 |
+
args.ptr_O_partial,
|
177 |
+
args.shape_O_partial,
|
178 |
+
args.stride_O_partial,
|
179 |
+
args.ptr_LSE_partial,
|
180 |
+
args.shape_LSE_partial,
|
181 |
+
args.stride_LSE_partial,
|
182 |
+
args.ptr_O,
|
183 |
+
args.stride_O,
|
184 |
+
args.ptr_LSE,
|
185 |
+
args.stride_LSE,
|
186 |
+
cutlass::FastDivmod(get<0>(args.shape_LSE_partial)), cutlass::FastDivmod(get<2>(args.shape_LSE_partial)),
|
187 |
+
args.cu_seqlens,
|
188 |
+
args.seqused,
|
189 |
+
args.num_splits_dynamic_ptr,
|
190 |
+
args.semaphore_to_reset
|
191 |
+
};
|
192 |
+
}
|
193 |
+
|
194 |
+
CUTLASS_DEVICE
|
195 |
+
void
|
196 |
+
operator()(Params const& params, char* smem_buf) {
|
197 |
+
|
198 |
+
SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(smem_buf);
|
199 |
+
Tensor sLSE = make_tensor(make_smem_ptr(shared_storage.smem_lse_partial.data()), SmemLayoutLSE{});
|
200 |
+
Tensor sMaxValidSplit = make_tensor(make_smem_ptr(shared_storage.smem_max_valid_split.data()), Shape<Int<kBlockM>>{});
|
201 |
+
Tensor sO = make_tensor(make_smem_ptr(shared_storage.smem_o_partial.data()), SmemLayoutO{});
|
202 |
+
|
203 |
+
int const thread_idx = threadIdx.x;
|
204 |
+
int const m_block = blockIdx.x;
|
205 |
+
int const k_block = blockIdx.y;
|
206 |
+
int const batch = blockIdx.z;
|
207 |
+
int const num_splits = params.num_splits_dynamic_ptr ? params.num_splits_dynamic_ptr[batch] : get<1>(params.shape_LSE_partial);
|
208 |
+
|
209 |
+
if (params.semaphore_to_reset && threadIdx.x == 0 && blockIdx.x == gridDim.x - 1 && blockIdx.y == gridDim.y - 1 && blockIdx.z == gridDim.z - 1) {
|
210 |
+
cutlass::arch::wait_on_dependent_grids();
|
211 |
+
*params.semaphore_to_reset = 0;
|
212 |
+
}
|
213 |
+
if (num_splits <= 1) { return; }
|
214 |
+
flash::SeqlenInfo<Varlen, kBlockM> seqlen_info{batch, size<0>(params.shape_LSE_partial), params.cu_seqlens, params.seqused};
|
215 |
+
int const offset = seqlen_info.offset;
|
216 |
+
int const seqlen = seqlen_info.seqlen;
|
217 |
+
int max_idx = seqlen * get<2>(params.shape_LSE_partial);
|
218 |
+
if constexpr (Varlen) {
|
219 |
+
if (m_block * kBlockM >= max_idx) { return; }
|
220 |
+
}
|
221 |
+
|
222 |
+
cutlass::FastDivmod seqlen_divmod_dynamic(seqlen);
|
223 |
+
|
224 |
+
// Step 1: load LSE_partial from gmem -> smem
|
225 |
+
Tensor mLSEpartial = make_tensor(make_gmem_ptr(params.ptr_LSE_partial + offset * get<0>(params.stride_LSE_partial)),
|
226 |
+
select<1, 0, 2, 3>(params.shape_LSE_partial),
|
227 |
+
select<1, 0, 2, 3>(params.stride_LSE_partial))(_, _, _, !Varlen ? batch : 0); // (num_splits, seqlen, head)
|
228 |
+
Tensor mLSEpartial_copy = cute::tiled_divide(mLSEpartial, Shape<_1, Int<kGmemElemsPerLoadLSE>>{});
|
229 |
+
GmemTiledCopyLSE gmem_tiled_copy_LSE;
|
230 |
+
auto gmem_thr_copy_LSE = gmem_tiled_copy_LSE.get_thread_slice(thread_idx);
|
231 |
+
Tensor tLSEsLSE = gmem_thr_copy_LSE.partition_D(sLSE);
|
232 |
+
|
233 |
+
// Construct identity layout for sLSE
|
234 |
+
Tensor cLSE = make_identity_tensor(make_shape(size<0>(sLSE), size<1>(sLSE))); // (NUM_SPLITS, BLK_M) -> (num_splits, blk_m)
|
235 |
+
// Repeat the partitioning with identity layouts
|
236 |
+
Tensor tLSEcLSE = gmem_thr_copy_LSE.partition_S(cLSE);
|
237 |
+
|
238 |
+
cutlass::arch::wait_on_dependent_grids();
|
239 |
+
|
240 |
+
#pragma unroll
|
241 |
+
for (int m = 0; m < size<2>(tLSEcLSE); ++m) {
|
242 |
+
int mi = int(get<1>(tLSEcLSE(_0{}, _0{}, m)));
|
243 |
+
int idx = m_block * kBlockM + mi;
|
244 |
+
if (idx < max_idx) {
|
245 |
+
int m_idx, bidh;
|
246 |
+
if constexpr (!Varlen) {
|
247 |
+
bidh = params.seqlen_divmod.divmod(m_idx, idx);
|
248 |
+
} else {
|
249 |
+
bidh = seqlen_divmod_dynamic.divmod(m_idx, idx);
|
250 |
+
}
|
251 |
+
Tensor mLSEpartial_cur_copy = mLSEpartial_copy(_, _, m_idx, bidh);
|
252 |
+
#pragma unroll
|
253 |
+
for (int s = 0; s < size<1>(tLSEcLSE); ++s) {
|
254 |
+
int si = get<0>(tLSEcLSE(_0{}, s, _0{}));
|
255 |
+
// if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && thread_idx < 32) { printf("thread_idx = %d, m = %d, s = %d, addr = %p, bank = %d\n", thread_idx, m, s, reinterpret_cast<float *>(&(tLSEsLSE(_0{}, s, m))), reinterpret_cast<int>(&(tLSEsLSE(_0{}, s, m))) / 4 % 32);}
|
256 |
+
if (si < num_splits) {
|
257 |
+
cute::copy(gmem_tiled_copy_LSE, mLSEpartial_cur_copy(_, si), tLSEsLSE(_, s, m));
|
258 |
+
} else {
|
259 |
+
cute::fill(tLSEsLSE(_, s, m), -INFINITY);
|
260 |
+
}
|
261 |
+
}
|
262 |
+
} else {
|
263 |
+
// We don't need to zero out the rest of the LSEs, as we will not write the output to gmem
|
264 |
+
// cute::fill(tLSEsLSE(_, _, m), -INFINITY);
|
265 |
+
}
|
266 |
+
}
|
267 |
+
if constexpr (Has_cp_async) { cute::cp_async_fence(); }
|
268 |
+
|
269 |
+
// Step 2: Load O_partial from gmem -> smem for split = 0, 1, ..., kStages - 2.
|
270 |
+
// We want these async loads to be in flight as we compute the LSE.
|
271 |
+
GmemTiledCopyAccum gmem_tiled_copy_O_partial;
|
272 |
+
auto gmem_thr_copy_O_partial = gmem_tiled_copy_O_partial.get_thread_slice(thread_idx);
|
273 |
+
// Construct identity layout for gO
|
274 |
+
Tensor cO = cute::make_identity_tensor(TileShape_MK{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
|
275 |
+
// Repeat the partitioning with identity layouts
|
276 |
+
Tensor tOcO = gmem_thr_copy_O_partial.partition_D(cO);
|
277 |
+
Tensor mOpartial = make_tensor(make_gmem_ptr(params.ptr_O_partial + offset * get<0>(params.stride_O_partial)),
|
278 |
+
params.shape_O_partial, params.stride_O_partial)(_, _, _, _, !Varlen ? batch : 0); // (seqlen, d, num_splits, head)
|
279 |
+
|
280 |
+
// Precompute these values to avoid recomputing them in the loop
|
281 |
+
Tensor tOmidx = make_tensor<int>(make_shape(size<1>(tOcO)));
|
282 |
+
Tensor tObidh = make_tensor<int>(make_shape(size<1>(tOcO)));
|
283 |
+
Tensor tOrOptr = make_tensor<ElementPartial const*>(make_shape(size<1>(tOcO)));
|
284 |
+
#pragma unroll
|
285 |
+
for (int m = 0; m < size<1>(tOcO); ++m) {
|
286 |
+
int mi = get<0>(tOcO(_0{}, m, _0{}));
|
287 |
+
int idx = m_block * kBlockM + mi;
|
288 |
+
if constexpr (!Varlen) {
|
289 |
+
tObidh(m) = params.seqlen_divmod.divmod(tOmidx(m), idx);
|
290 |
+
} else {
|
291 |
+
tObidh[m] = seqlen_divmod_dynamic.divmod(tOmidx(m), idx);
|
292 |
+
}
|
293 |
+
tOrOptr[m] = &mOpartial(tOmidx(m), k_block * kBlockK, _0{}, tObidh(m));
|
294 |
+
if (idx >= max_idx) {
|
295 |
+
tObidh[m] = -1;
|
296 |
+
}
|
297 |
+
}
|
298 |
+
|
299 |
+
Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOcO)));
|
300 |
+
if constexpr (!(Is_even_K)) {
|
301 |
+
#pragma unroll
|
302 |
+
for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < get<1>(params.shape_O_partial) - k_block * kBlockK; }
|
303 |
+
}
|
304 |
+
|
305 |
+
Tensor tOsOpartial = gmem_thr_copy_O_partial.partition_D(sO);
|
306 |
+
|
307 |
+
auto load_O_partial = [&] (int split, int stage) {
|
308 |
+
Tensor tOsOpartial_cur = tOsOpartial(_, _, _, stage);
|
309 |
+
#pragma unroll
|
310 |
+
for (int m = 0; m < size<1>(tOcO); ++m) {
|
311 |
+
if (tObidh(m) >= 0) {
|
312 |
+
Tensor mOpartial_cur = make_tensor(make_gmem_ptr(tOrOptr[m]), mOpartial(_0{}, _, _, _0{}).layout());
|
313 |
+
Tensor mOpartial_cur_copy = cute::tiled_divide(mOpartial_cur, Shape<Int<kGmemElemsPerLoad>>{});
|
314 |
+
#pragma unroll
|
315 |
+
for (int k = 0; k < size<2>(tOcO); ++k) {
|
316 |
+
int k_idx = get<1>(tOcO(_0{}, _0{}, k)) / kGmemElemsPerLoad;
|
317 |
+
if (Is_even_K || tOpO(k)) {
|
318 |
+
cute::copy(gmem_tiled_copy_O_partial, mOpartial_cur_copy(_, k_idx, split), tOsOpartial_cur(_, m, k));
|
319 |
+
}
|
320 |
+
}
|
321 |
+
}
|
322 |
+
}
|
323 |
+
};
|
324 |
+
|
325 |
+
for (int s = 0; s < kStages - 1; ++s) {
|
326 |
+
if (s < num_splits) { load_O_partial(s, s); }
|
327 |
+
if constexpr (Has_cp_async) { cute::cp_async_fence(); }
|
328 |
+
}
|
329 |
+
|
330 |
+
// Step 3: load and transpose LSE_partial from smem -> rmem
|
331 |
+
if constexpr (Has_cp_async) { cutlass::arch::cp_async_wait<kStages - 1>(); }
|
332 |
+
__syncthreads();
|
333 |
+
|
334 |
+
S2RTiledCopyLSE s2r_tiled_copy_LSE;
|
335 |
+
auto s2r_thr_copy_LSE = s2r_tiled_copy_LSE.get_thread_slice(thread_idx);
|
336 |
+
Tensor ts2rsLSE = s2r_thr_copy_LSE.partition_S(sLSE);
|
337 |
+
Tensor ts2rrLSE = make_fragment_like(ts2rsLSE);
|
338 |
+
cute::copy(s2r_tiled_copy_LSE, ts2rsLSE, ts2rrLSE);
|
339 |
+
|
340 |
+
// Step 4: compute the final LSE along the split dimension
|
341 |
+
Tensor lse_sum = make_tensor<float>(make_shape(size<2>(ts2rrLSE)));
|
342 |
+
Tensor ts2rcLSE = s2r_thr_copy_LSE.partition_D(cLSE);
|
343 |
+
// We compute the max valid split for each row to short-circuit the computation later
|
344 |
+
Tensor max_valid_split = make_tensor<int>(make_shape(size<2>(ts2rrLSE)));
|
345 |
+
static_assert(CUTE_STATIC_V(size<0>(ts2rrLSE)) == 1);
|
346 |
+
#pragma unroll
|
347 |
+
for (int m = 0; m < size<2>(ts2rrLSE); ++m) {
|
348 |
+
float lse_max = ts2rrLSE(_0{}, _0{}, m);
|
349 |
+
#pragma unroll
|
350 |
+
for (int s = 1; s < size<1>(ts2rrLSE); ++s) { lse_max = max(lse_max, ts2rrLSE(_0{}, s, m)); }
|
351 |
+
MaxOp<float> max_op;
|
352 |
+
lse_max = Allreduce<kSmemThreadsPerColLSEt>::run(lse_max, max_op);
|
353 |
+
int max_valid_idx = -1;
|
354 |
+
#pragma unroll
|
355 |
+
for (int s = 0; s < size<1>(ts2rrLSE); ++s) {
|
356 |
+
if (ts2rrLSE(_0{}, s, m) != -INFINITY) { max_valid_idx = get<0>(ts2rcLSE(_0{}, s, _0{})); }
|
357 |
+
}
|
358 |
+
MaxOp<int> max_int_op;
|
359 |
+
max_valid_split[m] = Allreduce<kSmemThreadsPerColLSEt>::run(max_valid_idx, max_int_op);
|
360 |
+
float lse_max_cur = lse_max == -INFINITY ? 0.0f : lse_max; // In case all local LSEs are -inf
|
361 |
+
float lse_sum_cur = 0.f;
|
362 |
+
#pragma unroll
|
363 |
+
for (int s = 0; s < size<1>(ts2rrLSE); ++s) {
|
364 |
+
float scale = expf(ts2rrLSE(_0{}, s, m) - lse_max_cur);
|
365 |
+
lse_sum_cur += scale;
|
366 |
+
// if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && thread_idx < 32) { printf("thread_idx = %d, m = %d, s = %d, addr = %p, bank = %d\n", thread_idx, m, s, reinterpret_cast<float *>(&(ts2rsLSE(_0{}, s, m))), reinterpret_cast<int>(&(ts2rsLSE(_0{}, s, m))) / 4 % 32);}
|
367 |
+
// ts2rsLSE(_0{}, m, s) = scale;
|
368 |
+
ts2rrLSE(_0{}, s, m) = scale;
|
369 |
+
}
|
370 |
+
SumOp<float> sum_op;
|
371 |
+
lse_sum_cur = Allreduce<kSmemThreadsPerColLSEt>::run(lse_sum_cur, sum_op);
|
372 |
+
lse_sum(m) = logf(lse_sum_cur) + lse_max;
|
373 |
+
float inv_sum = (lse_sum_cur == 0.f || lse_sum_cur != lse_sum_cur) ? 0.f : 1.f / lse_sum_cur;
|
374 |
+
#pragma unroll
|
375 |
+
for (int s = 0; s < size<1>(ts2rrLSE); ++s) { ts2rrLSE(_0{}, s, m) *= inv_sum; }
|
376 |
+
}
|
377 |
+
// Store the scales exp(lse - lse_logsum) back to smem
|
378 |
+
cute::copy(s2r_tiled_copy_LSE, ts2rrLSE, ts2rsLSE);
|
379 |
+
|
380 |
+
// Store max_valid_split to smem
|
381 |
+
#pragma unroll
|
382 |
+
for (int m = 0; m < size<2>(ts2rrLSE); ++m) {
|
383 |
+
if (get<0>(ts2rcLSE(_0{}, _0{}, m)) == 0) { // Only the thread responsible for s=0 writes to smem
|
384 |
+
int mi = int(get<1>(ts2rcLSE(_0{}, _0{}, m)));
|
385 |
+
if (mi < kBlockM) { sMaxValidSplit[mi] = max_valid_split[m]; }
|
386 |
+
}
|
387 |
+
}
|
388 |
+
|
389 |
+
// Step 5: store final LSE back to gmem
|
390 |
+
if (k_block == 0) {
|
391 |
+
auto shape_LSE = select<0, 2, 3>(params.shape_LSE_partial);
|
392 |
+
Tensor mLSE = make_tensor(make_gmem_ptr(params.ptr_LSE + offset * get<0>(params.stride_LSE)), shape_LSE, params.stride_LSE)(_, _, !Varlen ? batch : 0);
|
393 |
+
#pragma unroll
|
394 |
+
for (int m = 0; m < size<2>(ts2rrLSE); ++m) {
|
395 |
+
if (get<0>(ts2rcLSE(_0{}, _0{}, m)) == 0) { // Only the thread responsible for s=0 writes to gmem
|
396 |
+
int mi = int(get<1>(ts2rcLSE(_0{}, _0{}, m)));
|
397 |
+
int idx = m_block * kBlockM + mi;
|
398 |
+
if (idx < max_idx) {
|
399 |
+
int m_idx, bidh;
|
400 |
+
if constexpr (!Varlen) {
|
401 |
+
bidh = params.seqlen_divmod.divmod(m_idx, idx);
|
402 |
+
} else {
|
403 |
+
bidh = seqlen_divmod_dynamic.divmod(m_idx, idx);
|
404 |
+
}
|
405 |
+
// printf("thread_idx = %d, m = %d, mi = %d, idx = %d, m_idx = %d, bidh = %d, bidb = %d, lse_sum = %f\n", thread_idx, m, mi, idx, m_idx, bidh, bidb, lse_sum(m));
|
406 |
+
mLSE(m_idx, bidh) = lse_sum(m);
|
407 |
+
}
|
408 |
+
}
|
409 |
+
}
|
410 |
+
}
|
411 |
+
|
412 |
+
// Step 6: read O_partial from gmem -> smem -> rmem and accumulate the final O
|
413 |
+
__syncthreads();
|
414 |
+
int thr_max_valid_split = sMaxValidSplit[get<0>(tOcO(_0{}, _0{}, _0{}))];
|
415 |
+
#pragma unroll
|
416 |
+
for (int m = 1; m < size<1>(tOcO); ++m) { thr_max_valid_split = max(thr_max_valid_split, sMaxValidSplit[get<0>(tOcO(_0{}, m, _0{}))]); }
|
417 |
+
Layout tOrOpartial_layout = gmem_thr_copy_O_partial.partition_S(make_tensor<ElementPartial>(TileShape_MK{})).layout();
|
418 |
+
Tensor tOrOpartial = make_fragment_like<ElementPartial>(tOrOpartial_layout);
|
419 |
+
Tensor tOrO = make_fragment_like<float>(tOrOpartial);
|
420 |
+
clear(tOrO);
|
421 |
+
int stage_load = kStages - 1, stage_compute = 0;
|
422 |
+
#pragma unroll 4 // Already tuned for speed
|
423 |
+
for (int s = 0; s <= thr_max_valid_split; ++s) {
|
424 |
+
Tensor scale = make_tensor<float>(make_shape(size<1>(tOrOpartial)));
|
425 |
+
#pragma unroll
|
426 |
+
for (int m = 0; m < size<1>(tOrOpartial); ++m) { scale(m) = sLSE(s, get<0>(tOcO(_0{}, m, _0{}))); }
|
427 |
+
|
428 |
+
if (s + kStages - 1 <= thr_max_valid_split) { load_O_partial(s + kStages - 1, stage_load); }
|
429 |
+
if constexpr (Has_cp_async) { cute::cp_async_fence(); }
|
430 |
+
stage_load = stage_load < kStages - 1 ? stage_load + 1 : 0;
|
431 |
+
if constexpr (Has_cp_async) { cutlass::arch::cp_async_wait<kStages - 1>(); }
|
432 |
+
// We don't need __syncthreads() because each thread is just reading its own data from smem
|
433 |
+
cute::copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementPartial>{},
|
434 |
+
tOsOpartial(_, _, _, stage_compute), tOrOpartial);
|
435 |
+
stage_compute = stage_compute < kStages - 1 ? stage_compute + 1 : 0;
|
436 |
+
|
437 |
+
#pragma unroll
|
438 |
+
for (int m = 0; m < size<1>(tOrOpartial); ++m) {
|
439 |
+
if (tObidh(m) >= 0 && scale(m) > 0.f) {
|
440 |
+
#pragma unroll
|
441 |
+
for (int k = 0; k < size<2>(tOrOpartial); ++k) {
|
442 |
+
if (Is_even_K || tOpO(k)) {
|
443 |
+
Tensor rOpartial = make_tensor_like<float>(tOrOpartial(_, m, k));
|
444 |
+
flash::convert_type_out(tOrOpartial(_, m, k), rOpartial);
|
445 |
+
#pragma unroll
|
446 |
+
for (int i = 0; i < size<0>(tOrOpartial); ++i) {
|
447 |
+
tOrO(i, m, k) += scale(m) * rOpartial[i];
|
448 |
+
}
|
449 |
+
}
|
450 |
+
}
|
451 |
+
}
|
452 |
+
}
|
453 |
+
}
|
454 |
+
|
455 |
+
// Step 7: Write the final O to gmem
|
456 |
+
Tensor rO = make_tensor_like<Element>(tOrO);
|
457 |
+
flash::convert_type_out(tOrO, rO);
|
458 |
+
auto shape_O = make_shape(get<0>(params.shape_O_partial), get<1>(params.shape_O_partial) - k_block * kBlockK, get<3>(params.shape_O_partial), get<4>(params.shape_O_partial));
|
459 |
+
Tensor mO = make_tensor(make_gmem_ptr(params.ptr_O + offset * get<0>(params.stride_O) + k_block * kBlockK * get<1>(params.stride_O)),
|
460 |
+
shape_O, params.stride_O)(_, _, _, !Varlen ? batch : 0);
|
461 |
+
Tensor mO_copy = cute::tiled_divide(mO, Shape<_1, Int<kGmemElemsPerLoad>>{});
|
462 |
+
GmemTiledCopy gmem_tiled_copy_O;
|
463 |
+
auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx);
|
464 |
+
|
465 |
+
#pragma unroll
|
466 |
+
for (int m = 0; m < size<1>(tOcO); ++m) {
|
467 |
+
if (tObidh(m) >= 0) {
|
468 |
+
#pragma unroll
|
469 |
+
for (int k = 0; k < size<2>(tOcO); ++k) {
|
470 |
+
int k_idx = get<1>(tOcO(_0{}, _0{}, k)) / kGmemElemsPerLoad;
|
471 |
+
if (Is_even_K || tOpO(k)) {
|
472 |
+
cute::copy(gmem_tiled_copy_O, rO(_, m, k), mO_copy(_, tOmidx(m), k_idx, tObidh(m)));
|
473 |
+
}
|
474 |
+
}
|
475 |
+
}
|
476 |
+
}
|
477 |
+
|
478 |
+
}
|
479 |
+
|
480 |
+
};
|
481 |
+
|
482 |
+
} // namespace flash
|
flash-attn/flash_fwd_combine_launch_template.h
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/******************************************************************************
|
2 |
+
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
3 |
+
******************************************************************************/
|
4 |
+
|
5 |
+
#pragma once
|
6 |
+
|
7 |
+
#include "cute/tensor.hpp"
|
8 |
+
|
9 |
+
#include "cutlass/cutlass.h"
|
10 |
+
#include "cutlass/arch/arch.h" // For cutlass::arch::Sm80
|
11 |
+
#include "cutlass/device_kernel.h" // For device_kernel
|
12 |
+
#include "cutlass/kernel_launch.h" // For kernel_launch
|
13 |
+
|
14 |
+
#include "static_switch.h"
|
15 |
+
#include "flash.h"
|
16 |
+
#include "flash_fwd_combine_kernel.h"
|
17 |
+
|
18 |
+
using namespace cute;
|
19 |
+
|
20 |
+
template <int Arch, int kBlockM, int kBlockK, int kLogMaxSplits, bool IsEvenK, bool Varlen, typename Element, typename ElementPartial>
|
21 |
+
void run_flash_fwd_combine(Flash_fwd_params ¶ms, cudaStream_t stream, bool enable_pdl) {
|
22 |
+
using ArchTag = std::conditional_t<Arch >= 90, cutlass::arch::Sm90, cutlass::arch::Sm80>;
|
23 |
+
using TileShape_MK = cute::Shape<Int<kBlockM>, Int<kBlockK>>;
|
24 |
+
using CombineKernel = flash::FlashAttnFwdCombine<TileShape_MK, kLogMaxSplits, 256 /*kNThreads*/, 1 /*AlignmentLSE*/,
|
25 |
+
IsEvenK, Varlen, Element, ElementPartial, ArchTag>;
|
26 |
+
|
27 |
+
typename CombineKernel::Arguments args {
|
28 |
+
static_cast<ElementPartial const*>(params.oaccum_ptr),
|
29 |
+
{!Varlen ? params.seqlen_q : params.total_q, params.dv, params.num_splits, params.h, !Varlen ? params.b : 1}, // shape_O_partial
|
30 |
+
{params.oaccum_row_stride, _1{}, params.oaccum_split_stride, params.oaccum_head_stride, !Varlen ? params.oaccum_batch_stride : 0}, // stride_O_partial
|
31 |
+
static_cast<float*>(params.softmax_lseaccum_ptr),
|
32 |
+
{!Varlen ? params.seqlen_q : params.total_q, params.num_splits, params.h, !Varlen ? params.b : 1}, // shape_LSE_partial
|
33 |
+
{_1{}, params.lseaccum_split_stride, params.lseaccum_head_stride, !Varlen ? params.lseaccum_batch_stride : 0}, // stride_LSE_partial
|
34 |
+
static_cast<Element*>(params.o_ptr),
|
35 |
+
{params.o_row_stride, _1{}, params.o_head_stride, !Varlen ? params.o_batch_stride : 0}, // stride_O
|
36 |
+
static_cast<float*>(params.softmax_lse_ptr),
|
37 |
+
{_1{}, !Varlen ? params.seqlen_q : params.total_q, !Varlen ? params.h * params.seqlen_q : 0}, // stride_LSE
|
38 |
+
params.cu_seqlens_q, params.seqused_q, params.num_splits_dynamic_ptr, params.tile_count_semaphore
|
39 |
+
};
|
40 |
+
|
41 |
+
typename CombineKernel::Params kernel_params = CombineKernel::to_underlying_arguments(args);
|
42 |
+
int num_blocks_k = cute::ceil_div(params.dv, kBlockK);
|
43 |
+
int num_blocks_m = cute::ceil_div(params.seqlen_q * params.h, kBlockM);
|
44 |
+
dim3 grid_m(num_blocks_m, num_blocks_k, params.b);
|
45 |
+
auto kernel = cutlass::device_kernel<CombineKernel>;
|
46 |
+
int smem_size = CombineKernel::SharedStorageSize;
|
47 |
+
if (smem_size >= 48 * 1024) {
|
48 |
+
CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
|
49 |
+
}
|
50 |
+
// kernel<<<grid_m, CombineKernel::MaxThreadsPerBlock, smem_size, stream>>>(kernel_params);
|
51 |
+
cutlass::kernel_launch<CombineKernel>(grid_m, CombineKernel::MaxThreadsPerBlock, smem_size, stream, kernel_params, Arch >= 90 && enable_pdl /*launch_with_pdl*/);
|
52 |
+
CHECK_CUDA_KERNEL_LAUNCH();
|
53 |
+
}
|
54 |
+
|
55 |
+
template<typename T, typename Tpartial, int kBlockK>
|
56 |
+
void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream, bool enable_pdl) {
|
57 |
+
// We want kBlockM to be as small as possible to maximize parallelism.
|
58 |
+
// E.g., if hdim is 64, we want kBlockM to be 16 so that we can use 256 threads, each reading 4 elements (floats).
|
59 |
+
static_assert(kBlockK % 32 == 0, "kBlockK must be a multiple of 32");
|
60 |
+
static constexpr int kBlockM = kBlockK % 128 == 0 ? 8 : (kBlockK % 64 == 0 ? 16 : 32);
|
61 |
+
ARCH_SWITCH(params.arch, Arch, [&] {
|
62 |
+
BOOL_SWITCH(params.cu_seqlens_q || params.seqused_q, Varlen, [&] {
|
63 |
+
if constexpr (kBlockM >= 16) { // If kBlockM == 8 then the minimum number of splits is 32.
|
64 |
+
if (params.num_splits <= 16) {
|
65 |
+
run_flash_fwd_combine<Arch, kBlockM, kBlockK, 4, false /*IsEvenK*/, Varlen, T, Tpartial>(params, stream, enable_pdl);
|
66 |
+
return;
|
67 |
+
}
|
68 |
+
}
|
69 |
+
if (params.num_splits <= 32) {
|
70 |
+
run_flash_fwd_combine<Arch, kBlockM, kBlockK, 5, false /*IsEvenK*/, Varlen, T, Tpartial>(params, stream, enable_pdl);
|
71 |
+
} else if (params.num_splits <= 64) {
|
72 |
+
run_flash_fwd_combine<Arch, kBlockM, kBlockK, 6, false /*IsEvenK*/, Varlen, T, Tpartial>(params, stream, enable_pdl);
|
73 |
+
} else if (params.num_splits <= 128) {
|
74 |
+
run_flash_fwd_combine<Arch, kBlockM, kBlockK, 7, false /*IsEvenK*/, Varlen, T, Tpartial>(params, stream, enable_pdl);
|
75 |
+
} else {
|
76 |
+
run_flash_fwd_combine<Arch, kBlockM, kBlockK, 8, false /*IsEvenK*/, Varlen, T, Tpartial>(params, stream, enable_pdl);
|
77 |
+
}
|
78 |
+
});
|
79 |
+
});
|
80 |
+
}
|
flash-attn/flash_fwd_kernel_sm80.h
ADDED
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/******************************************************************************
|
2 |
+
* Copyright (c) 2024, Tri Dao.
|
3 |
+
******************************************************************************/
|
4 |
+
|
5 |
+
#pragma once
|
6 |
+
|
7 |
+
#include "cute/tensor.hpp"
|
8 |
+
|
9 |
+
#include <cutlass/cutlass.h>
|
10 |
+
#include <cutlass/array.h>
|
11 |
+
#include <cutlass/numeric_types.h>
|
12 |
+
#include <cutlass/kernel_hardware_info.h>
|
13 |
+
|
14 |
+
#include "seqlen.h"
|
15 |
+
#include "utils.h"
|
16 |
+
#include "softmax.h"
|
17 |
+
|
18 |
+
namespace flash {
|
19 |
+
|
20 |
+
using namespace cute;
|
21 |
+
|
22 |
+
template <class CollectiveMainloop_, class CollectiveEpilogue_, class TileScheduler_>
|
23 |
+
class FlashAttnFwdSm80 {
|
24 |
+
|
25 |
+
public:
|
26 |
+
|
27 |
+
// Type Aliases
|
28 |
+
using CollectiveMainloop = CollectiveMainloop_;
|
29 |
+
using CollectiveEpilogue = CollectiveEpilogue_;
|
30 |
+
static constexpr bool Is_causal = CollectiveMainloop::Is_causal;
|
31 |
+
static constexpr bool Is_local = CollectiveMainloop::Is_local;
|
32 |
+
static_assert(CollectiveMainloop::Varlen == CollectiveEpilogue::Varlen);
|
33 |
+
static constexpr bool Has_softcap = CollectiveMainloop::Has_softcap;
|
34 |
+
static constexpr bool Varlen = CollectiveMainloop::Varlen;
|
35 |
+
static constexpr bool PagedKV = CollectiveMainloop::PagedKV;
|
36 |
+
static constexpr bool Split = CollectiveMainloop::Split;
|
37 |
+
static constexpr bool Is_FP8 = CollectiveMainloop::Is_FP8;
|
38 |
+
static constexpr bool Transpose_V = CollectiveMainloop::Transpose_V;
|
39 |
+
static constexpr bool AppendKV = CollectiveMainloop::AppendKV;
|
40 |
+
static constexpr bool PackGQA = CollectiveMainloop::PackGQA;
|
41 |
+
static constexpr int NumProducerThreads = CollectiveMainloop::NumProducerThreads;
|
42 |
+
using SeqlenInfo_t = typename CollectiveMainloop::SeqlenInfo_t;
|
43 |
+
|
44 |
+
// Mainloop derived types
|
45 |
+
using TileShape_MNK = typename CollectiveMainloop::TileShape_MNK;
|
46 |
+
using TiledMma = typename CollectiveMainloop::TiledMma;
|
47 |
+
using ArchTag = typename CollectiveMainloop::ArchTag;
|
48 |
+
using MainloopArguments = typename CollectiveMainloop::Arguments;
|
49 |
+
using MainloopParams = typename CollectiveMainloop::Params;
|
50 |
+
|
51 |
+
// Epilogue derived types
|
52 |
+
using EpilogueArguments = typename CollectiveEpilogue::Arguments;
|
53 |
+
using EpilogueParams = typename CollectiveEpilogue::Params;
|
54 |
+
|
55 |
+
static_assert(ArchTag::kMinComputeCapability >= 80);
|
56 |
+
|
57 |
+
using TileScheduler = TileScheduler_;
|
58 |
+
using TileSchedulerArguments = typename flash::TileSchedulerArguments;
|
59 |
+
using TileSchedulerParams = typename TileScheduler::Params;
|
60 |
+
|
61 |
+
static constexpr uint32_t NumThreads = CUTE_STATIC_V(size(TiledMma{}));
|
62 |
+
static constexpr uint32_t MaxThreadsPerBlock = CUTE_STATIC_V(size(TiledMma{}));
|
63 |
+
static constexpr uint32_t MinBlocksPerMultiprocessor = NumThreads == 128 ? 2 : 1;
|
64 |
+
|
65 |
+
// Kernel level shared memory storage
|
66 |
+
// We overlap the shared memory for the mainloop and epilogue. However, we only want smem_o to overlap with smem_v + smem_k and not smem_q
|
67 |
+
// and nothing else, so we'll pad in case sizeof(smem_o) > sizeof(smem_v) + sizeof(smem_k).
|
68 |
+
static constexpr int mainloop_smem_padding_ = int(sizeof(typename CollectiveEpilogue::TensorStorage))
|
69 |
+
- int(sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_v)))
|
70 |
+
- int(sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_k)));
|
71 |
+
static constexpr int mainloop_smem_padding = mainloop_smem_padding_ < 0 ? 0 : mainloop_smem_padding_;
|
72 |
+
struct SharedStorage {
|
73 |
+
struct TensorStorage : cute::aligned_struct<128> {
|
74 |
+
union {
|
75 |
+
struct {
|
76 |
+
cute::array<uint32_t, mainloop_smem_padding / sizeof(uint32_t)> padding_;
|
77 |
+
typename CollectiveMainloop::TensorStorage mainloop;
|
78 |
+
};
|
79 |
+
// We want smem_o to line up with the start of smem_v
|
80 |
+
typename CollectiveEpilogue::TensorStorage epilogue;
|
81 |
+
};
|
82 |
+
} tensors;
|
83 |
+
|
84 |
+
alignas(16) typename TileScheduler::SharedStorage smem_scheduler;
|
85 |
+
|
86 |
+
};
|
87 |
+
|
88 |
+
static constexpr int SharedStorageSize = sizeof(SharedStorage);
|
89 |
+
|
90 |
+
// Device side arguments
|
91 |
+
struct Arguments {
|
92 |
+
MainloopArguments mainloop{};
|
93 |
+
EpilogueArguments epilogue{};
|
94 |
+
cutlass::KernelHardwareInfo hw_info{};
|
95 |
+
TileSchedulerArguments scheduler{};
|
96 |
+
};
|
97 |
+
|
98 |
+
// Kernel entry point API
|
99 |
+
struct Params {
|
100 |
+
MainloopParams mainloop{};
|
101 |
+
EpilogueParams epilogue{};
|
102 |
+
cutlass::KernelHardwareInfo hw_info{};
|
103 |
+
TileSchedulerParams scheduler{};
|
104 |
+
};
|
105 |
+
|
106 |
+
//
|
107 |
+
// Methods
|
108 |
+
//
|
109 |
+
|
110 |
+
// Convert to underlying arguments. In this case, a simple copy for the aliased type.
|
111 |
+
static
|
112 |
+
Params
|
113 |
+
to_underlying_arguments(Arguments const& args) {
|
114 |
+
CUTLASS_TRACE_HOST("to_underlying_arguments():");
|
115 |
+
|
116 |
+
// Get SM count if needed, otherwise use user supplied SM count
|
117 |
+
int sm_count = args.hw_info.sm_count;
|
118 |
+
if (sm_count <= 0) {
|
119 |
+
CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n"
|
120 |
+
" For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count.");
|
121 |
+
sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id);
|
122 |
+
}
|
123 |
+
|
124 |
+
CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count);
|
125 |
+
|
126 |
+
cutlass::KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count};
|
127 |
+
return {
|
128 |
+
CollectiveMainloop::to_underlying_arguments(args.mainloop),
|
129 |
+
CollectiveEpilogue::to_underlying_arguments(args.epilogue),
|
130 |
+
hw_info,
|
131 |
+
TileScheduler::to_underlying_arguments(args.scheduler)
|
132 |
+
};
|
133 |
+
}
|
134 |
+
|
135 |
+
// Computes the kernel launch grid shape based on runtime parameters
|
136 |
+
static dim3
|
137 |
+
get_grid_shape(Params const& params) {
|
138 |
+
return TileScheduler::get_grid_shape(params.scheduler, params.hw_info.sm_count * MinBlocksPerMultiprocessor);
|
139 |
+
}
|
140 |
+
|
141 |
+
static dim3
|
142 |
+
get_block_shape() {
|
143 |
+
return dim3(MaxThreadsPerBlock, 1, 1);
|
144 |
+
}
|
145 |
+
|
146 |
+
CUTLASS_DEVICE
|
147 |
+
void
|
148 |
+
operator()(Params const& params, char* smem_buf) {
|
149 |
+
|
150 |
+
static constexpr int kBlockM = get<0>(TileShape_MNK{});
|
151 |
+
|
152 |
+
SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(smem_buf);
|
153 |
+
|
154 |
+
CollectiveMainloop mainloop;
|
155 |
+
CollectiveEpilogue epilogue;
|
156 |
+
|
157 |
+
TileScheduler scheduler(reinterpret_cast<typename TileScheduler::SharedStorage*>(&shared_storage.smem_scheduler));
|
158 |
+
// Initialize matmul objects.
|
159 |
+
TiledMma tiled_mma;
|
160 |
+
|
161 |
+
scheduler.init_consumer();
|
162 |
+
|
163 |
+
int warp_idx = cutlass::canonical_warp_idx_sync();
|
164 |
+
CUTLASS_PRAGMA_NO_UNROLL
|
165 |
+
for (auto work_tile_info = warp_idx == 0 ? scheduler.template get_initial_work</*IsProducerWarp=*/true>(params.scheduler) : scheduler.template get_initial_work</*IsProducerWarp=*/false>(params.scheduler);
|
166 |
+
work_tile_info.is_valid(params.scheduler);
|
167 |
+
work_tile_info = warp_idx == 0 ? scheduler.template get_next_work</*IsProducerWarp=*/true>(params.scheduler, work_tile_info) : scheduler.template get_next_work</*IsProducerWarp=*/false>(params.scheduler, work_tile_info)) {
|
168 |
+
// Attention output (GEMM-II) accumulator.
|
169 |
+
Tensor tOrO = partition_fragment_C(tiled_mma, select<0, 2>(TileShape_MNK{}));
|
170 |
+
float softmax_scale_log2 = params.mainloop.softmax_scale_log2;
|
171 |
+
// If there's tanh softcap, the scaling will be done before tanh.
|
172 |
+
auto block_coord = work_tile_info.get_block_coord(params.scheduler);
|
173 |
+
int const bidb = get<2>(block_coord);
|
174 |
+
if constexpr (Is_FP8 && !Has_softcap) {
|
175 |
+
int const bidh = get<1>(block_coord);
|
176 |
+
int const bidh_kv = !PackGQA ? params.mainloop.qhead_per_khead_divmod.divide(bidh) : bidh;
|
177 |
+
float const q_descale = params.mainloop.ptr_q_descale == nullptr ? 1.0f : params.mainloop.ptr_q_descale[bidb * get<0>(params.mainloop.stride_q_descale) + bidh_kv * get<1>(params.mainloop.stride_q_descale)];
|
178 |
+
float const k_descale = params.mainloop.ptr_k_descale == nullptr ? 1.0f : params.mainloop.ptr_k_descale[bidb * get<0>(params.mainloop.stride_k_descale) + bidh_kv * get<1>(params.mainloop.stride_k_descale)];
|
179 |
+
softmax_scale_log2 *= q_descale * k_descale;
|
180 |
+
}
|
181 |
+
flash::Softmax<2 * (2 * kBlockM / NumThreads), /*Max_offset=*/!Is_FP8 ? 0 : 8> softmax(softmax_scale_log2);
|
182 |
+
|
183 |
+
SeqlenInfo_t seqlen_info{
|
184 |
+
bidb,
|
185 |
+
get<0>(params.mainloop.shape_Q),
|
186 |
+
!PagedKV ? size<0>(params.mainloop.shape_K) : size<0>(params.mainloop.shape_K) * size<1>(params.mainloop.shape_pagetable),
|
187 |
+
get<0>(params.mainloop.shape_K_new),
|
188 |
+
params.mainloop.cu_seqlens_q, params.mainloop.cu_seqlens_k, params.mainloop.cu_seqlens_k_new,
|
189 |
+
params.mainloop.seqused_q, params.mainloop.seqused_k, params.mainloop.leftpad_k,
|
190 |
+
params.mainloop.seqlens_rotary
|
191 |
+
};
|
192 |
+
if constexpr (AppendKV) {
|
193 |
+
bool tile_new_valid = mainloop.store_kv_new(
|
194 |
+
params.mainloop, threadIdx.x, shared_storage, seqlen_info, block_coord);
|
195 |
+
if (tile_new_valid) { __syncthreads(); }
|
196 |
+
}
|
197 |
+
bool tile_valid = mainloop.mma(
|
198 |
+
params.mainloop, tOrO, softmax, threadIdx.x, seqlen_info, block_coord,
|
199 |
+
shared_storage);
|
200 |
+
scheduler.prefetch_next_work(params.scheduler, work_tile_info);
|
201 |
+
if (tile_valid) {
|
202 |
+
// if (threadIdx.x == 128) { printf("Before epilogue, bid.x = %d, bid.y = %d, bid.z = %d, m_block = %d, bidb = %d, split_idx = %d\n", blockIdx.x, blockIdx.y, blockIdx.z, m_block, bidb, split_idx); }
|
203 |
+
epilogue.store(params.epilogue, tOrO, softmax.row_sum, shared_storage, tiled_mma,
|
204 |
+
threadIdx.x, block_coord);
|
205 |
+
} else {
|
206 |
+
// Write 0 to gO and -inf to gLSE.
|
207 |
+
epilogue.store_zero(params.epilogue, threadIdx.x, block_coord);
|
208 |
+
}
|
209 |
+
}
|
210 |
+
|
211 |
+
}
|
212 |
+
|
213 |
+
};
|
214 |
+
|
215 |
+
} // namespace flash
|
flash-attn/flash_fwd_kernel_sm90.h
ADDED
@@ -0,0 +1,458 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/******************************************************************************
|
2 |
+
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
3 |
+
******************************************************************************/
|
4 |
+
|
5 |
+
#pragma once
|
6 |
+
|
7 |
+
#include "cute/tensor.hpp"
|
8 |
+
|
9 |
+
#include <cutlass/cutlass.h>
|
10 |
+
#include <cutlass/arch/reg_reconfig.h>
|
11 |
+
#include <cutlass/array.h>
|
12 |
+
#include <cutlass/numeric_types.h>
|
13 |
+
#include <cutlass/numeric_conversion.h>
|
14 |
+
#include <cutlass/kernel_hardware_info.h>
|
15 |
+
#include "cutlass/pipeline/pipeline.hpp"
|
16 |
+
|
17 |
+
#include "cutlass/arch/grid_dependency_control.h"
|
18 |
+
|
19 |
+
#include "seqlen.h"
|
20 |
+
#include "utils.h"
|
21 |
+
#include "softmax.h"
|
22 |
+
|
23 |
+
namespace flash {
|
24 |
+
|
25 |
+
using namespace cute;
|
26 |
+
|
27 |
+
template <class CollectiveMainloop_, class CollectiveEpilogue_, class TileScheduler_>
|
28 |
+
class FlashAttnFwdSm90 {
|
29 |
+
|
30 |
+
public:
|
31 |
+
|
32 |
+
// Type Aliases
|
33 |
+
using CollectiveMainloop = CollectiveMainloop_;
|
34 |
+
using CollectiveEpilogue = CollectiveEpilogue_;
|
35 |
+
static constexpr bool Is_causal = CollectiveMainloop::Is_causal;
|
36 |
+
static constexpr bool Is_local = CollectiveMainloop::Is_local;
|
37 |
+
static_assert(CollectiveMainloop::Varlen == CollectiveEpilogue::Varlen);
|
38 |
+
static constexpr bool Has_softcap = CollectiveMainloop::Has_softcap;
|
39 |
+
static constexpr bool Varlen = CollectiveMainloop::Varlen;
|
40 |
+
static constexpr bool Split = CollectiveMainloop::Split;
|
41 |
+
static constexpr bool Is_FP8 = CollectiveMainloop::Is_FP8;
|
42 |
+
static constexpr bool Transpose_V = CollectiveMainloop::Transpose_V;
|
43 |
+
static constexpr bool AppendKV = CollectiveMainloop::AppendKV;
|
44 |
+
static constexpr bool HasQv = CollectiveMainloop::HasQv;
|
45 |
+
static constexpr bool Use_TMA_Q = CollectiveMainloop::Use_TMA_Q;
|
46 |
+
static constexpr bool Use_TMA_KV = CollectiveMainloop::Use_TMA_KV;
|
47 |
+
static constexpr bool Use_TMA_O = CollectiveEpilogue::Use_TMA_O;
|
48 |
+
static constexpr bool PackGQA = CollectiveMainloop::PackGQA;
|
49 |
+
static constexpr int NumProducerThreads = CollectiveMainloop::NumProducerThreads;
|
50 |
+
static constexpr bool SameHeadDim = CollectiveMainloop::SameHeadDim;
|
51 |
+
static constexpr bool LargeHeadDimV = CollectiveMainloop::LargeHeadDimV;
|
52 |
+
static_assert(CollectiveMainloop::LargeHeadDimV == CollectiveEpilogue::LargeHeadDimV);
|
53 |
+
using SeqlenInfo_t = typename CollectiveMainloop::SeqlenInfo_t;
|
54 |
+
|
55 |
+
// Mainloop derived types
|
56 |
+
using TileShape_MNK_PV = typename CollectiveMainloop::TileShape_MNK_PV;
|
57 |
+
using TiledMmaPV = typename CollectiveMainloop::TiledMmaPV;
|
58 |
+
using ArchTag = typename CollectiveMainloop::ArchTag;
|
59 |
+
using ClusterShape = typename CollectiveMainloop::ClusterShape;
|
60 |
+
using MainloopArguments = typename CollectiveMainloop::Arguments;
|
61 |
+
using MainloopParams = typename CollectiveMainloop::Params;
|
62 |
+
using BarrierQ = std::conditional_t<Use_TMA_Q, cutlass::arch::ClusterTransactionBarrier, cutlass::arch::ClusterBarrier>;
|
63 |
+
|
64 |
+
// Epilogue derived types
|
65 |
+
using EpilogueArguments = typename CollectiveEpilogue::Arguments;
|
66 |
+
using EpilogueParams = typename CollectiveEpilogue::Params;
|
67 |
+
|
68 |
+
static_assert(ArchTag::kMinComputeCapability >= 90);
|
69 |
+
|
70 |
+
using TileScheduler = TileScheduler_;
|
71 |
+
using TileSchedulerArguments = typename flash::TileSchedulerArguments;
|
72 |
+
using TileSchedulerParams = typename TileScheduler::Params;
|
73 |
+
|
74 |
+
static constexpr uint32_t NumLoadWarpGroups = 1;
|
75 |
+
static constexpr uint32_t NumMmaWarpGroups = CUTE_STATIC_V(size(TiledMmaPV{})) / cutlass::NumThreadsPerWarpGroup;
|
76 |
+
static constexpr uint32_t MaxThreadsPerBlock = CUTE_STATIC_V(size(TiledMmaPV{})) + (NumLoadWarpGroups * cutlass::NumThreadsPerWarpGroup);
|
77 |
+
static constexpr uint32_t MinBlocksPerMultiprocessor = 1;
|
78 |
+
static_assert(NumMmaWarpGroups == 1 || NumMmaWarpGroups == 2 || NumMmaWarpGroups == 3);
|
79 |
+
|
80 |
+
/// Register requirement for Load and Math WGs
|
81 |
+
// If we use cp.async to load K and V, we need more registers for the producer WG.
|
82 |
+
static constexpr uint32_t LoadRegisterRequirement = NumMmaWarpGroups == 1 ? 56 : (NumMmaWarpGroups == 2 ? (Use_TMA_KV ? 24 : 40) : 32);
|
83 |
+
static constexpr uint32_t MmaRegisterRequirement = NumMmaWarpGroups == 1 ? 256 : (NumMmaWarpGroups == 2 ? (Use_TMA_KV ? 240 : 232) : 160);
|
84 |
+
// If you want to print from the producer warp, you'd need to increase the number of registers
|
85 |
+
// Otherwise you'll get CUDA error.
|
86 |
+
// static constexpr uint32_t LoadRegisterRequirement = 40;
|
87 |
+
// static constexpr uint32_t MmaRegisterRequirement = NumMmaWarpGroups == 2 ? 232 : 152;
|
88 |
+
|
89 |
+
// Kernel level shared memory storage
|
90 |
+
// We overlap the shared memory for the mainloop and epilogue. However, we only want smem_o to overlap with smem_v
|
91 |
+
// and nothing else, so we'll pad in case sizeof(smem_o) > sizeof(smem_v).
|
92 |
+
static constexpr int mainloop_smem_padding_ = int(sizeof(typename CollectiveEpilogue::TensorStorage)) - int(sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_v)));
|
93 |
+
static constexpr int mainloop_smem_padding = mainloop_smem_padding_ < 0 ? 0 : mainloop_smem_padding_;
|
94 |
+
struct SharedStorage {
|
95 |
+
struct TensorStorage : cute::aligned_struct<128, _1> {
|
96 |
+
union {
|
97 |
+
struct {
|
98 |
+
cute::array<uint32_t, mainloop_smem_padding / sizeof(uint32_t)> padding_;
|
99 |
+
typename CollectiveMainloop::TensorStorage mainloop;
|
100 |
+
};
|
101 |
+
// We want smem_o to line up with the start of smem_v
|
102 |
+
typename CollectiveEpilogue::TensorStorage epilogue;
|
103 |
+
};
|
104 |
+
} tensors;
|
105 |
+
struct PipelineStorage : cute::aligned_struct<16, _1> {
|
106 |
+
alignas(16) BarrierQ barrier_Q;
|
107 |
+
alignas(16) BarrierQ barrier_Qv;
|
108 |
+
alignas(16) cutlass::arch::ClusterBarrier barrier_O;
|
109 |
+
alignas(16) typename CollectiveMainloop::MainloopPipelineK::SharedStorage pipeline_k;
|
110 |
+
alignas(16) typename CollectiveMainloop::MainloopPipelineV::SharedStorage pipeline_v;
|
111 |
+
alignas(16) typename CollectiveMainloop::MainloopPipelineVt::SharedStorage pipeline_vt;
|
112 |
+
alignas(16) typename CollectiveMainloop::MainloopPipelineKVNew::SharedStorage pipeline_k_new;
|
113 |
+
alignas(16) typename CollectiveMainloop::MainloopPipelineKVNew::SharedStorage pipeline_v_new;
|
114 |
+
alignas(16) typename TileScheduler::SharedStorage smem_scheduler;
|
115 |
+
} pipelines;
|
116 |
+
|
117 |
+
};
|
118 |
+
|
119 |
+
static constexpr int SharedStorageSize = sizeof(SharedStorage);
|
120 |
+
|
121 |
+
// Device side arguments
|
122 |
+
struct Arguments {
|
123 |
+
MainloopArguments mainloop{};
|
124 |
+
EpilogueArguments epilogue{};
|
125 |
+
cutlass::KernelHardwareInfo hw_info{};
|
126 |
+
TileSchedulerArguments scheduler{};
|
127 |
+
};
|
128 |
+
|
129 |
+
// Kernel entry point API
|
130 |
+
struct Params {
|
131 |
+
MainloopParams mainloop{};
|
132 |
+
EpilogueParams epilogue{};
|
133 |
+
cutlass::KernelHardwareInfo hw_info{};
|
134 |
+
TileSchedulerParams scheduler{};
|
135 |
+
};
|
136 |
+
|
137 |
+
//
|
138 |
+
// Methods
|
139 |
+
//
|
140 |
+
|
141 |
+
// Convert to underlying arguments. In this case, a simple copy for the aliased type.
|
142 |
+
static
|
143 |
+
Params
|
144 |
+
to_underlying_arguments(Arguments const& args) {
|
145 |
+
CUTLASS_TRACE_HOST("to_underlying_arguments():");
|
146 |
+
|
147 |
+
// Get SM count if needed, otherwise use user supplied SM count
|
148 |
+
int sm_count = args.hw_info.sm_count;
|
149 |
+
if (sm_count <= 0) {
|
150 |
+
CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n"
|
151 |
+
" For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count.");
|
152 |
+
sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id);
|
153 |
+
}
|
154 |
+
|
155 |
+
CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count);
|
156 |
+
|
157 |
+
cutlass::KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count};
|
158 |
+
return {
|
159 |
+
CollectiveMainloop::to_underlying_arguments(args.mainloop),
|
160 |
+
CollectiveEpilogue::to_underlying_arguments(args.epilogue),
|
161 |
+
hw_info,
|
162 |
+
TileScheduler::to_underlying_arguments(args.scheduler)
|
163 |
+
};
|
164 |
+
}
|
165 |
+
|
166 |
+
// Computes the kernel launch grid shape based on runtime parameters
|
167 |
+
static dim3
|
168 |
+
get_grid_shape(Params const& params) {
|
169 |
+
return TileScheduler::get_grid_shape(params.scheduler, params.hw_info.sm_count);
|
170 |
+
}
|
171 |
+
|
172 |
+
static dim3
|
173 |
+
get_block_shape() {
|
174 |
+
return dim3(MaxThreadsPerBlock, 1, 1);
|
175 |
+
}
|
176 |
+
|
177 |
+
CUTLASS_DEVICE
|
178 |
+
void
|
179 |
+
operator()(Params const& params, char* smem_buf) {
|
180 |
+
|
181 |
+
static constexpr int NumMmaThreads = NumMmaWarpGroups * cutlass::NumThreadsPerWarpGroup;
|
182 |
+
static constexpr int MmaThreadOffset = NumLoadWarpGroups * cutlass::NumThreadsPerWarpGroup;
|
183 |
+
static constexpr int kBlockM = get<0>(TileShape_MNK_PV{});
|
184 |
+
|
185 |
+
using MainloopPipelineK = typename CollectiveMainloop::MainloopPipelineK;
|
186 |
+
using MainloopPipelineV = typename CollectiveMainloop::MainloopPipelineV;
|
187 |
+
using MainloopPipelineVt = typename CollectiveMainloop::MainloopPipelineVt;
|
188 |
+
using MainloopPipelineKVNew = typename CollectiveMainloop::MainloopPipelineKVNew;
|
189 |
+
using PipelineState = typename CollectiveMainloop::PipelineState;
|
190 |
+
using PipelineParamsK = typename MainloopPipelineK::Params;
|
191 |
+
using PipelineParamsV = typename MainloopPipelineV::Params;
|
192 |
+
using PipelineParamsVt = typename MainloopPipelineVt::Params;
|
193 |
+
using PipelineParamsKVNew = typename MainloopPipelineKVNew::Params;
|
194 |
+
|
195 |
+
SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(smem_buf);
|
196 |
+
|
197 |
+
int const lane_predicate = cute::elect_one_sync();
|
198 |
+
int const warp_idx = cutlass::canonical_warp_idx_sync();
|
199 |
+
|
200 |
+
// Issue Tma Descriptor Prefetch from a single thread
|
201 |
+
if (warp_idx == 0 && lane_predicate) {
|
202 |
+
CollectiveMainloop::prefetch_tma_descriptors(params.mainloop);
|
203 |
+
CollectiveEpilogue::prefetch_tma_descriptors(params.epilogue);
|
204 |
+
}
|
205 |
+
|
206 |
+
// Obtain warp index
|
207 |
+
int const warp_group_thread_idx = threadIdx.x % cutlass::NumThreadsPerWarpGroup;
|
208 |
+
int warp_group_idx = cutlass::canonical_warp_group_idx();
|
209 |
+
|
210 |
+
if (warp_idx == 0 && lane_predicate) {
|
211 |
+
shared_storage.pipelines.barrier_Q.init(Use_TMA_Q ? 1 : NumProducerThreads /*numThreads*/);
|
212 |
+
if constexpr (HasQv) {
|
213 |
+
shared_storage.pipelines.barrier_Qv.init(Use_TMA_Q ? 1 : NumProducerThreads /*numThreads*/);
|
214 |
+
}
|
215 |
+
shared_storage.pipelines.barrier_O.init(size(ClusterShape{}) * (Use_TMA_O ? 1 : NumMmaThreads) /*numThreads*/);
|
216 |
+
}
|
217 |
+
|
218 |
+
// We're counting on pipeline_k to call cutlass::arch::fence_barrier_init();
|
219 |
+
PipelineParamsK pipeline_params_k;
|
220 |
+
pipeline_params_k.role = warp_group_idx == 0
|
221 |
+
? MainloopPipelineK::ThreadCategory::Producer
|
222 |
+
: MainloopPipelineK::ThreadCategory::Consumer;
|
223 |
+
if constexpr (Use_TMA_KV) {
|
224 |
+
pipeline_params_k.transaction_bytes = CollectiveMainloop::TmaTransactionBytesK;
|
225 |
+
pipeline_params_k.is_leader = warp_group_thread_idx == 0;
|
226 |
+
pipeline_params_k.num_consumers = !LargeHeadDimV ? NumMmaThreads : cutlass::NumThreadsPerWarpGroup;
|
227 |
+
} else {
|
228 |
+
pipeline_params_k.consumer_arv_count = !LargeHeadDimV ? NumMmaThreads : cutlass::NumThreadsPerWarpGroup;
|
229 |
+
pipeline_params_k.producer_arv_count = NumProducerThreads;
|
230 |
+
}
|
231 |
+
|
232 |
+
static_assert(is_same_v<PipelineParamsK, PipelineParamsVt>);
|
233 |
+
PipelineParamsVt pipeline_params_vt = pipeline_params_k;
|
234 |
+
if constexpr (Use_TMA_KV && !SameHeadDim) {
|
235 |
+
pipeline_params_vt.transaction_bytes = CollectiveMainloop::TmaTransactionBytesV;
|
236 |
+
if constexpr (LargeHeadDimV) { pipeline_params_vt.num_consumers = NumMmaThreads; }
|
237 |
+
} else {
|
238 |
+
if constexpr (LargeHeadDimV) { pipeline_params_vt.consumer_arv_count = NumMmaThreads; }
|
239 |
+
}
|
240 |
+
|
241 |
+
MainloopPipelineK pipeline_k = [&] {
|
242 |
+
if constexpr (Use_TMA_KV) {
|
243 |
+
return MainloopPipelineK(shared_storage.pipelines.pipeline_k, pipeline_params_k, ClusterShape{});
|
244 |
+
} else {
|
245 |
+
return MainloopPipelineK(shared_storage.pipelines.pipeline_k, pipeline_params_k);
|
246 |
+
}
|
247 |
+
}();
|
248 |
+
// MainloopPipelineV pipeline_v(shared_storage.pipelines.pipeline_v, pipeline_params_v, ClusterShape{});
|
249 |
+
MainloopPipelineV pipeline_v = [&] {
|
250 |
+
if constexpr (!Transpose_V) {
|
251 |
+
static_assert(is_same_v<PipelineParamsK, PipelineParamsV>);
|
252 |
+
if constexpr (Use_TMA_KV) {
|
253 |
+
return MainloopPipelineV(shared_storage.pipelines.pipeline_v, pipeline_params_vt, ClusterShape{});
|
254 |
+
} else {
|
255 |
+
return MainloopPipelineV(shared_storage.pipelines.pipeline_v, pipeline_params_vt);
|
256 |
+
}
|
257 |
+
} else {
|
258 |
+
PipelineParamsV pipeline_params_v;
|
259 |
+
pipeline_params_v.role = warp_group_idx == 0
|
260 |
+
? MainloopPipelineV::ThreadCategory::Producer
|
261 |
+
: MainloopPipelineV::ThreadCategory::Consumer;
|
262 |
+
pipeline_params_v.producer_arv_count = NumProducerThreads;
|
263 |
+
pipeline_params_v.consumer_arv_count = NumMmaThreads;
|
264 |
+
return MainloopPipelineV(shared_storage.pipelines.pipeline_v, pipeline_params_v);
|
265 |
+
}
|
266 |
+
}();
|
267 |
+
// If we need to transpose V (e.g. FP8 and V is row-major), we use pipeline_vt for the TMA, then
|
268 |
+
// the producer WG will read from pipeline_vt and write to pipeline_v.
|
269 |
+
// If we don't need to transpose V, we use pipeline_v for the TMA, and pipeline_vt won't be used.
|
270 |
+
// Technically for pipeline_params_vt, warp0 of WG0 is the producer and all of WG0 are consumers.
|
271 |
+
// However, the thread role isn't used in the pipeline implementation.
|
272 |
+
MainloopPipelineVt pipeline_vt = [&] {
|
273 |
+
if constexpr (Use_TMA_KV) {
|
274 |
+
pipeline_params_vt.num_consumers = NumProducerThreads; // TMA_V is only consumed by the producer WG
|
275 |
+
return MainloopPipelineVt(shared_storage.pipelines.pipeline_vt, pipeline_params_vt, ClusterShape{});
|
276 |
+
} else {
|
277 |
+
pipeline_params_vt.consumer_arv_count = NumProducerThreads; // TMA_V is only consumed by the producer WG
|
278 |
+
return MainloopPipelineVt(shared_storage.pipelines.pipeline_vt, pipeline_params_vt);
|
279 |
+
}
|
280 |
+
}();
|
281 |
+
|
282 |
+
PipelineParamsKVNew pipeline_params_kv_new;
|
283 |
+
pipeline_params_kv_new.role = warp_group_idx == 0
|
284 |
+
? MainloopPipelineKVNew::ThreadCategory::Producer
|
285 |
+
: MainloopPipelineKVNew::ThreadCategory::Consumer;
|
286 |
+
pipeline_params_kv_new.transaction_bytes = CollectiveMainloop::TmaTransactionBytesK;
|
287 |
+
pipeline_params_kv_new.is_leader = warp_group_thread_idx == 0;
|
288 |
+
pipeline_params_kv_new.num_consumers = NumMmaThreads;
|
289 |
+
auto pipeline_k_new = cute::conditional_return<AppendKV>(MainloopPipelineKVNew(shared_storage.pipelines.pipeline_k_new, pipeline_params_kv_new, ClusterShape{}), nullptr);
|
290 |
+
if constexpr (!SameHeadDim) {
|
291 |
+
pipeline_params_kv_new.transaction_bytes = CollectiveMainloop::TmaTransactionBytesV;
|
292 |
+
}
|
293 |
+
auto pipeline_v_new = cute::conditional_return<AppendKV>(MainloopPipelineKVNew(shared_storage.pipelines.pipeline_v_new, pipeline_params_kv_new, ClusterShape{}), nullptr);
|
294 |
+
|
295 |
+
CollectiveMainloop mainloop;
|
296 |
+
CollectiveEpilogue epilogue;
|
297 |
+
|
298 |
+
// We need this to guarantee that the Pipeline init is visible to all producers and consumer blocks in the Cluster
|
299 |
+
if constexpr (size(ClusterShape{}) > 1) {
|
300 |
+
cute::cluster_arrive_relaxed();
|
301 |
+
cute::cluster_wait();
|
302 |
+
} else {
|
303 |
+
__syncthreads();
|
304 |
+
}
|
305 |
+
|
306 |
+
TileScheduler scheduler(reinterpret_cast<typename TileScheduler::SharedStorage*>(&shared_storage.pipelines.smem_scheduler));
|
307 |
+
|
308 |
+
if (warp_group_idx == 0) { // Producer
|
309 |
+
cutlass::arch::warpgroup_reg_dealloc<LoadRegisterRequirement>();
|
310 |
+
|
311 |
+
// The pipelines for AppendKV and main attention are different, since e.g. main attention
|
312 |
+
// might use cp.async to load KV (if PagedKVNonTMA) while AppendKV always uses TMA to load
|
313 |
+
// KV_new. Since the pipeline states are different, we have to manually sync to make
|
314 |
+
// sure the two pipelines don't race when accessing smem_k and smem_v.
|
315 |
+
PipelineState smem_pipe_write = cutlass::make_producer_start_state<MainloopPipelineK>();
|
316 |
+
PipelineState smem_pipe_write_new = cutlass::make_producer_start_state<MainloopPipelineKVNew>();
|
317 |
+
int work_idx = 0;
|
318 |
+
int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0);
|
319 |
+
static constexpr bool SingleProducerWarp = NumProducerThreads == cutlass::NumThreadsPerWarp;
|
320 |
+
if constexpr (SingleProducerWarp) {
|
321 |
+
if (warp_idx_in_warpgroup != 0) { return; }
|
322 |
+
}
|
323 |
+
if (!SingleProducerWarp && warp_idx_in_warpgroup != 0) { scheduler.init_consumer(); }
|
324 |
+
|
325 |
+
cutlass::arch::wait_on_dependent_grids();
|
326 |
+
|
327 |
+
// Load Q, K, V
|
328 |
+
for (auto work_tile_info = SingleProducerWarp || warp_idx_in_warpgroup == 0 ? scheduler.template get_initial_work</*IsProducerWarp=*/true>(params.scheduler) : scheduler.template get_initial_work</*IsProducerWarp=*/false>(params.scheduler);
|
329 |
+
work_tile_info.is_valid(params.scheduler);
|
330 |
+
work_tile_info = SingleProducerWarp || warp_idx_in_warpgroup == 0 ? scheduler.template get_next_work</*IsProducerWarp=*/true>(params.scheduler, work_tile_info) : scheduler.template get_next_work</*IsProducerWarp=*/false>(params.scheduler, work_tile_info)) {
|
331 |
+
|
332 |
+
auto block_coord = work_tile_info.get_block_coord(params.scheduler);
|
333 |
+
SeqlenInfo_t seqlen_info{
|
334 |
+
get<2>(block_coord) /*bidb*/,
|
335 |
+
get<0>(params.mainloop.shape_Q),
|
336 |
+
!params.mainloop.ptr_pagetable ? size<0>(params.mainloop.shape_K) : size<0>(params.mainloop.shape_K) * size<1>(params.mainloop.shape_pagetable),
|
337 |
+
get<0>(params.mainloop.shape_K_new),
|
338 |
+
params.mainloop.cu_seqlens_q, params.mainloop.cu_seqlens_k, params.mainloop.cu_seqlens_k_new,
|
339 |
+
params.mainloop.seqused_q, params.mainloop.seqused_k, params.mainloop.leftpad_k,
|
340 |
+
params.mainloop.seqlens_rotary
|
341 |
+
};
|
342 |
+
if constexpr (AppendKV) {
|
343 |
+
bool tile_new_valid = mainloop.load_kv_new(
|
344 |
+
params.mainloop, pipeline_k_new, pipeline_v_new,
|
345 |
+
smem_pipe_write_new, shared_storage, seqlen_info, block_coord, work_idx);
|
346 |
+
if (tile_new_valid) {
|
347 |
+
// if (threadIdx.x == 0) { printf("Producer: Before sync\n"); }
|
348 |
+
cutlass::arch::NamedBarrier::sync(NumMmaThreads + NumProducerThreads, static_cast<uint32_t>(FwdNamedBarriers::AppendKV) /*id*/);
|
349 |
+
// if (threadIdx.x == 0) { printf("Producer: After sync\n"); }
|
350 |
+
}
|
351 |
+
}
|
352 |
+
auto scheduler_prefetch = [&scheduler, ¶ms, &work_tile_info]() {
|
353 |
+
scheduler.prefetch_next_work(params.scheduler, work_tile_info);
|
354 |
+
};
|
355 |
+
// pipeline_vt won't be used if we don't need to transpose V.
|
356 |
+
mainloop.load(params.mainloop, pipeline_k, pipeline_v, pipeline_vt, smem_pipe_write,
|
357 |
+
shared_storage, scheduler_prefetch, seqlen_info, block_coord, work_idx);
|
358 |
+
}
|
359 |
+
mainloop.load_tail(pipeline_k, pipeline_v, pipeline_vt, smem_pipe_write, shared_storage, work_idx);
|
360 |
+
} else { // Consumer
|
361 |
+
cutlass::arch::warpgroup_reg_alloc<MmaRegisterRequirement>();
|
362 |
+
|
363 |
+
// Initialize matmul objects.
|
364 |
+
TiledMmaPV tiled_mma_pv;
|
365 |
+
|
366 |
+
PipelineState smem_pipe_read;
|
367 |
+
PipelineState smem_pipe_read_new;
|
368 |
+
// We don't need separate variables smem_pipe_release_k and smem_pipe_release_v
|
369 |
+
// (like in Cutlass's gemm) because the read and release pipeline states are always the same.
|
370 |
+
|
371 |
+
scheduler.init_consumer();
|
372 |
+
mainloop.mma_init();
|
373 |
+
|
374 |
+
int work_idx = 0;
|
375 |
+
CUTLASS_PRAGMA_NO_UNROLL
|
376 |
+
for (auto work_tile_info = scheduler.template get_initial_work</*IsProducerWarp=*/false>(params.scheduler);
|
377 |
+
work_tile_info.is_valid(params.scheduler);
|
378 |
+
// get_next_work will be called before the epilogue
|
379 |
+
) {
|
380 |
+
auto block_coord = work_tile_info.get_block_coord(params.scheduler);
|
381 |
+
int const bidb = get<2>(block_coord);
|
382 |
+
SeqlenInfo_t seqlen_info{
|
383 |
+
bidb,
|
384 |
+
get<0>(params.mainloop.shape_Q),
|
385 |
+
!params.mainloop.ptr_pagetable ? size<0>(params.mainloop.shape_K) : size<0>(params.mainloop.shape_K) * size<1>(params.mainloop.shape_pagetable),
|
386 |
+
get<0>(params.mainloop.shape_K_new),
|
387 |
+
params.mainloop.cu_seqlens_q, params.mainloop.cu_seqlens_k, params.mainloop.cu_seqlens_k_new,
|
388 |
+
params.mainloop.seqused_q, params.mainloop.seqused_k, params.mainloop.leftpad_k,
|
389 |
+
params.mainloop.seqlens_rotary
|
390 |
+
};
|
391 |
+
if constexpr (AppendKV) {
|
392 |
+
bool tile_new_valid = mainloop.store_kv_new(
|
393 |
+
params.mainloop, pipeline_k_new, pipeline_v_new, smem_pipe_read_new,
|
394 |
+
threadIdx.x - MmaThreadOffset, shared_storage, seqlen_info, block_coord);
|
395 |
+
if (tile_new_valid) {
|
396 |
+
// if (threadIdx.x == 128) { printf("Consumer: Before sync\n"); }
|
397 |
+
// We need this sync so that the gmem write from the consumers is visible to the producer
|
398 |
+
// that might do TMA read after that.
|
399 |
+
asm volatile ("fence.proxy.async.global;");
|
400 |
+
cutlass::arch::NamedBarrier::arrive(NumMmaThreads + NumProducerThreads, static_cast<uint32_t>(FwdNamedBarriers::AppendKV) /*id*/);
|
401 |
+
// arrive is enough, we don't need sync. The producer will sync, which means
|
402 |
+
// after that sync we're guaranteed that the AppendKV pipeline have finished
|
403 |
+
// loading and consumer smem_k and smem_v.
|
404 |
+
// if (threadIdx.x == 128) { printf("Consumer: After sync\n"); }
|
405 |
+
}
|
406 |
+
}
|
407 |
+
// If there's tanh softcap, the scaling will be done before tanh.
|
408 |
+
float softmax_scale_log2 = params.mainloop.softmax_scale_log2;
|
409 |
+
if constexpr (Is_FP8 && !Has_softcap) {
|
410 |
+
int const bidh = get<1>(block_coord);
|
411 |
+
int const bidh_kv = !PackGQA ? params.mainloop.qhead_per_khead_divmod.divide(bidh) : bidh;
|
412 |
+
float const q_descale = params.mainloop.ptr_q_descale == nullptr ? 1.0f : params.mainloop.ptr_q_descale[bidb * get<0>(params.mainloop.stride_q_descale) + bidh_kv * get<1>(params.mainloop.stride_q_descale)];
|
413 |
+
float const k_descale = params.mainloop.ptr_k_descale == nullptr ? 1.0f : params.mainloop.ptr_k_descale[bidb * get<0>(params.mainloop.stride_k_descale) + bidh_kv * get<1>(params.mainloop.stride_k_descale)];
|
414 |
+
softmax_scale_log2 *= q_descale * k_descale;
|
415 |
+
}
|
416 |
+
flash::Softmax<!LargeHeadDimV ? 2 * (2 * kBlockM / NumMmaThreads) : 2, /*Max_offset=*/!Is_FP8 ? 0 : 8> softmax(softmax_scale_log2);
|
417 |
+
// Attention output (GEMM-II) accumulator.
|
418 |
+
Tensor tOrO = partition_fragment_C(tiled_mma_pv, select<0, 1>(TileShape_MNK_PV{}));
|
419 |
+
bool tile_valid;
|
420 |
+
if constexpr (!LargeHeadDimV) {
|
421 |
+
tile_valid = mainloop.mma(
|
422 |
+
params.mainloop, pipeline_k, pipeline_v, smem_pipe_read,
|
423 |
+
tOrO, softmax, threadIdx.x - MmaThreadOffset, work_idx, seqlen_info, block_coord, shared_storage);
|
424 |
+
} else { // mma_pv might not compile if !LargeHeadDimV
|
425 |
+
if (warp_group_idx == 1) {
|
426 |
+
tile_valid = mainloop.mma(
|
427 |
+
params.mainloop, pipeline_k, pipeline_v, smem_pipe_read,
|
428 |
+
tOrO, softmax, threadIdx.x - MmaThreadOffset, work_idx, seqlen_info, block_coord, shared_storage);
|
429 |
+
} else {
|
430 |
+
tile_valid = mainloop.mma_pv(
|
431 |
+
params.mainloop, pipeline_v, smem_pipe_read,
|
432 |
+
tOrO, softmax, threadIdx.x - MmaThreadOffset, seqlen_info, block_coord, shared_storage);
|
433 |
+
}
|
434 |
+
}
|
435 |
+
// Do this here before the epilogue so that the next tile is ready to go.
|
436 |
+
work_tile_info = scheduler.template get_next_work</*IsProducerWarp=*/false>(params.scheduler, work_tile_info);
|
437 |
+
if constexpr (Split && Varlen) {
|
438 |
+
if (!work_tile_info.is_valid(params.scheduler)) { // Last tile
|
439 |
+
cutlass::arch::launch_dependent_grids();
|
440 |
+
}
|
441 |
+
}
|
442 |
+
if (tile_valid) {
|
443 |
+
// if (threadIdx.x == 128) { printf("Before epilogue, bid.x = %d, bid.y = %d, bid.z = %d, m_block = %d, bidb = %d, split_idx = %d\n", blockIdx.x, blockIdx.y, blockIdx.z, m_block, bidb, split_idx); }
|
444 |
+
epilogue.store(params.epilogue, tOrO, softmax.row_sum, shared_storage, tiled_mma_pv,
|
445 |
+
threadIdx.x - MmaThreadOffset, block_coord);
|
446 |
+
} else {
|
447 |
+
// Write 0 to gO and -inf to gLSE.
|
448 |
+
epilogue.store_zero(params.epilogue, threadIdx.x - MmaThreadOffset, block_coord);
|
449 |
+
}
|
450 |
+
}
|
451 |
+
epilogue.store_tail();
|
452 |
+
}
|
453 |
+
|
454 |
+
}
|
455 |
+
|
456 |
+
};
|
457 |
+
|
458 |
+
} // namespace flash
|
flash-attn/flash_fwd_launch_template.h
ADDED
@@ -0,0 +1,223 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/******************************************************************************
|
2 |
+
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
3 |
+
******************************************************************************/
|
4 |
+
|
5 |
+
#pragma once
|
6 |
+
|
7 |
+
#include "cute/tensor.hpp"
|
8 |
+
|
9 |
+
#include "cutlass/cutlass.h"
|
10 |
+
#include "cutlass/device_kernel.h" // For device_kernel
|
11 |
+
#include <cutlass/kernel_hardware_info.h>
|
12 |
+
#include "cutlass/cluster_launch.hpp"
|
13 |
+
#include "cutlass/kernel_launch.h"
|
14 |
+
|
15 |
+
#include "static_switch.h"
|
16 |
+
#include "flash.h"
|
17 |
+
#include "tile_size.h"
|
18 |
+
#include "tile_scheduler.hpp"
|
19 |
+
#include "flash_fwd_kernel_sm90.h"
|
20 |
+
#include "flash_fwd_kernel_sm80.h"
|
21 |
+
#include "mainloop_fwd_sm90_tma_gmma_ws.hpp"
|
22 |
+
#include "mainloop_fwd_sm80.hpp"
|
23 |
+
#include "epilogue_fwd.hpp"
|
24 |
+
|
25 |
+
using namespace cute;
|
26 |
+
|
27 |
+
template <int Arch, int kHeadDim, int kHeadDimV, int ClusterM, typename Element, typename ElementOut,
|
28 |
+
bool Is_causal, bool Is_local, bool Has_softcap, bool Varlen, bool PagedKVNonTMA, bool AppendKV, bool HasQv,
|
29 |
+
bool PackGQA, bool Split, bool V_colmajor>
|
30 |
+
void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
31 |
+
static_assert(!(Is_causal && Is_local), "Causal and Local cannot be enabled at the same time");
|
32 |
+
static_assert(!(AppendKV && V_colmajor), "AppendKV and V_colmajor cannot be enabled at the same time");
|
33 |
+
static_assert(!(AppendKV && !Varlen), "AppendKV requires Varlen");
|
34 |
+
static constexpr bool Is_FP8 = cute::is_same_v<Element, cutlass::float_e4m3_t> || cute::is_same_v<Element, cutlass::float_e5m2_t>;
|
35 |
+
static constexpr bool FP8_TransposeV = Is_FP8 && !V_colmajor;
|
36 |
+
using ArchTag = std::conditional_t<Arch >= 90, cutlass::arch::Sm90, cutlass::arch::Sm80>;
|
37 |
+
|
38 |
+
// Can't use structured binding since it's not compatible with constexpr
|
39 |
+
static constexpr std::tuple<int, int, bool, bool> kBlockMN_RS_IntraWGOverlap = tile_size_fwd_sm90(kHeadDim, kHeadDimV, Is_causal, Is_local, sizeof(Element) /*element_size*/, V_colmajor, PagedKVNonTMA, Has_softcap);
|
40 |
+
static constexpr std::tuple<int, int, int, int, bool> kBlockMN_kNWarps_Stages_RS = tile_size_fwd_sm8x(Arch == 86 || Arch == 89, kHeadDim, kHeadDimV, Is_causal, Is_local, sizeof(Element) /*element_size*/, PagedKVNonTMA, Varlen && Split, Has_softcap, AppendKV);
|
41 |
+
static constexpr int kBlockM = Arch >= 90 ? std::get<0>(kBlockMN_RS_IntraWGOverlap) : std::get<0>(kBlockMN_kNWarps_Stages_RS);
|
42 |
+
static constexpr int kBlockN = Arch >= 90 ? std::get<1>(kBlockMN_RS_IntraWGOverlap) : std::get<1>(kBlockMN_kNWarps_Stages_RS);
|
43 |
+
static constexpr bool MmaPV_is_RS = std::get<2>(kBlockMN_RS_IntraWGOverlap);
|
44 |
+
static constexpr bool IntraWGOverlap = std::get<3>(kBlockMN_RS_IntraWGOverlap);
|
45 |
+
static constexpr int kNWarps = std::get<2>(kBlockMN_kNWarps_Stages_RS);
|
46 |
+
static constexpr int kStages = Arch >= 90 ? 2 : std::get<3>(kBlockMN_kNWarps_Stages_RS);
|
47 |
+
static constexpr bool Q_in_regs = Arch >= 90 ? false : std::get<4>(kBlockMN_kNWarps_Stages_RS);
|
48 |
+
|
49 |
+
using TileShape_MNK = cute::Shape<Int<kBlockM>, Int<kBlockN>, Int<kHeadDim>>;
|
50 |
+
using TileShape_MNK_PV = cute::Shape<Int<kBlockM>, Int<kHeadDimV>, Int<kBlockN>>;
|
51 |
+
using ClusterShape = cute::Shape<Int<ClusterM>, _1, _1>;
|
52 |
+
using CollectiveMainloop = std::conditional_t<
|
53 |
+
Arch >= 90,
|
54 |
+
flash::CollectiveMainloopFwdSm90<kStages, ClusterShape, TileShape_MNK, kHeadDimV, Element, float, cutlass::arch::Sm90, Is_causal, Is_local, Has_softcap, Varlen, PagedKVNonTMA, AppendKV, HasQv, MmaPV_is_RS, IntraWGOverlap, PackGQA, Split, V_colmajor>,
|
55 |
+
flash::CollectiveMainloopFwdSm80<kNWarps, kStages, Q_in_regs, TileShape_MNK, kHeadDimV, Element, float, cutlass::arch::Sm80, Is_causal, Is_local, Has_softcap, Varlen, PagedKVNonTMA, AppendKV, PackGQA, Split>
|
56 |
+
>;
|
57 |
+
using CollectiveEpilogue = flash::CollectiveEpilogueFwd<TileShape_MNK_PV, ClusterShape, ElementOut, ArchTag, CollectiveMainloop::NumMmaThreads, Varlen, PackGQA, Split, FP8_TransposeV>;
|
58 |
+
|
59 |
+
static constexpr int NumProducerThreads = Arch >= 90 ? CollectiveMainloop::NumProducerThreads : CollectiveMainloop::NumMmaThreads;
|
60 |
+
using SchedulerPersistent = std::conditional_t<Varlen,
|
61 |
+
flash::VarlenDynamicPersistentTileScheduler<kBlockM, CollectiveMainloop::NumMmaThreads, NumProducerThreads, Split, PackGQA, Arch >= 90 /*WarpSpecialized*/>,
|
62 |
+
std::conditional_t<!Is_causal && !Is_local,
|
63 |
+
flash::StaticPersistentTileScheduler<Split>,
|
64 |
+
flash::DynamicPersistentTileScheduler<CollectiveMainloop::NumMmaThreads, NumProducerThreads, Split, PackGQA, Arch >= 90 /*WarpSpecialized*/>
|
65 |
+
>
|
66 |
+
>;
|
67 |
+
using SchedulerSingleTile = flash::SingleTileScheduler<Varlen, Split, PackGQA, kBlockM>;
|
68 |
+
// If Split then we probably don't have enough work for PersistentScheduler to be useful.
|
69 |
+
// However, if Varlen (e.g., during decode where we have max_seqlens), using PersistentScheduler is better
|
70 |
+
// since we'll avoid launching a bunch of thread blocks that immediately exit.
|
71 |
+
// On Sm80, noncausal persistent seems a bit slower.
|
72 |
+
static constexpr bool UsePersistentScheduler = Arch >= 90 ? !(Split && !Varlen) : ((Is_causal && !Varlen) || (Varlen && Split));
|
73 |
+
using Scheduler = std::conditional_t<!UsePersistentScheduler, SchedulerSingleTile, SchedulerPersistent>;
|
74 |
+
using AttnKernel = std::conditional_t<
|
75 |
+
Arch >= 90,
|
76 |
+
flash::enable_sm90_or_later<flash::FlashAttnFwdSm90<CollectiveMainloop, CollectiveEpilogue, Scheduler>>,
|
77 |
+
flash::enable_sm80_to_sm89<flash::FlashAttnFwdSm80<CollectiveMainloop, CollectiveEpilogue, Scheduler>>
|
78 |
+
>;
|
79 |
+
|
80 |
+
bool const is_varlen_q = params.cu_seqlens_q;
|
81 |
+
bool const is_varlen_k = params.cu_seqlens_k;
|
82 |
+
bool const is_varlen_k_new = params.cu_seqlens_knew;
|
83 |
+
int seqlen_q = !is_varlen_q ? params.seqlen_q : params.total_q;
|
84 |
+
int batch_q = !is_varlen_q ? params.b : 1;
|
85 |
+
int batch_k = !is_varlen_k ? (params.kv_batch_idx ? params.b_k : params.b) : 1;
|
86 |
+
typename CollectiveMainloop::StrideV v_strides =
|
87 |
+
cute::conditional_return<!V_colmajor>(
|
88 |
+
make_stride(params.v_row_stride, _1{}, params.v_head_stride, !is_varlen_k ? params.v_batch_stride : 0),
|
89 |
+
make_stride(_1{}, params.v_dim_stride, params.v_head_stride, !is_varlen_k ? params.v_batch_stride : 0));
|
90 |
+
typename CollectiveMainloop::Arguments mainloop_args {
|
91 |
+
static_cast<Element const*>(params.q_ptr),
|
92 |
+
{seqlen_q, params.d, params.h, batch_q}, // shape_Q
|
93 |
+
{params.q_row_stride, _1{}, params.q_head_stride, !is_varlen_q ? params.q_batch_stride : 0}, // stride_Q
|
94 |
+
static_cast<Element*>(params.k_ptr),
|
95 |
+
{!params.page_table ? (!is_varlen_k ? params.seqlen_k : params.total_k) : params.page_size,
|
96 |
+
params.d, params.h_k, !params.page_table ? batch_k : params.num_pages}, // shape_K
|
97 |
+
{params.k_row_stride, _1{}, params.k_head_stride, !is_varlen_k ? params.k_batch_stride : 0}, // stride_K
|
98 |
+
static_cast<Element*>(params.v_ptr),
|
99 |
+
params.dv, // headdim_v
|
100 |
+
v_strides, // stride_V
|
101 |
+
static_cast<Element const*>(params.knew_ptr),
|
102 |
+
{!is_varlen_k_new ? params.seqlen_knew : params.total_knew, params.d, params.h_k, !is_varlen_k_new ? params.b : 1}, // shape_K_new
|
103 |
+
{params.knew_row_stride, _1{}, params.knew_head_stride, !is_varlen_k_new ? params.knew_batch_stride : 0}, // stride_K_new
|
104 |
+
static_cast<Element const*>(params.vnew_ptr),
|
105 |
+
{params.vnew_row_stride, _1{}, params.vnew_head_stride, !is_varlen_k_new ? params.vnew_batch_stride : 0}, // stride_V_new
|
106 |
+
static_cast<Element const*>(params.qv_ptr),
|
107 |
+
{params.qv_row_stride, _1{}, params.qv_head_stride, !is_varlen_q ? params.qv_batch_stride : 0}, // stride_Qv
|
108 |
+
static_cast<Element const*>(params.rotary_cos_ptr),
|
109 |
+
{params.seqlen_k, params.rotary_dim / 2}, // shape_rotary, the seqlen shape doesn't matter
|
110 |
+
{params.rotary_dim / 2, _1{}}, // stride_rotary_cos
|
111 |
+
static_cast<Element const*>(params.rotary_sin_ptr),
|
112 |
+
{params.rotary_dim / 2, _1{}}, // stride_rotary_sin
|
113 |
+
params.is_rotary_interleaved,
|
114 |
+
params.page_table,
|
115 |
+
// if page_size is not set, avoid dividing by zero
|
116 |
+
{params.kv_batch_idx ? params.b_k : params.b, !params.page_table ? 0 : params.seqlen_k / params.page_size}, // shape_page_table
|
117 |
+
{params.page_table_batch_stride, _1{}}, // stride_page_table
|
118 |
+
params.scale_softmax,
|
119 |
+
params.q_descale_ptr, params.k_descale_ptr, params.v_descale_ptr,
|
120 |
+
{params.q_descale_batch_stride, params.q_descale_head_stride},
|
121 |
+
{params.k_descale_batch_stride, params.k_descale_head_stride},
|
122 |
+
{params.v_descale_batch_stride, params.v_descale_head_stride},
|
123 |
+
params.window_size_left, params.window_size_right, params.attention_chunk,
|
124 |
+
params.softcap,
|
125 |
+
params.num_splits,
|
126 |
+
params.kv_batch_idx,
|
127 |
+
params.cu_seqlens_q, params.cu_seqlens_k, params.cu_seqlens_knew,
|
128 |
+
params.seqused_q, params.seqused_k,
|
129 |
+
params.leftpad_k, params.seqlens_rotary
|
130 |
+
};
|
131 |
+
typename CollectiveEpilogue::Arguments epilogue_args {
|
132 |
+
static_cast<ElementOut*>(params.o_ptr),
|
133 |
+
{seqlen_q, params.dv, params.h, batch_q, params.num_splits}, // shape_O
|
134 |
+
{params.o_row_stride, _1{}, params.o_head_stride, !is_varlen_q ? params.o_batch_stride : 0, 0}, // stride_O
|
135 |
+
static_cast<float*>(params.oaccum_ptr),
|
136 |
+
{params.oaccum_row_stride, _1{}, params.oaccum_head_stride, !is_varlen_q ? params.oaccum_batch_stride : 0, params.oaccum_split_stride}, // stride_O_partial
|
137 |
+
static_cast<float*>(params.softmax_lse_ptr),
|
138 |
+
{_1{}, seqlen_q, !is_varlen_q ? params.h * seqlen_q : 0, 0}, // stride_LSE
|
139 |
+
static_cast<float*>(params.softmax_lseaccum_ptr),
|
140 |
+
{_1{}, seqlen_q, !is_varlen_q ? params.h * seqlen_q : 0, params.h * seqlen_q * batch_q}, // stride_LSE_partial
|
141 |
+
params.h_k,
|
142 |
+
params.cu_seqlens_q, params.seqused_q
|
143 |
+
};
|
144 |
+
|
145 |
+
int qhead_per_khead = !PackGQA ? 1 : cutlass::ceil_div(params.h, params.h_k);
|
146 |
+
int num_blocks_m = cutlass::ceil_div(params.seqlen_q * qhead_per_khead, get<0>(TileShape_MNK{}));
|
147 |
+
num_blocks_m = cutlass::round_up(num_blocks_m, size<0>(ClusterShape{}));
|
148 |
+
typename flash::TileSchedulerArguments scheduler_args {
|
149 |
+
num_blocks_m, !PackGQA ? params.h : params.h_k, params.b, params.num_splits,
|
150 |
+
params.h / params.h_k,
|
151 |
+
params.seqlen_q,
|
152 |
+
params.seqlen_k, params.d, params.dv, sizeof(Element),
|
153 |
+
params.tile_count_semaphore, params.cu_seqlens_q, params.seqused_q,
|
154 |
+
// params.num_m_blocks_ptr,
|
155 |
+
params.num_splits_dynamic_ptr,
|
156 |
+
};
|
157 |
+
|
158 |
+
if (Varlen && params.num_splits_dynamic_ptr && !params.skip_scheduler_metadata_computation) {
|
159 |
+
prepare_varlen_num_blocks(params, stream, PackGQA, kBlockM, kBlockN, Arch >= 90 /*enable_pdl*/);
|
160 |
+
CHECK_CUDA_KERNEL_LAUNCH();
|
161 |
+
}
|
162 |
+
|
163 |
+
int device;
|
164 |
+
CHECK_CUDA(cudaGetDevice(&device));
|
165 |
+
typename AttnKernel::Params kernel_params = AttnKernel::to_underlying_arguments({
|
166 |
+
mainloop_args, epilogue_args, {device, params.num_sm}, scheduler_args
|
167 |
+
});
|
168 |
+
|
169 |
+
dim3 grid_dims = AttnKernel::get_grid_shape(kernel_params);
|
170 |
+
dim3 block_dims = AttnKernel::get_block_shape();
|
171 |
+
int smem_size = AttnKernel::SharedStorageSize;
|
172 |
+
// int smem_size_q = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_q));
|
173 |
+
// int smem_size_k = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_k));
|
174 |
+
// int smem_size_v = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_v));
|
175 |
+
// printf("smem_size = %d, q = %d, k = %d, v = %d\n", smem_size, smem_size_q, smem_size_k, smem_size_v);
|
176 |
+
// Get the ptr to kernel function.
|
177 |
+
if constexpr (size(ClusterShape{}) > 1) {
|
178 |
+
void const* kernel = (void const*) cutlass::device_kernel<AttnKernel>;
|
179 |
+
if (smem_size >= 48 * 1024) {
|
180 |
+
CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
|
181 |
+
}
|
182 |
+
dim3 cluster_dims(size<0>(ClusterShape{}), size<1>(ClusterShape{}), size<2>(ClusterShape{}));
|
183 |
+
cutlass::ClusterLaunchParams launch_params{grid_dims, block_dims, cluster_dims, smem_size, stream};
|
184 |
+
cutlass::launch_kernel_on_cluster(launch_params, kernel, kernel_params);
|
185 |
+
} else {
|
186 |
+
auto kernel = cutlass::device_kernel<AttnKernel>;
|
187 |
+
if (smem_size >= 48 * 1024) {
|
188 |
+
CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
|
189 |
+
}
|
190 |
+
// kernel<<<grid_dims, block_dims, smem_size, stream>>>(kernel_params);
|
191 |
+
cutlass::kernel_launch<AttnKernel>(grid_dims, block_dims, smem_size, stream, kernel_params,
|
192 |
+
Arch >= 90 && Varlen && params.num_splits_dynamic_ptr && !params.skip_scheduler_metadata_computation /*launch_with_pdl*/);
|
193 |
+
}
|
194 |
+
CHECK_CUDA_KERNEL_LAUNCH();
|
195 |
+
}
|
196 |
+
|
197 |
+
template<int Arch, typename T, int kHeadDim, int kHeadDimV, bool Split, bool PagedKVNonTMA, bool Has_softcap, bool PackGQA>
|
198 |
+
void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
199 |
+
static_assert(sizeof(T) == 2 || sizeof(T) == 1, "Only 16bit and 8bit are supported");
|
200 |
+
static constexpr bool Is_FP8 = cute::is_same_v<T, cutlass::float_e4m3_t> || cute::is_same_v<T, cutlass::float_e5m2_t>;
|
201 |
+
using T_out = std::conditional_t<!Is_FP8, T, cutlass::bfloat16_t>;
|
202 |
+
CAUSAL_LOCAL_SWITCH(params.is_causal, params.is_local, Is_causal, Is_local, [&] {
|
203 |
+
VCOLMAJOR_SWITCH(params.v_dim_stride != 1, V_colmajor_, [&] {
|
204 |
+
static constexpr bool V_colmajor = V_colmajor_ && sizeof(T) == 1;
|
205 |
+
VARLEN_SWITCH(params.cu_seqlens_q || params.cu_seqlens_k || params.seqused_q || params.seqused_k || params.leftpad_k, Varlen, [&] {
|
206 |
+
// Only needed here to decide if we should use cluster
|
207 |
+
static constexpr int kBlockM = Arch >= 90 ? std::get<0>(tile_size_fwd_sm90(kHeadDim, kHeadDimV, Is_causal, Is_local, sizeof(T) /*element_size*/, V_colmajor, PagedKVNonTMA, Has_softcap)) : 128;
|
208 |
+
|
209 |
+
static constexpr bool Enable_cluster = Arch == 90 && (sizeof(T) == 2 ? (kHeadDim >= 128) : (kHeadDim == 192)) && !Is_causal && !Is_local && !Split && !PagedKVNonTMA && !Varlen;
|
210 |
+
BOOL_SWITCH(params.qv_ptr, HasQV_, [&] {
|
211 |
+
static constexpr bool HasQv = HasQV_ && Arch == 90 && !Is_FP8 && kHeadDim == 64 && kHeadDimV >= 256;
|
212 |
+
APPENDKV_SWITCH(params.knew_ptr, AppendKV, [&] {
|
213 |
+
// Only use Cluster if number of tiles along seqlen_q is even and not varlen
|
214 |
+
CLUSTER_SWITCH(cutlass::ceil_div(params.seqlen_q * (!PackGQA ? 1 : params.h / params.h_k), kBlockM) % 2 == 0, Use_cluster, [&] {
|
215 |
+
static constexpr int ClusterM = Enable_cluster && Use_cluster ? 2 : 1;
|
216 |
+
run_flash_fwd<Arch, kHeadDim, kHeadDimV, ClusterM, T, T_out, Is_causal, Is_local, Has_softcap, Varlen, PagedKVNonTMA, AppendKV && Varlen, HasQv, PackGQA, Split, V_colmajor>(params, stream);
|
217 |
+
});
|
218 |
+
});
|
219 |
+
});
|
220 |
+
});
|
221 |
+
});
|
222 |
+
});
|
223 |
+
}
|
flash-attn/flash_prepare_scheduler.cu
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/******************************************************************************
|
2 |
+
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
3 |
+
******************************************************************************/
|
4 |
+
|
5 |
+
#include "cutlass/fast_math.h"
|
6 |
+
#include "cutlass/barrier.h"
|
7 |
+
#include "cutlass/arch/barrier.h"
|
8 |
+
|
9 |
+
#include "cutlass/arch/grid_dependency_control.h"
|
10 |
+
|
11 |
+
#include "flash.h"
|
12 |
+
|
13 |
+
namespace flash {
|
14 |
+
|
15 |
+
__global__ void prepare_varlen_num_blocks_kernel(
|
16 |
+
int seqlen_q_static, int seqlen_k_static, int seqlen_k_new_static,
|
17 |
+
int const* const cu_seqlens_q, int const* const cu_seqlens_k, int const* const cu_seqlens_k_new,
|
18 |
+
int const* const seqused_q, int const* const seqused_k, int const* const leftpad_k_ptr,
|
19 |
+
int num_batch, int num_head, int qhead_per_khead, int num_sm, int num_splits_static,
|
20 |
+
cutlass::FastDivmod blockm_divmod, cutlass::FastDivmod blockn_divmod,
|
21 |
+
int* const tile_count_semaphore,
|
22 |
+
// int* const num_m_blocks_ptr,
|
23 |
+
int* const num_splits_dynamic_ptr,
|
24 |
+
bool enable_pdl) {
|
25 |
+
|
26 |
+
static constexpr int kNumBatchPerWarp = cutlass::NumThreadsPerWarp - 1;
|
27 |
+
static constexpr int kSmemSize = 1;
|
28 |
+
// Assume that there's only one block in the grid
|
29 |
+
__shared__ int total_blocks_smem[kSmemSize];
|
30 |
+
|
31 |
+
// There's only 1 block in the grid, so might as well start launching the main attn kernel
|
32 |
+
if (enable_pdl) { cutlass::arch::launch_dependent_grids(); }
|
33 |
+
|
34 |
+
if (threadIdx.x < kSmemSize) { total_blocks_smem[threadIdx.x] = 0; }
|
35 |
+
__syncthreads();
|
36 |
+
|
37 |
+
if (threadIdx.x == 0 && tile_count_semaphore) { *tile_count_semaphore = 0; }
|
38 |
+
|
39 |
+
int lane = threadIdx.x % cutlass::NumThreadsPerWarp;
|
40 |
+
|
41 |
+
auto get_num_m_blocks = [&](int bidb_start) {
|
42 |
+
int batch_idx = lane + bidb_start;
|
43 |
+
int seqlen;
|
44 |
+
if (seqused_q) {
|
45 |
+
seqlen = batch_idx < num_batch ? seqused_q[batch_idx] : 0;
|
46 |
+
} else if (cu_seqlens_q) {
|
47 |
+
int cur_cu_seqlen = batch_idx <= num_batch ? cu_seqlens_q[batch_idx] : 0;
|
48 |
+
int next_cu_seqlen = __shfl_down_sync(0xffffffff, cur_cu_seqlen, 1);
|
49 |
+
seqlen = next_cu_seqlen - cur_cu_seqlen;
|
50 |
+
} else {
|
51 |
+
seqlen = seqlen_q_static;
|
52 |
+
}
|
53 |
+
seqlen *= qhead_per_khead;
|
54 |
+
return batch_idx < num_batch && lane < kNumBatchPerWarp
|
55 |
+
? blockm_divmod.div(seqlen + blockm_divmod.divisor - 1) : 0;
|
56 |
+
};
|
57 |
+
|
58 |
+
auto get_num_n_blocks = [&](int bidb_start) {
|
59 |
+
int batch_idx = lane + bidb_start;
|
60 |
+
int leftpad_k = batch_idx < num_batch && leftpad_k_ptr != nullptr ? leftpad_k_ptr[batch_idx] : 0;
|
61 |
+
int seqlen;
|
62 |
+
if (seqused_k) {
|
63 |
+
seqlen = batch_idx < num_batch ? seqused_k[batch_idx] : 0;
|
64 |
+
} else if (cu_seqlens_k) {
|
65 |
+
int cur_cu_seqlen = batch_idx <= num_batch ? cu_seqlens_k[batch_idx] : 0;
|
66 |
+
int next_cu_seqlen = __shfl_down_sync(0xffffffff, cur_cu_seqlen, 1);
|
67 |
+
seqlen = next_cu_seqlen - cur_cu_seqlen;
|
68 |
+
} else {
|
69 |
+
seqlen = seqlen_k_static;
|
70 |
+
}
|
71 |
+
int seqlen_new;
|
72 |
+
if (cu_seqlens_k_new) {
|
73 |
+
int cur_cu_seqlen_new = batch_idx <= num_batch ? cu_seqlens_k_new[batch_idx] : 0;
|
74 |
+
int next_cu_seqlen_new = __shfl_down_sync(0xffffffff, cur_cu_seqlen_new, 1);
|
75 |
+
seqlen_new = next_cu_seqlen_new - cur_cu_seqlen_new;
|
76 |
+
} else {
|
77 |
+
seqlen_new = seqlen_k_new_static;
|
78 |
+
}
|
79 |
+
// if (threadIdx.x == 0) { printf("seqlen = %d, seqlen_new = %d, leftpad_k = %d\n", seqlen, seqlen_new, leftpad_k); }
|
80 |
+
seqlen = seqlen - leftpad_k + seqlen_new;
|
81 |
+
return batch_idx < num_batch && lane < kNumBatchPerWarp
|
82 |
+
? blockn_divmod.div(seqlen + blockn_divmod.divisor - 1) : 0;
|
83 |
+
};
|
84 |
+
|
85 |
+
int warp_idx = threadIdx.x / cutlass::NumThreadsPerWarp;
|
86 |
+
int bidb_start = kNumBatchPerWarp * warp_idx;
|
87 |
+
int num_m_blocks = get_num_m_blocks(bidb_start);
|
88 |
+
int num_n_blocks = get_num_n_blocks(bidb_start);
|
89 |
+
|
90 |
+
int total_blocks = num_m_blocks * num_n_blocks;
|
91 |
+
// Warp sum
|
92 |
+
#pragma unroll
|
93 |
+
for (int i = cutlass::NumThreadsPerWarp / 2; i >= 1; i /= 2) {
|
94 |
+
total_blocks += __shfl_down_sync(0xffffffff, total_blocks, i);
|
95 |
+
}
|
96 |
+
if (lane == 0) { atomicAdd(total_blocks_smem, total_blocks); }
|
97 |
+
__syncthreads();
|
98 |
+
total_blocks = total_blocks_smem[0];
|
99 |
+
// 10% margin
|
100 |
+
int blocks_per_sm = static_cast<int>(ceilf(float(total_blocks) * 1.1f * float(num_head) / float(num_sm)));
|
101 |
+
// blocks_per_sm = std::max(1, blocks_per_sm); // 1 is the minimum number of blocks per SM
|
102 |
+
int num_splits_dynamic = std::max(std::min((num_n_blocks + blocks_per_sm - 1) / blocks_per_sm, num_splits_static), 1);
|
103 |
+
if (bidb_start + lane < num_batch && lane < kNumBatchPerWarp) {
|
104 |
+
num_splits_dynamic_ptr[bidb_start + lane] = num_splits_dynamic;
|
105 |
+
// printf("idx = %d, num_m_blocks = %d, num_n_blocks = %d, num_split_static = %d, num_splits_dynamic = %d\n", bidb_start + lane, num_m_blocks_ptr[bidb_start + lane], num_n_blocks, num_splits_static, num_splits_dynamic);
|
106 |
+
}
|
107 |
+
}
|
108 |
+
|
109 |
+
} // flash
|
110 |
+
|
111 |
+
void prepare_varlen_num_blocks(Flash_fwd_params ¶ms, cudaStream_t stream, bool packgqa,
|
112 |
+
int blockM, int blockN, bool enable_pdl) {
|
113 |
+
// Only support batch <= 992 (32 warps, each with 31 batches)
|
114 |
+
int qhead_per_khead = !packgqa ? 1 : cutlass::ceil_div(params.h, params.h_k);
|
115 |
+
flash::prepare_varlen_num_blocks_kernel<<<1 /*grid*/, 1024 /*block*/, 0, stream>>>(
|
116 |
+
params.seqlen_q, params.seqlen_k, params.seqlen_knew,
|
117 |
+
params.cu_seqlens_q, params.cu_seqlens_k, params.cu_seqlens_knew,
|
118 |
+
params.seqused_q, params.seqused_k, params.leftpad_k,
|
119 |
+
params.b, !packgqa ? params.h : params.h_k, qhead_per_khead, params.num_sm, params.num_splits,
|
120 |
+
cutlass::FastDivmod(blockM), cutlass::FastDivmod(blockN),
|
121 |
+
params.tile_count_semaphore,
|
122 |
+
// params.num_m_blocks_ptr,
|
123 |
+
params.num_splits_dynamic_ptr, enable_pdl);
|
124 |
+
}
|
flash-attn/heuristics.h
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/******************************************************************************
|
2 |
+
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
3 |
+
******************************************************************************/
|
4 |
+
|
5 |
+
#pragma once
|
6 |
+
|
7 |
+
#include <vector>
|
8 |
+
|
9 |
+
inline bool should_pack_gqa(bool varlen_q, int seqlen_q, int qhead_per_khead, int blockM) {
|
10 |
+
// If varlen, we don't actually know seqlen_q but only max_seqlen_q.
|
11 |
+
if (varlen_q) return true;
|
12 |
+
// Heuristic: PackGQA is a bit slower but can help if seqlen_q is small or not near a multiple of kBlockM
|
13 |
+
auto round_up = [](int a, int b) { return (a + b - 1) / b * b; };
|
14 |
+
float nopack_gqa_efficiency = float(seqlen_q) / float(round_up(seqlen_q, blockM));
|
15 |
+
float pack_gqa_efficiency = float(seqlen_q * qhead_per_khead) / float(round_up(seqlen_q * qhead_per_khead, blockM));
|
16 |
+
return nopack_gqa_efficiency < 0.9 * pack_gqa_efficiency;
|
17 |
+
};
|
18 |
+
|
19 |
+
// Find the number of splits that maximizes the occupancy. For example, if we have
|
20 |
+
// batch * n_heads = 48 and we have 108 SMs, having 2 splits (efficiency = 0.89) is
|
21 |
+
// better than having 3 splits (efficiency = 0.67). However, we also don't want too many
|
22 |
+
// splits as that would incur more HBM reads/writes.
|
23 |
+
// So we find the best efficiency, then find the smallest number of splits that gets 85%
|
24 |
+
// of the best efficiency.
|
25 |
+
inline int num_splits_heuristic(int total_mblocks, int num_SMs, int num_n_blocks, int num_m_blocks, int size_one_kv_head, bool is_causal_or_local, int max_splits) {
|
26 |
+
// If we have enough to almost fill the SMs, then just use 1 split
|
27 |
+
// However, in the case of super long seqlen where each head of KV doesn't even fit into
|
28 |
+
// L2 (we assume that L2 size is 50MB), we want to split.
|
29 |
+
if (total_mblocks >= 0.8f * num_SMs) {
|
30 |
+
int const size_l2 = 50 * 1024 * 1024;
|
31 |
+
// Only split if there are enough queries to go over the KV at least twice
|
32 |
+
// Don't split if causal
|
33 |
+
if (size_one_kv_head > size_l2 && num_m_blocks >= num_SMs * 2 && !is_causal_or_local) {
|
34 |
+
return std::min((size_one_kv_head + size_l2 - 1) / size_l2, max_splits);
|
35 |
+
} else {
|
36 |
+
return 1;
|
37 |
+
}
|
38 |
+
}
|
39 |
+
// If num_n_blocks is too small, use 1 split. For example, we never split for hdim = 128 and seqlen_k = 512.
|
40 |
+
if (num_n_blocks <= 4) { return 1; }
|
41 |
+
max_splits = std::min({max_splits, num_SMs, num_n_blocks});
|
42 |
+
float max_efficiency = 0.f;
|
43 |
+
std::vector<float> efficiency;
|
44 |
+
efficiency.reserve(max_splits);
|
45 |
+
for (int num_splits = 1; num_splits <= max_splits; num_splits++) {
|
46 |
+
float n_waves = float(total_mblocks * num_splits) / num_SMs;
|
47 |
+
float eff = n_waves / ceil(n_waves);
|
48 |
+
// printf("num_splits = %d, eff = %f\n", num_splits, eff);
|
49 |
+
if (eff > max_efficiency) { max_efficiency = eff; }
|
50 |
+
efficiency.push_back(eff);
|
51 |
+
}
|
52 |
+
for (int num_splits = 1; num_splits <= max_splits; num_splits++) {
|
53 |
+
if (efficiency[num_splits - 1] >= 0.85 * max_efficiency) {
|
54 |
+
// printf("num_splits chosen = %d\n", num_splits);
|
55 |
+
return num_splits;
|
56 |
+
}
|
57 |
+
}
|
58 |
+
return 1;
|
59 |
+
}
|
flash-attn/instantiations/flash_bwd_hdim128_bf16_sm80.cu
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
2 |
+
// Splitting the different template instantiations to different files to speed up compilation.
|
3 |
+
// This file is auto-generated. See "generate_kernels.py"
|
4 |
+
|
5 |
+
#include "flash_bwd_launch_template.h"
|
6 |
+
|
7 |
+
#ifndef FLASHATTENTION_DISABLE_SM8x
|
8 |
+
#ifndef FLASHATTENTION_DISABLE_HDIM128
|
9 |
+
template<>
|
10 |
+
void run_mha_bwd_<80, cutlass::bfloat16_t, 128, false>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
11 |
+
run_mha_bwd_hdim128<80, cutlass::bfloat16_t, false>(params, stream);
|
12 |
+
}
|
13 |
+
template<>
|
14 |
+
void run_mha_bwd_<86, cutlass::bfloat16_t, 128, false>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
15 |
+
run_mha_bwd_hdim128<86, cutlass::bfloat16_t, false>(params, stream);
|
16 |
+
}
|
17 |
+
#endif
|
18 |
+
#endif
|
flash-attn/instantiations/flash_bwd_hdim128_bf16_sm90.cu
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
2 |
+
// Splitting the different template instantiations to different files to speed up compilation.
|
3 |
+
// This file is auto-generated. See "generate_kernels.py"
|
4 |
+
|
5 |
+
#include "flash_bwd_launch_template.h"
|
6 |
+
|
7 |
+
#ifndef FLASHATTENTION_DISABLE_HDIM128
|
8 |
+
template<>
|
9 |
+
void run_mha_bwd_<90, cutlass::bfloat16_t, 128, false>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
10 |
+
run_mha_bwd_hdim128<90, cutlass::bfloat16_t, false>(params, stream);
|
11 |
+
}
|
12 |
+
#endif
|
flash-attn/instantiations/flash_bwd_hdim128_bf16_softcap_sm80.cu
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
2 |
+
// Splitting the different template instantiations to different files to speed up compilation.
|
3 |
+
// This file is auto-generated. See "generate_kernels.py"
|
4 |
+
|
5 |
+
#include "flash_bwd_launch_template.h"
|
6 |
+
|
7 |
+
#ifndef FLASHATTENTION_DISABLE_SM8x
|
8 |
+
#ifndef FLASHATTENTION_DISABLE_HDIM128
|
9 |
+
template<>
|
10 |
+
void run_mha_bwd_<80, cutlass::bfloat16_t, 128, true>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
11 |
+
run_mha_bwd_hdim128<80, cutlass::bfloat16_t, true>(params, stream);
|
12 |
+
}
|
13 |
+
template<>
|
14 |
+
void run_mha_bwd_<86, cutlass::bfloat16_t, 128, true>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
15 |
+
run_mha_bwd_hdim128<86, cutlass::bfloat16_t, true>(params, stream);
|
16 |
+
}
|
17 |
+
#endif
|
18 |
+
#endif
|
flash-attn/instantiations/flash_bwd_hdim128_bf16_softcap_sm90.cu
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
2 |
+
// Splitting the different template instantiations to different files to speed up compilation.
|
3 |
+
// This file is auto-generated. See "generate_kernels.py"
|
4 |
+
|
5 |
+
#include "flash_bwd_launch_template.h"
|
6 |
+
|
7 |
+
#ifndef FLASHATTENTION_DISABLE_HDIM128
|
8 |
+
template<>
|
9 |
+
void run_mha_bwd_<90, cutlass::bfloat16_t, 128, true>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
10 |
+
run_mha_bwd_hdim128<90, cutlass::bfloat16_t, true>(params, stream);
|
11 |
+
}
|
12 |
+
#endif
|
flash-attn/instantiations/flash_bwd_hdim128_bf16_softcapall_sm90.cu
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
2 |
+
// Splitting the different template instantiations to different files to speed up compilation.
|
3 |
+
// This file is auto-generated. See "generate_kernels.py"
|
4 |
+
|
5 |
+
#include "flash_bwd_hdim128_bf16_sm90.cu"
|
6 |
+
#include "flash_bwd_hdim128_bf16_softcap_sm90.cu"
|
flash-attn/instantiations/flash_bwd_hdim128_fp16_sm80.cu
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
2 |
+
// Splitting the different template instantiations to different files to speed up compilation.
|
3 |
+
// This file is auto-generated. See "generate_kernels.py"
|
4 |
+
|
5 |
+
#include "flash_bwd_launch_template.h"
|
6 |
+
|
7 |
+
#ifndef FLASHATTENTION_DISABLE_SM8x
|
8 |
+
#ifndef FLASHATTENTION_DISABLE_HDIM128
|
9 |
+
template<>
|
10 |
+
void run_mha_bwd_<80, cutlass::half_t, 128, false>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
11 |
+
run_mha_bwd_hdim128<80, cutlass::half_t, false>(params, stream);
|
12 |
+
}
|
13 |
+
template<>
|
14 |
+
void run_mha_bwd_<86, cutlass::half_t, 128, false>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
15 |
+
run_mha_bwd_hdim128<86, cutlass::half_t, false>(params, stream);
|
16 |
+
}
|
17 |
+
#endif
|
18 |
+
#endif
|
flash-attn/instantiations/flash_bwd_hdim128_fp16_sm90.cu
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
2 |
+
// Splitting the different template instantiations to different files to speed up compilation.
|
3 |
+
// This file is auto-generated. See "generate_kernels.py"
|
4 |
+
|
5 |
+
#include "flash_bwd_launch_template.h"
|
6 |
+
|
7 |
+
#ifndef FLASHATTENTION_DISABLE_HDIM128
|
8 |
+
template<>
|
9 |
+
void run_mha_bwd_<90, cutlass::half_t, 128, false>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
10 |
+
run_mha_bwd_hdim128<90, cutlass::half_t, false>(params, stream);
|
11 |
+
}
|
12 |
+
#endif
|
flash-attn/instantiations/flash_bwd_hdim128_fp16_softcap_sm80.cu
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
2 |
+
// Splitting the different template instantiations to different files to speed up compilation.
|
3 |
+
// This file is auto-generated. See "generate_kernels.py"
|
4 |
+
|
5 |
+
#include "flash_bwd_launch_template.h"
|
6 |
+
|
7 |
+
#ifndef FLASHATTENTION_DISABLE_SM8x
|
8 |
+
#ifndef FLASHATTENTION_DISABLE_HDIM128
|
9 |
+
template<>
|
10 |
+
void run_mha_bwd_<80, cutlass::half_t, 128, true>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
11 |
+
run_mha_bwd_hdim128<80, cutlass::half_t, true>(params, stream);
|
12 |
+
}
|
13 |
+
template<>
|
14 |
+
void run_mha_bwd_<86, cutlass::half_t, 128, true>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
15 |
+
run_mha_bwd_hdim128<86, cutlass::half_t, true>(params, stream);
|
16 |
+
}
|
17 |
+
#endif
|
18 |
+
#endif
|
flash-attn/instantiations/flash_bwd_hdim128_fp16_softcap_sm90.cu
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
2 |
+
// Splitting the different template instantiations to different files to speed up compilation.
|
3 |
+
// This file is auto-generated. See "generate_kernels.py"
|
4 |
+
|
5 |
+
#include "flash_bwd_launch_template.h"
|
6 |
+
|
7 |
+
#ifndef FLASHATTENTION_DISABLE_HDIM128
|
8 |
+
template<>
|
9 |
+
void run_mha_bwd_<90, cutlass::half_t, 128, true>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
10 |
+
run_mha_bwd_hdim128<90, cutlass::half_t, true>(params, stream);
|
11 |
+
}
|
12 |
+
#endif
|
flash-attn/instantiations/flash_bwd_hdim128_fp16_softcapall_sm90.cu
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
2 |
+
// Splitting the different template instantiations to different files to speed up compilation.
|
3 |
+
// This file is auto-generated. See "generate_kernels.py"
|
4 |
+
|
5 |
+
#include "flash_bwd_hdim128_fp16_sm90.cu"
|
6 |
+
#include "flash_bwd_hdim128_fp16_softcap_sm90.cu"
|
flash-attn/instantiations/flash_bwd_hdim192_bf16_sm80.cu
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
2 |
+
// Splitting the different template instantiations to different files to speed up compilation.
|
3 |
+
// This file is auto-generated. See "generate_kernels.py"
|
4 |
+
|
5 |
+
#include "flash_bwd_launch_template.h"
|
6 |
+
|
7 |
+
#ifndef FLASHATTENTION_DISABLE_SM8x
|
8 |
+
#ifndef FLASHATTENTION_DISABLE_HDIM192
|
9 |
+
template<>
|
10 |
+
void run_mha_bwd_<80, cutlass::bfloat16_t, 192, false>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
11 |
+
run_mha_bwd_hdim192<80, cutlass::bfloat16_t, false>(params, stream);
|
12 |
+
}
|
13 |
+
template<>
|
14 |
+
void run_mha_bwd_<86, cutlass::bfloat16_t, 192, false>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
15 |
+
run_mha_bwd_hdim192<86, cutlass::bfloat16_t, false>(params, stream);
|
16 |
+
}
|
17 |
+
#endif
|
18 |
+
#endif
|
flash-attn/instantiations/flash_bwd_hdim192_bf16_sm90.cu
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
2 |
+
// Splitting the different template instantiations to different files to speed up compilation.
|
3 |
+
// This file is auto-generated. See "generate_kernels.py"
|
4 |
+
|
5 |
+
#include "flash_bwd_launch_template.h"
|
6 |
+
|
7 |
+
#ifndef FLASHATTENTION_DISABLE_HDIM192
|
8 |
+
template<>
|
9 |
+
void run_mha_bwd_<90, cutlass::bfloat16_t, 192, false>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
10 |
+
run_mha_bwd_hdim192<90, cutlass::bfloat16_t, false>(params, stream);
|
11 |
+
}
|
12 |
+
#endif
|
flash-attn/instantiations/flash_bwd_hdim192_bf16_softcap_sm80.cu
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
2 |
+
// Splitting the different template instantiations to different files to speed up compilation.
|
3 |
+
// This file is auto-generated. See "generate_kernels.py"
|
4 |
+
|
5 |
+
#include "flash_bwd_launch_template.h"
|
6 |
+
|
7 |
+
#ifndef FLASHATTENTION_DISABLE_SM8x
|
8 |
+
#ifndef FLASHATTENTION_DISABLE_HDIM192
|
9 |
+
template<>
|
10 |
+
void run_mha_bwd_<80, cutlass::bfloat16_t, 192, true>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
11 |
+
run_mha_bwd_hdim192<80, cutlass::bfloat16_t, true>(params, stream);
|
12 |
+
}
|
13 |
+
template<>
|
14 |
+
void run_mha_bwd_<86, cutlass::bfloat16_t, 192, true>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
15 |
+
run_mha_bwd_hdim192<86, cutlass::bfloat16_t, true>(params, stream);
|
16 |
+
}
|
17 |
+
#endif
|
18 |
+
#endif
|
flash-attn/instantiations/flash_bwd_hdim192_bf16_softcap_sm90.cu
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
2 |
+
// Splitting the different template instantiations to different files to speed up compilation.
|
3 |
+
// This file is auto-generated. See "generate_kernels.py"
|
4 |
+
|
5 |
+
#include "flash_bwd_launch_template.h"
|
6 |
+
|
7 |
+
#ifndef FLASHATTENTION_DISABLE_HDIM192
|
8 |
+
template<>
|
9 |
+
void run_mha_bwd_<90, cutlass::bfloat16_t, 192, true>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
10 |
+
run_mha_bwd_hdim192<90, cutlass::bfloat16_t, true>(params, stream);
|
11 |
+
}
|
12 |
+
#endif
|
flash-attn/instantiations/flash_bwd_hdim192_bf16_softcapall_sm90.cu
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
2 |
+
// Splitting the different template instantiations to different files to speed up compilation.
|
3 |
+
// This file is auto-generated. See "generate_kernels.py"
|
4 |
+
|
5 |
+
#include "flash_bwd_hdim192_bf16_sm90.cu"
|
6 |
+
#include "flash_bwd_hdim192_bf16_softcap_sm90.cu"
|
flash-attn/instantiations/flash_bwd_hdim192_fp16_sm80.cu
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
2 |
+
// Splitting the different template instantiations to different files to speed up compilation.
|
3 |
+
// This file is auto-generated. See "generate_kernels.py"
|
4 |
+
|
5 |
+
#include "flash_bwd_launch_template.h"
|
6 |
+
|
7 |
+
#ifndef FLASHATTENTION_DISABLE_SM8x
|
8 |
+
#ifndef FLASHATTENTION_DISABLE_HDIM192
|
9 |
+
template<>
|
10 |
+
void run_mha_bwd_<80, cutlass::half_t, 192, false>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
11 |
+
run_mha_bwd_hdim192<80, cutlass::half_t, false>(params, stream);
|
12 |
+
}
|
13 |
+
template<>
|
14 |
+
void run_mha_bwd_<86, cutlass::half_t, 192, false>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
15 |
+
run_mha_bwd_hdim192<86, cutlass::half_t, false>(params, stream);
|
16 |
+
}
|
17 |
+
#endif
|
18 |
+
#endif
|
flash-attn/instantiations/flash_bwd_hdim192_fp16_sm90.cu
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
2 |
+
// Splitting the different template instantiations to different files to speed up compilation.
|
3 |
+
// This file is auto-generated. See "generate_kernels.py"
|
4 |
+
|
5 |
+
#include "flash_bwd_launch_template.h"
|
6 |
+
|
7 |
+
#ifndef FLASHATTENTION_DISABLE_HDIM192
|
8 |
+
template<>
|
9 |
+
void run_mha_bwd_<90, cutlass::half_t, 192, false>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
10 |
+
run_mha_bwd_hdim192<90, cutlass::half_t, false>(params, stream);
|
11 |
+
}
|
12 |
+
#endif
|
flash-attn/instantiations/flash_bwd_hdim192_fp16_softcap_sm80.cu
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
2 |
+
// Splitting the different template instantiations to different files to speed up compilation.
|
3 |
+
// This file is auto-generated. See "generate_kernels.py"
|
4 |
+
|
5 |
+
#include "flash_bwd_launch_template.h"
|
6 |
+
|
7 |
+
#ifndef FLASHATTENTION_DISABLE_SM8x
|
8 |
+
#ifndef FLASHATTENTION_DISABLE_HDIM192
|
9 |
+
template<>
|
10 |
+
void run_mha_bwd_<80, cutlass::half_t, 192, true>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
11 |
+
run_mha_bwd_hdim192<80, cutlass::half_t, true>(params, stream);
|
12 |
+
}
|
13 |
+
template<>
|
14 |
+
void run_mha_bwd_<86, cutlass::half_t, 192, true>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
15 |
+
run_mha_bwd_hdim192<86, cutlass::half_t, true>(params, stream);
|
16 |
+
}
|
17 |
+
#endif
|
18 |
+
#endif
|
flash-attn/instantiations/flash_bwd_hdim192_fp16_softcap_sm90.cu
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
2 |
+
// Splitting the different template instantiations to different files to speed up compilation.
|
3 |
+
// This file is auto-generated. See "generate_kernels.py"
|
4 |
+
|
5 |
+
#include "flash_bwd_launch_template.h"
|
6 |
+
|
7 |
+
#ifndef FLASHATTENTION_DISABLE_HDIM192
|
8 |
+
template<>
|
9 |
+
void run_mha_bwd_<90, cutlass::half_t, 192, true>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
10 |
+
run_mha_bwd_hdim192<90, cutlass::half_t, true>(params, stream);
|
11 |
+
}
|
12 |
+
#endif
|
flash-attn/instantiations/flash_bwd_hdim192_fp16_softcapall_sm90.cu
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
2 |
+
// Splitting the different template instantiations to different files to speed up compilation.
|
3 |
+
// This file is auto-generated. See "generate_kernels.py"
|
4 |
+
|
5 |
+
#include "flash_bwd_hdim192_fp16_sm90.cu"
|
6 |
+
#include "flash_bwd_hdim192_fp16_softcap_sm90.cu"
|
flash-attn/instantiations/flash_bwd_hdim256_bf16_sm80.cu
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
2 |
+
// Splitting the different template instantiations to different files to speed up compilation.
|
3 |
+
// This file is auto-generated. See "generate_kernels.py"
|
4 |
+
|
5 |
+
#include "flash_bwd_launch_template.h"
|
6 |
+
|
7 |
+
#ifndef FLASHATTENTION_DISABLE_SM8x
|
8 |
+
#ifndef FLASHATTENTION_DISABLE_HDIM256
|
9 |
+
template<>
|
10 |
+
void run_mha_bwd_<80, cutlass::bfloat16_t, 256, false>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
11 |
+
run_mha_bwd_hdim256<80, cutlass::bfloat16_t, false>(params, stream);
|
12 |
+
}
|
13 |
+
template<>
|
14 |
+
void run_mha_bwd_<86, cutlass::bfloat16_t, 256, false>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
15 |
+
run_mha_bwd_hdim256<86, cutlass::bfloat16_t, false>(params, stream);
|
16 |
+
}
|
17 |
+
#endif
|
18 |
+
#endif
|
flash-attn/instantiations/flash_bwd_hdim256_bf16_sm90.cu
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
2 |
+
// Splitting the different template instantiations to different files to speed up compilation.
|
3 |
+
// This file is auto-generated. See "generate_kernels.py"
|
4 |
+
|
5 |
+
#include "flash_bwd_launch_template.h"
|
6 |
+
|
7 |
+
#ifndef FLASHATTENTION_DISABLE_HDIM256
|
8 |
+
template<>
|
9 |
+
void run_mha_bwd_<90, cutlass::bfloat16_t, 256, false>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
10 |
+
run_mha_bwd_hdim256<90, cutlass::bfloat16_t, false>(params, stream);
|
11 |
+
}
|
12 |
+
#endif
|
flash-attn/instantiations/flash_bwd_hdim256_bf16_softcap_sm80.cu
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
2 |
+
// Splitting the different template instantiations to different files to speed up compilation.
|
3 |
+
// This file is auto-generated. See "generate_kernels.py"
|
4 |
+
|
5 |
+
#include "flash_bwd_launch_template.h"
|
6 |
+
|
7 |
+
#ifndef FLASHATTENTION_DISABLE_SM8x
|
8 |
+
#ifndef FLASHATTENTION_DISABLE_HDIM256
|
9 |
+
template<>
|
10 |
+
void run_mha_bwd_<80, cutlass::bfloat16_t, 256, true>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
11 |
+
run_mha_bwd_hdim256<80, cutlass::bfloat16_t, true>(params, stream);
|
12 |
+
}
|
13 |
+
template<>
|
14 |
+
void run_mha_bwd_<86, cutlass::bfloat16_t, 256, true>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
15 |
+
run_mha_bwd_hdim256<86, cutlass::bfloat16_t, true>(params, stream);
|
16 |
+
}
|
17 |
+
#endif
|
18 |
+
#endif
|
flash-attn/instantiations/flash_bwd_hdim256_bf16_softcap_sm90.cu
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
2 |
+
// Splitting the different template instantiations to different files to speed up compilation.
|
3 |
+
// This file is auto-generated. See "generate_kernels.py"
|
4 |
+
|
5 |
+
#include "flash_bwd_launch_template.h"
|
6 |
+
|
7 |
+
#ifndef FLASHATTENTION_DISABLE_HDIM256
|
8 |
+
template<>
|
9 |
+
void run_mha_bwd_<90, cutlass::bfloat16_t, 256, true>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
10 |
+
run_mha_bwd_hdim256<90, cutlass::bfloat16_t, true>(params, stream);
|
11 |
+
}
|
12 |
+
#endif
|
flash-attn/instantiations/flash_bwd_hdim256_bf16_softcapall_sm90.cu
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
2 |
+
// Splitting the different template instantiations to different files to speed up compilation.
|
3 |
+
// This file is auto-generated. See "generate_kernels.py"
|
4 |
+
|
5 |
+
#include "flash_bwd_hdim256_bf16_sm90.cu"
|
6 |
+
#include "flash_bwd_hdim256_bf16_softcap_sm90.cu"
|
flash-attn/instantiations/flash_bwd_hdim256_fp16_sm80.cu
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
2 |
+
// Splitting the different template instantiations to different files to speed up compilation.
|
3 |
+
// This file is auto-generated. See "generate_kernels.py"
|
4 |
+
|
5 |
+
#include "flash_bwd_launch_template.h"
|
6 |
+
|
7 |
+
#ifndef FLASHATTENTION_DISABLE_SM8x
|
8 |
+
#ifndef FLASHATTENTION_DISABLE_HDIM256
|
9 |
+
template<>
|
10 |
+
void run_mha_bwd_<80, cutlass::half_t, 256, false>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
11 |
+
run_mha_bwd_hdim256<80, cutlass::half_t, false>(params, stream);
|
12 |
+
}
|
13 |
+
template<>
|
14 |
+
void run_mha_bwd_<86, cutlass::half_t, 256, false>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
15 |
+
run_mha_bwd_hdim256<86, cutlass::half_t, false>(params, stream);
|
16 |
+
}
|
17 |
+
#endif
|
18 |
+
#endif
|
flash-attn/instantiations/flash_bwd_hdim256_fp16_sm90.cu
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
2 |
+
// Splitting the different template instantiations to different files to speed up compilation.
|
3 |
+
// This file is auto-generated. See "generate_kernels.py"
|
4 |
+
|
5 |
+
#include "flash_bwd_launch_template.h"
|
6 |
+
|
7 |
+
#ifndef FLASHATTENTION_DISABLE_HDIM256
|
8 |
+
template<>
|
9 |
+
void run_mha_bwd_<90, cutlass::half_t, 256, false>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
10 |
+
run_mha_bwd_hdim256<90, cutlass::half_t, false>(params, stream);
|
11 |
+
}
|
12 |
+
#endif
|