danieldk HF Staff commited on
Commit
0c60fb4
·
1 Parent(s): 2491f56

Enable Torch 2.8 build

Browse files
Files changed (3) hide show
  1. build.toml +80 -68
  2. flake.lock +79 -27
  3. flake.nix +1 -1
build.toml CHANGED
@@ -1,80 +1,92 @@
1
  [general]
2
  name = "quantization_eetq"
 
3
 
4
  [torch]
5
  src = [
6
- "torch-ext/torch_binding.cpp",
7
- "torch-ext/torch_binding.h"
8
  ]
9
 
10
- [kernel.cutlass_kernels]
 
 
 
 
 
 
11
  src = [
12
- "cutlass_extensions/include/cutlass_extensions/arch/mma.h",
13
- "cutlass_extensions/include/cutlass_extensions/compute_occupancy.h",
14
- "cutlass_extensions/include/cutlass_extensions/epilogue/epilogue_quant_helper.h",
15
- "cutlass_extensions/include/cutlass_extensions/epilogue/thread/ft_fused_activations.h",
16
- "cutlass_extensions/include/cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale.h",
17
- "cutlass_extensions/include/cutlass_extensions/epilogue/threadblock/epilogue_tensor_op_int32.h",
18
- "cutlass_extensions/include/cutlass_extensions/epilogue_helpers.h",
19
- "cutlass_extensions/include/cutlass_extensions/ft_gemm_configs.h",
20
- "cutlass_extensions/include/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h",
21
- "cutlass_extensions/include/cutlass_extensions/gemm/kernel/fpA_intB_gemm.h",
22
- "cutlass_extensions/include/cutlass_extensions/gemm/kernel/fpA_intB_gemm_with_broadcast.h",
23
- "cutlass_extensions/include/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h",
24
- "cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma.h",
25
- "cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h",
26
- "cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h",
27
- "cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_mma.h",
28
- "cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_mma_bf16.h",
29
- "cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_base.h",
30
- "cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_multistage.h",
31
- "cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_pipelined.h",
32
- "cutlass_extensions/include/cutlass_extensions/gemm/warp/default_mma_tensor_op.h",
33
- "cutlass_extensions/include/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h",
34
- "cutlass_extensions/include/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h",
35
- "cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h",
36
- "cutlass_extensions/include/cutlass_extensions/tile_interleaved_layout.h",
37
- "cutlass_kernels/cutlass_heuristic.cu",
38
- "cutlass_kernels/cutlass_heuristic.h",
39
- "cutlass_kernels/cutlass_preprocessors.cc",
40
- "cutlass_kernels/cutlass_preprocessors.h",
41
- "cutlass_kernels/fpA_intB_gemm.cu",
42
- "cutlass_kernels/fpA_intB_gemm.h",
43
- "cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm.h",
44
- "cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h",
45
- "cutlass_kernels/fpA_intB_gemm_wrapper.cu",
46
- "cutlass_kernels/fpA_intB_gemm_wrapper.h",
47
- "weightOnlyBatchedGemv/common.h",
48
- "weightOnlyBatchedGemv/enabled.h",
49
- "utils/activation_types.h",
50
- "utils/cuda_utils.h",
51
- "utils/logger.cc",
52
- "utils/logger.h",
53
- "utils/string_utils.h",
54
- "utils/torch_utils.h",
55
  ]
56
- depends = [ "cutlass_2_10", "torch" ]
57
- include = [ ".", "utils", "cutlass_extensions/include" ]
58
 
59
- [kernel.weight_only_batched_gemv]
 
 
 
 
 
 
 
 
 
 
60
  src = [
61
- "cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h",
62
- "cutlass_extensions/include/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h",
63
- "weightOnlyBatchedGemv/common.h",
64
- "weightOnlyBatchedGemv/enabled.h",
65
- "weightOnlyBatchedGemv/kernel.h",
66
- "weightOnlyBatchedGemv/kernelLauncher.cu",
67
- "weightOnlyBatchedGemv/kernelLauncher.h",
68
- "weightOnlyBatchedGemv/utility.h",
69
- "weightOnlyBatchedGemv/weightOnlyBatchedGemvBs1Int4b.cu",
70
- "weightOnlyBatchedGemv/weightOnlyBatchedGemvBs1Int8b.cu",
71
- "weightOnlyBatchedGemv/weightOnlyBatchedGemvBs2Int4b.cu",
72
- "weightOnlyBatchedGemv/weightOnlyBatchedGemvBs2Int8b.cu",
73
- "weightOnlyBatchedGemv/weightOnlyBatchedGemvBs3Int4b.cu",
74
- "weightOnlyBatchedGemv/weightOnlyBatchedGemvBs3Int8b.cu",
75
- "weightOnlyBatchedGemv/weightOnlyBatchedGemvBs4Int4b.cu",
76
- "weightOnlyBatchedGemv/weightOnlyBatchedGemvBs4Int8b.cu",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  ]
78
- depends = [ "cutlass_2_10", "torch" ]
79
- include = [ "cutlass_extensions/include" ]
80
-
 
1
  [general]
2
  name = "quantization_eetq"
3
+ universal = false
4
 
5
  [torch]
6
  src = [
7
+ "torch-ext/torch_binding.cpp",
8
+ "torch-ext/torch_binding.h",
9
  ]
10
 
11
+ [kernel.weight_only_batched_gemv]
12
+ backend = "cuda"
13
+ depends = [
14
+ "cutlass_2_10",
15
+ "torch",
16
+ ]
17
+ include = ["cutlass_extensions/include"]
18
  src = [
19
+ "cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h",
20
+ "cutlass_extensions/include/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h",
21
+ "weightOnlyBatchedGemv/common.h",
22
+ "weightOnlyBatchedGemv/enabled.h",
23
+ "weightOnlyBatchedGemv/kernel.h",
24
+ "weightOnlyBatchedGemv/kernelLauncher.cu",
25
+ "weightOnlyBatchedGemv/kernelLauncher.h",
26
+ "weightOnlyBatchedGemv/utility.h",
27
+ "weightOnlyBatchedGemv/weightOnlyBatchedGemvBs1Int4b.cu",
28
+ "weightOnlyBatchedGemv/weightOnlyBatchedGemvBs1Int8b.cu",
29
+ "weightOnlyBatchedGemv/weightOnlyBatchedGemvBs2Int4b.cu",
30
+ "weightOnlyBatchedGemv/weightOnlyBatchedGemvBs2Int8b.cu",
31
+ "weightOnlyBatchedGemv/weightOnlyBatchedGemvBs3Int4b.cu",
32
+ "weightOnlyBatchedGemv/weightOnlyBatchedGemvBs3Int8b.cu",
33
+ "weightOnlyBatchedGemv/weightOnlyBatchedGemvBs4Int4b.cu",
34
+ "weightOnlyBatchedGemv/weightOnlyBatchedGemvBs4Int8b.cu",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  ]
 
 
36
 
37
+ [kernel.cutlass_kernels]
38
+ backend = "cuda"
39
+ depends = [
40
+ "cutlass_2_10",
41
+ "torch",
42
+ ]
43
+ include = [
44
+ ".",
45
+ "utils",
46
+ "cutlass_extensions/include",
47
+ ]
48
  src = [
49
+ "cutlass_extensions/include/cutlass_extensions/arch/mma.h",
50
+ "cutlass_extensions/include/cutlass_extensions/compute_occupancy.h",
51
+ "cutlass_extensions/include/cutlass_extensions/epilogue/epilogue_quant_helper.h",
52
+ "cutlass_extensions/include/cutlass_extensions/epilogue/thread/ft_fused_activations.h",
53
+ "cutlass_extensions/include/cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale.h",
54
+ "cutlass_extensions/include/cutlass_extensions/epilogue/threadblock/epilogue_tensor_op_int32.h",
55
+ "cutlass_extensions/include/cutlass_extensions/epilogue_helpers.h",
56
+ "cutlass_extensions/include/cutlass_extensions/ft_gemm_configs.h",
57
+ "cutlass_extensions/include/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h",
58
+ "cutlass_extensions/include/cutlass_extensions/gemm/kernel/fpA_intB_gemm.h",
59
+ "cutlass_extensions/include/cutlass_extensions/gemm/kernel/fpA_intB_gemm_with_broadcast.h",
60
+ "cutlass_extensions/include/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h",
61
+ "cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma.h",
62
+ "cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h",
63
+ "cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h",
64
+ "cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_mma.h",
65
+ "cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_mma_bf16.h",
66
+ "cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_base.h",
67
+ "cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_multistage.h",
68
+ "cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_pipelined.h",
69
+ "cutlass_extensions/include/cutlass_extensions/gemm/warp/default_mma_tensor_op.h",
70
+ "cutlass_extensions/include/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h",
71
+ "cutlass_extensions/include/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h",
72
+ "cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h",
73
+ "cutlass_extensions/include/cutlass_extensions/tile_interleaved_layout.h",
74
+ "cutlass_kernels/cutlass_heuristic.cu",
75
+ "cutlass_kernels/cutlass_heuristic.h",
76
+ "cutlass_kernels/cutlass_preprocessors.cc",
77
+ "cutlass_kernels/cutlass_preprocessors.h",
78
+ "cutlass_kernels/fpA_intB_gemm.cu",
79
+ "cutlass_kernels/fpA_intB_gemm.h",
80
+ "cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm.h",
81
+ "cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h",
82
+ "cutlass_kernels/fpA_intB_gemm_wrapper.cu",
83
+ "cutlass_kernels/fpA_intB_gemm_wrapper.h",
84
+ "weightOnlyBatchedGemv/common.h",
85
+ "weightOnlyBatchedGemv/enabled.h",
86
+ "utils/activation_types.h",
87
+ "utils/cuda_utils.h",
88
+ "utils/logger.cc",
89
+ "utils/logger.h",
90
+ "utils/string_utils.h",
91
+ "utils/torch_utils.h",
92
  ]
 
 
 
flake.lock CHANGED
@@ -1,6 +1,21 @@
1
  {
2
  "nodes": {
3
  "flake-compat": {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  "locked": {
5
  "lastModified": 1733328505,
6
  "narHash": "sha256-NeCCThCEP3eCl2l/+27kNNK7QrwZB1IJCrXfrbv5oqU=",
@@ -33,61 +48,83 @@
33
  "type": "github"
34
  }
35
  },
36
- "kernel-builder": {
37
  "inputs": {
38
- "flake-compat": "flake-compat",
39
- "flake-utils": "flake-utils",
40
- "nixpkgs": "nixpkgs",
41
- "rocm-nix": "rocm-nix"
42
  },
43
  "locked": {
44
- "lastModified": 1745320030,
45
- "narHash": "sha256-HDGGPgp1pBi90zylndBySdL0XHuFtq+blv/0fH4g0q8=",
46
- "owner": "huggingface",
47
- "repo": "kernel-builder",
48
- "rev": "c12ad49918de63907aaae26d4fe21150a463380b",
49
  "type": "github"
50
  },
51
  "original": {
52
- "owner": "huggingface",
53
- "repo": "kernel-builder",
54
  "type": "github"
55
  }
56
  },
57
- "nixpkgs": {
 
 
 
 
 
58
  "locked": {
59
- "lastModified": 1743559129,
60
- "narHash": "sha256-7gpAWsENV3tY2HmeHYQ2MoQxGpys+jQWnkS/BHAMXVk=",
61
- "owner": "nixos",
62
- "repo": "nixpkgs",
63
- "rev": "adae22bea8bcc0aa2fd6e8732044660fb7755f5e",
64
  "type": "github"
65
  },
66
  "original": {
67
- "owner": "nixos",
68
- "ref": "nixos-unstable-small",
69
- "repo": "nixpkgs",
70
  "type": "github"
71
  }
72
  },
73
- "rocm-nix": {
74
  "inputs": {
 
 
 
75
  "nixpkgs": [
76
  "kernel-builder",
 
77
  "nixpkgs"
78
  ]
79
  },
80
  "locked": {
81
- "lastModified": 1745310663,
82
- "narHash": "sha256-1U3PzCO/jt7HUlEgLOY3RpxadKwTo6GSvb2j4m0UFw0=",
83
  "owner": "huggingface",
84
- "repo": "rocm-nix",
85
- "rev": "e08373a0efa1c297b0c57af070e0a311df47481f",
86
  "type": "github"
87
  },
88
  "original": {
89
  "owner": "huggingface",
90
- "repo": "rocm-nix",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  "type": "github"
92
  }
93
  },
@@ -110,6 +147,21 @@
110
  "repo": "default",
111
  "type": "github"
112
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
  }
114
  },
115
  "root": "root",
 
1
  {
2
  "nodes": {
3
  "flake-compat": {
4
+ "locked": {
5
+ "lastModified": 1747046372,
6
+ "narHash": "sha256-CIVLLkVgvHYbgI2UpXvIIBJ12HWgX+fjA8Xf8PUmqCY=",
7
+ "owner": "edolstra",
8
+ "repo": "flake-compat",
9
+ "rev": "9100a0f413b0c601e0533d1d94ffd501ce2e7885",
10
+ "type": "github"
11
+ },
12
+ "original": {
13
+ "owner": "edolstra",
14
+ "repo": "flake-compat",
15
+ "type": "github"
16
+ }
17
+ },
18
+ "flake-compat_2": {
19
  "locked": {
20
  "lastModified": 1733328505,
21
  "narHash": "sha256-NeCCThCEP3eCl2l/+27kNNK7QrwZB1IJCrXfrbv5oqU=",
 
48
  "type": "github"
49
  }
50
  },
51
+ "flake-utils_2": {
52
  "inputs": {
53
+ "systems": "systems_2"
 
 
 
54
  },
55
  "locked": {
56
+ "lastModified": 1731533236,
57
+ "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
58
+ "owner": "numtide",
59
+ "repo": "flake-utils",
60
+ "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
61
  "type": "github"
62
  },
63
  "original": {
64
+ "owner": "numtide",
65
+ "repo": "flake-utils",
66
  "type": "github"
67
  }
68
  },
69
+ "hf-nix": {
70
+ "inputs": {
71
+ "flake-compat": "flake-compat_2",
72
+ "flake-utils": "flake-utils_2",
73
+ "nixpkgs": "nixpkgs"
74
+ },
75
  "locked": {
76
+ "lastModified": 1753354560,
77
+ "narHash": "sha256-vmOfRmr0Qm/IbZTWB2sBn+UFrABSTTA/cTg+m27Yt/E=",
78
+ "owner": "huggingface",
79
+ "repo": "hf-nix",
80
+ "rev": "7f2aceda2a2e72cd573bdb25e5c0667fd75f89d3",
81
  "type": "github"
82
  },
83
  "original": {
84
+ "owner": "huggingface",
85
+ "repo": "hf-nix",
 
86
  "type": "github"
87
  }
88
  },
89
+ "kernel-builder": {
90
  "inputs": {
91
+ "flake-compat": "flake-compat",
92
+ "flake-utils": "flake-utils",
93
+ "hf-nix": "hf-nix",
94
  "nixpkgs": [
95
  "kernel-builder",
96
+ "hf-nix",
97
  "nixpkgs"
98
  ]
99
  },
100
  "locked": {
101
+ "lastModified": 1753354632,
102
+ "narHash": "sha256-31SX3Raiyx0qCuY9JSlx9ZZgxljeUxvW+JdujjxbofQ=",
103
  "owner": "huggingface",
104
+ "repo": "kernel-builder",
105
+ "rev": "524b628fd8e58525dbd28455bffb0628092c5265",
106
  "type": "github"
107
  },
108
  "original": {
109
  "owner": "huggingface",
110
+ "ref": "torch-2.8",
111
+ "repo": "kernel-builder",
112
+ "type": "github"
113
+ }
114
+ },
115
+ "nixpkgs": {
116
+ "locked": {
117
+ "lastModified": 1752785354,
118
+ "narHash": "sha256-Y33ryUz7MPqKrZwlbQcsYCUz2jAJCacRf8jbs0tYUlA=",
119
+ "owner": "nixos",
120
+ "repo": "nixpkgs",
121
+ "rev": "d38025438a6ee456758dc03188ca6873a415463b",
122
+ "type": "github"
123
+ },
124
+ "original": {
125
+ "owner": "nixos",
126
+ "repo": "nixpkgs",
127
+ "rev": "d38025438a6ee456758dc03188ca6873a415463b",
128
  "type": "github"
129
  }
130
  },
 
147
  "repo": "default",
148
  "type": "github"
149
  }
150
+ },
151
+ "systems_2": {
152
+ "locked": {
153
+ "lastModified": 1681028828,
154
+ "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
155
+ "owner": "nix-systems",
156
+ "repo": "default",
157
+ "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
158
+ "type": "github"
159
+ },
160
+ "original": {
161
+ "owner": "nix-systems",
162
+ "repo": "default",
163
+ "type": "github"
164
+ }
165
  }
166
  },
167
  "root": "root",
flake.nix CHANGED
@@ -2,7 +2,7 @@
2
  description = "Flake for EETQ kernels";
3
 
4
  inputs = {
5
- kernel-builder.url = "github:huggingface/kernel-builder";
6
  };
7
 
8
  outputs =
 
2
  description = "Flake for EETQ kernels";
3
 
4
  inputs = {
5
+ kernel-builder.url = "github:huggingface/kernel-builder/torch-2.8";
6
  };
7
 
8
  outputs =