kernel
danieldk HF Staff commited on
Commit
2e75662
·
1 Parent(s): a743610

Various small fixes

Browse files
flake.lock CHANGED
@@ -98,11 +98,11 @@
98
  ]
99
  },
100
  "locked": {
101
- "lastModified": 1750275112,
102
- "narHash": "sha256-gqAxmLLt0tYvuRYumOZHQgryMeEFdt6j3nEC8B5rT14=",
103
  "owner": "huggingface",
104
  "repo": "kernel-builder",
105
- "rev": "1b63210b2a1fc3cda2e3a579e7aa8f8c8532626f",
106
  "type": "github"
107
  },
108
  "original": {
 
98
  ]
99
  },
100
  "locked": {
101
+ "lastModified": 1751014803,
102
+ "narHash": "sha256-9Xfq2k3uPfB602NwQF+zAY2GQZiKUN1G7Q6XiDCUR8Y=",
103
  "owner": "huggingface",
104
  "repo": "kernel-builder",
105
+ "rev": "bbc4e712ff2046e217818e97de2201e2b996756e",
106
  "type": "github"
107
  },
108
  "original": {
flake.nix CHANGED
@@ -13,5 +13,37 @@
13
  kernel-builder.lib.genFlakeOutputs {
14
  path = ./.;
15
  rev = self.shortRev or self.dirtyShortRev or self.lastModifiedDate;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  };
17
  }
 
13
  kernel-builder.lib.genFlakeOutputs {
14
  path = ./.;
15
  rev = self.shortRev or self.dirtyShortRev or self.lastModifiedDate;
16
+ # Building with CDUA later than 12.4 fails with:
17
+ #
18
+ # error: 'ptxas' died due to signal 11 (Invalid memory reference)
19
+ #
20
+ # So, build for 12.4 only and copy to all the other build variants
21
+ # by hand (which works fine thanks to backward compat).
22
+ torchVersions = [
23
+ {
24
+ torchVersion = "2.6";
25
+ cudaVersion = "12.4";
26
+ cxx11Abi = false;
27
+ systems = [ "x86_64-linux" ];
28
+ upstreamVariant = true;
29
+ }
30
+ {
31
+ torchVersion = "2.6";
32
+ cudaVersion = "12.4";
33
+ cxx11Abi = true;
34
+ systems = [ "x86_64-linux" ];
35
+ upstreamVariant = true;
36
+ }
37
+ {
38
+ torchVersion = "2.7";
39
+ cudaVersion = "12.4";
40
+ cxx11Abi = true;
41
+ systems = [
42
+ "x86_64-linux"
43
+ "aarch64-linux"
44
+ ];
45
+ upstreamVariant = true;
46
+ }
47
+ ];
48
  };
49
  }
torch-ext/{flash_attn → flash_attn3}/__init__.py RENAMED
File without changes
torch-ext/{flash_attn → flash_attn3}/flash_attn_interface.py RENAMED
File without changes
torch-ext/torch_binding.cpp CHANGED
@@ -5,7 +5,7 @@
5
  #include "torch_binding.h"
6
 
7
  TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
8
- m.def("fwd("
9
  "Tensor q,"
10
  "Tensor k,"
11
  "Tensor v,"
@@ -40,7 +40,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
40
  "int num_splits = 0,"
41
  "bool? pack_gqa = None,"
42
  "int sm_margin = 0) -> (Tensor(out!), Tensor, Tensor, Tensor)");
43
- m.def("bwd("
44
  "Tensor dout,"
45
  "Tensor q,"
46
  "Tensor k,"
@@ -63,12 +63,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
63
  "float softcap = 0.0,"
64
  "bool deterministic = False,"
65
  "int sm_margin = 0) -> (Tensor(dq!), Tensor(dk!), Tensor(dv!), Tensor, Tensor, Tensor, Tensor, Tensor)");
66
- m.def("fwd_combine("
67
  "Tensor out_partial,"
68
  "Tensor lse_partial,"
69
  "Tensor(out!)? out = None,"
70
  "ScalarType? out_dtype = None) -> (Tensor(out!), Tensor)");
71
- m.def("get_scheduler_metadata("
72
  "int batch_size,"
73
  "int max_seqlen_q,"
74
  "int max_seqlen_k,"
@@ -94,10 +94,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
94
  "bool? pack_gqa = None,"
95
  "int sm_margin = 0) -> Tensor");
96
 
97
- m.impl("fwd", &mha_fwd);
98
- m.impl("bwd", &mha_bwd);
99
- m.impl("fwd_combine", &mha_combine);
100
- m.impl("get_scheduler_metadata", &mha_fwd_get_scheduler_metadata);
101
  }
102
 
103
  REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
 
5
  #include "torch_binding.h"
6
 
7
  TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
8
+ ops.def("fwd("
9
  "Tensor q,"
10
  "Tensor k,"
11
  "Tensor v,"
 
40
  "int num_splits = 0,"
41
  "bool? pack_gqa = None,"
42
  "int sm_margin = 0) -> (Tensor(out!), Tensor, Tensor, Tensor)");
43
+ ops.def("bwd("
44
  "Tensor dout,"
45
  "Tensor q,"
46
  "Tensor k,"
 
63
  "float softcap = 0.0,"
64
  "bool deterministic = False,"
65
  "int sm_margin = 0) -> (Tensor(dq!), Tensor(dk!), Tensor(dv!), Tensor, Tensor, Tensor, Tensor, Tensor)");
66
+ ops.def("fwd_combine("
67
  "Tensor out_partial,"
68
  "Tensor lse_partial,"
69
  "Tensor(out!)? out = None,"
70
  "ScalarType? out_dtype = None) -> (Tensor(out!), Tensor)");
71
+ ops.def("get_scheduler_metadata("
72
  "int batch_size,"
73
  "int max_seqlen_q,"
74
  "int max_seqlen_k,"
 
94
  "bool? pack_gqa = None,"
95
  "int sm_margin = 0) -> Tensor");
96
 
97
+ ops.impl("fwd", &mha_fwd);
98
+ ops.impl("bwd", &mha_bwd);
99
+ ops.impl("fwd_combine", &mha_combine);
100
+ ops.impl("get_scheduler_metadata", &mha_fwd_get_scheduler_metadata);
101
  }
102
 
103
  REGISTER_EXTENSION(TORCH_EXTENSION_NAME)