Various small fixes
Browse files
flake.lock
CHANGED
@@ -98,11 +98,11 @@
|
|
98 |
]
|
99 |
},
|
100 |
"locked": {
|
101 |
-
"lastModified":
|
102 |
-
"narHash": "sha256-
|
103 |
"owner": "huggingface",
|
104 |
"repo": "kernel-builder",
|
105 |
-
"rev": "
|
106 |
"type": "github"
|
107 |
},
|
108 |
"original": {
|
|
|
98 |
]
|
99 |
},
|
100 |
"locked": {
|
101 |
+
"lastModified": 1751014803,
|
102 |
+
"narHash": "sha256-9Xfq2k3uPfB602NwQF+zAY2GQZiKUN1G7Q6XiDCUR8Y=",
|
103 |
"owner": "huggingface",
|
104 |
"repo": "kernel-builder",
|
105 |
+
"rev": "bbc4e712ff2046e217818e97de2201e2b996756e",
|
106 |
"type": "github"
|
107 |
},
|
108 |
"original": {
|
flake.nix
CHANGED
@@ -13,5 +13,37 @@
|
|
13 |
kernel-builder.lib.genFlakeOutputs {
|
14 |
path = ./.;
|
15 |
rev = self.shortRev or self.dirtyShortRev or self.lastModifiedDate;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
};
|
17 |
}
|
|
|
13 |
kernel-builder.lib.genFlakeOutputs {
|
14 |
path = ./.;
|
15 |
rev = self.shortRev or self.dirtyShortRev or self.lastModifiedDate;
|
16 |
+
# Building with CDUA later than 12.4 fails with:
|
17 |
+
#
|
18 |
+
# error: 'ptxas' died due to signal 11 (Invalid memory reference)
|
19 |
+
#
|
20 |
+
# So, build for 12.4 only and copy to all the other build variants
|
21 |
+
# by hand (which works fine thanks to backward compat).
|
22 |
+
torchVersions = [
|
23 |
+
{
|
24 |
+
torchVersion = "2.6";
|
25 |
+
cudaVersion = "12.4";
|
26 |
+
cxx11Abi = false;
|
27 |
+
systems = [ "x86_64-linux" ];
|
28 |
+
upstreamVariant = true;
|
29 |
+
}
|
30 |
+
{
|
31 |
+
torchVersion = "2.6";
|
32 |
+
cudaVersion = "12.4";
|
33 |
+
cxx11Abi = true;
|
34 |
+
systems = [ "x86_64-linux" ];
|
35 |
+
upstreamVariant = true;
|
36 |
+
}
|
37 |
+
{
|
38 |
+
torchVersion = "2.7";
|
39 |
+
cudaVersion = "12.4";
|
40 |
+
cxx11Abi = true;
|
41 |
+
systems = [
|
42 |
+
"x86_64-linux"
|
43 |
+
"aarch64-linux"
|
44 |
+
];
|
45 |
+
upstreamVariant = true;
|
46 |
+
}
|
47 |
+
];
|
48 |
};
|
49 |
}
|
torch-ext/{flash_attn → flash_attn3}/__init__.py
RENAMED
File without changes
|
torch-ext/{flash_attn → flash_attn3}/flash_attn_interface.py
RENAMED
File without changes
|
torch-ext/torch_binding.cpp
CHANGED
@@ -5,7 +5,7 @@
|
|
5 |
#include "torch_binding.h"
|
6 |
|
7 |
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
8 |
-
|
9 |
"Tensor q,"
|
10 |
"Tensor k,"
|
11 |
"Tensor v,"
|
@@ -40,7 +40,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|
40 |
"int num_splits = 0,"
|
41 |
"bool? pack_gqa = None,"
|
42 |
"int sm_margin = 0) -> (Tensor(out!), Tensor, Tensor, Tensor)");
|
43 |
-
|
44 |
"Tensor dout,"
|
45 |
"Tensor q,"
|
46 |
"Tensor k,"
|
@@ -63,12 +63,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|
63 |
"float softcap = 0.0,"
|
64 |
"bool deterministic = False,"
|
65 |
"int sm_margin = 0) -> (Tensor(dq!), Tensor(dk!), Tensor(dv!), Tensor, Tensor, Tensor, Tensor, Tensor)");
|
66 |
-
|
67 |
"Tensor out_partial,"
|
68 |
"Tensor lse_partial,"
|
69 |
"Tensor(out!)? out = None,"
|
70 |
"ScalarType? out_dtype = None) -> (Tensor(out!), Tensor)");
|
71 |
-
|
72 |
"int batch_size,"
|
73 |
"int max_seqlen_q,"
|
74 |
"int max_seqlen_k,"
|
@@ -94,10 +94,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|
94 |
"bool? pack_gqa = None,"
|
95 |
"int sm_margin = 0) -> Tensor");
|
96 |
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
}
|
102 |
|
103 |
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|
|
|
5 |
#include "torch_binding.h"
|
6 |
|
7 |
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
8 |
+
ops.def("fwd("
|
9 |
"Tensor q,"
|
10 |
"Tensor k,"
|
11 |
"Tensor v,"
|
|
|
40 |
"int num_splits = 0,"
|
41 |
"bool? pack_gqa = None,"
|
42 |
"int sm_margin = 0) -> (Tensor(out!), Tensor, Tensor, Tensor)");
|
43 |
+
ops.def("bwd("
|
44 |
"Tensor dout,"
|
45 |
"Tensor q,"
|
46 |
"Tensor k,"
|
|
|
63 |
"float softcap = 0.0,"
|
64 |
"bool deterministic = False,"
|
65 |
"int sm_margin = 0) -> (Tensor(dq!), Tensor(dk!), Tensor(dv!), Tensor, Tensor, Tensor, Tensor, Tensor)");
|
66 |
+
ops.def("fwd_combine("
|
67 |
"Tensor out_partial,"
|
68 |
"Tensor lse_partial,"
|
69 |
"Tensor(out!)? out = None,"
|
70 |
"ScalarType? out_dtype = None) -> (Tensor(out!), Tensor)");
|
71 |
+
ops.def("get_scheduler_metadata("
|
72 |
"int batch_size,"
|
73 |
"int max_seqlen_q,"
|
74 |
"int max_seqlen_k,"
|
|
|
94 |
"bool? pack_gqa = None,"
|
95 |
"int sm_margin = 0) -> Tensor");
|
96 |
|
97 |
+
ops.impl("fwd", &mha_fwd);
|
98 |
+
ops.impl("bwd", &mha_bwd);
|
99 |
+
ops.impl("fwd_combine", &mha_combine);
|
100 |
+
ops.impl("get_scheduler_metadata", &mha_fwd_get_scheduler_metadata);
|
101 |
}
|
102 |
|
103 |
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|