Commit
·
c1e53ae
1
Parent(s):
acd39ac
Add support for XPU to run gpt-oss
Browse files- torch-ext/triton_kernels/matmul_ogs.py +2 -1
- torch-ext/triton_kernels/matmul_ogs_details/_common.py +13 -1
- torch-ext/triton_kernels/matmul_ogs_details/_finalize_matmul.py +6 -5
- torch-ext/triton_kernels/matmul_ogs_details/_p_matmul_ogs.py +3 -3
- torch-ext/triton_kernels/matmul_ogs_details/opt_flags.py +80 -1
- torch-ext/triton_kernels/matmul_ogs_details/opt_flags_details/opt_flags_intel.py +41 -0
- torch-ext/triton_kernels/numerics_details/flexpoint.py +2 -1
- torch-ext/triton_kernels/target_info.py +47 -26
torch-ext/triton_kernels/matmul_ogs.py
CHANGED
@@ -602,6 +602,7 @@ def matmul_ogs_torch(x, w, bias,
|
|
602 |
betas = None,
|
603 |
gammas = None,
|
604 |
round_x = None, round_y = None,
|
|
|
605 |
):
|
606 |
is_input_batched = x.ndim == 3
|
607 |
assert x.dtype.itemsize > 1
|
@@ -641,7 +642,7 @@ def matmul_ogs_torch(x, w, bias,
|
|
641 |
else:
|
642 |
idx = gather_indx.src_indx[lo:hi] // n_expts_act
|
643 |
batch = i if is_input_batched else 0
|
644 |
-
out = torch.matmul(round_x(x[batch, idx, :], torch.arange(lo, hi, device=
|
645 |
w[i].float())
|
646 |
if bias is not None:
|
647 |
out += bias[i, :] if betas is None else bias[i, :] * betas[lo:hi, None]
|
|
|
602 |
betas = None,
|
603 |
gammas = None,
|
604 |
round_x = None, round_y = None,
|
605 |
+
device: str = "cuda",
|
606 |
):
|
607 |
is_input_batched = x.ndim == 3
|
608 |
assert x.dtype.itemsize > 1
|
|
|
642 |
else:
|
643 |
idx = gather_indx.src_indx[lo:hi] // n_expts_act
|
644 |
batch = i if is_input_batched else 0
|
645 |
+
out = torch.matmul(round_x(x[batch, idx, :], torch.arange(lo, hi, device=device)).float(),
|
646 |
w[i].float())
|
647 |
if bias is not None:
|
648 |
out += bias[i, :] if betas is None else bias[i, :] * betas[lo:hi, None]
|
torch-ext/triton_kernels/matmul_ogs_details/_common.py
CHANGED
@@ -7,9 +7,21 @@ from triton.tools.tensor_descriptor import TensorDescriptor
|
|
7 |
# -----------------------------------------------------------------------------
|
8 |
# Utilities
|
9 |
# -----------------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
|
|
|
|
|
|
|
|
|
11 |
|
12 |
-
|
|
|
|
|
13 |
def get_scaled_dot_format_string(dtype: tl.dtype):
|
14 |
mapping = {
|
15 |
tl.float16: "fp16",
|
|
|
7 |
# -----------------------------------------------------------------------------
|
8 |
# Utilities
|
9 |
# -----------------------------------------------------------------------------
|
10 |
+
try:
|
11 |
+
_ver_str = getattr(triton, "__version__", "0.0.0").split("+")[0]
|
12 |
+
_parts = _ver_str.split(".")
|
13 |
+
_ver_tuple = tuple(int(p) for p in (_parts + ["0", "0", "0"])[:3])
|
14 |
+
except Exception:
|
15 |
+
_ver_tuple = (0, 0, 0)
|
16 |
|
17 |
+
if _ver_tuple > (3, 4, 0) and hasattr(triton, "constexpr_function"):
|
18 |
+
_constexpr_function = triton.constexpr_function
|
19 |
+
else:
|
20 |
+
_constexpr_function = tl.constexpr_function
|
21 |
|
22 |
+
|
23 |
+
|
24 |
+
@_constexpr_function
|
25 |
def get_scaled_dot_format_string(dtype: tl.dtype):
|
26 |
mapping = {
|
27 |
tl.float16: "fp16",
|
torch-ext/triton_kernels/matmul_ogs_details/_finalize_matmul.py
CHANGED
@@ -4,25 +4,26 @@ from ..numerics_details.flexpoint import float_to_flex, load_scale, update_scale
|
|
4 |
from ..numerics_details.mxfp_details._downcast_to_mxfp import MXFP_BLOCK_SIZE
|
5 |
from ..target_info import cuda_capability_geq as _cuda_capability_geq
|
6 |
from ..target_info import is_hip as _is_hip
|
|
|
7 |
|
8 |
|
9 |
# fmt: off
|
10 |
-
@
|
11 |
def is_hip():
|
12 |
return _is_hip()
|
13 |
|
14 |
|
15 |
-
@
|
16 |
def cuda_capability_geq(x, y):
|
17 |
return _cuda_capability_geq(x, y)
|
18 |
|
19 |
|
20 |
-
@
|
21 |
def log2(n):
|
22 |
return len(bin(n)) - 3
|
23 |
|
24 |
|
25 |
-
@
|
26 |
def _permute_to_end_order(n: int, axis: int):
|
27 |
"""
|
28 |
Returns the order of the axes of a tensor to permute `axis` to the end.
|
@@ -105,7 +106,7 @@ def _finalize_matmul_launch_metadata(grid, kernel, args):
|
|
105 |
return ret
|
106 |
|
107 |
|
108 |
-
@
|
109 |
def _accumulate_f16_into_f32_and_track_absmax_ptx(n_inputs: int, src_type: str, absmax_reg_name: str | None):
|
110 |
"""
|
111 |
Generate PTX code to take fp16 inputs and sum them into an f32 accumulator using mixed-precision
|
|
|
4 |
from ..numerics_details.mxfp_details._downcast_to_mxfp import MXFP_BLOCK_SIZE
|
5 |
from ..target_info import cuda_capability_geq as _cuda_capability_geq
|
6 |
from ..target_info import is_hip as _is_hip
|
7 |
+
from ._common import _constexpr_function
|
8 |
|
9 |
|
10 |
# fmt: off
|
11 |
+
@_constexpr_function
|
12 |
def is_hip():
|
13 |
return _is_hip()
|
14 |
|
15 |
|
16 |
+
@_constexpr_function
|
17 |
def cuda_capability_geq(x, y):
|
18 |
return _cuda_capability_geq(x, y)
|
19 |
|
20 |
|
21 |
+
@_constexpr_function
|
22 |
def log2(n):
|
23 |
return len(bin(n)) - 3
|
24 |
|
25 |
|
26 |
+
@_constexpr_function
|
27 |
def _permute_to_end_order(n: int, axis: int):
|
28 |
"""
|
29 |
Returns the order of the axes of a tensor to permute `axis` to the end.
|
|
|
106 |
return ret
|
107 |
|
108 |
|
109 |
+
@_constexpr_function
|
110 |
def _accumulate_f16_into_f32_and_track_absmax_ptx(n_inputs: int, src_type: str, absmax_reg_name: str | None):
|
111 |
"""
|
112 |
Generate PTX code to take fp16 inputs and sum them into an f32 accumulator using mixed-precision
|
torch-ext/triton_kernels/matmul_ogs_details/_p_matmul_ogs.py
CHANGED
@@ -12,14 +12,14 @@ from ..numerics_details.flexpoint import (
|
|
12 |
compute_scale,
|
13 |
)
|
14 |
from ..numerics_details.mxfp_details._downcast_to_mxfp import MXFP_BLOCK_SIZE
|
15 |
-
from ._common import make_matmul_repr, matmul_launch_metadata, swizzle2d, xcd_swizzle, get_scaled_dot_format_string
|
16 |
|
17 |
|
18 |
-
@
|
19 |
def cuda_capability_geq(major, minor):
|
20 |
return target_info.cuda_capability_geq(major, minor)
|
21 |
|
22 |
-
@
|
23 |
def get_dtype(tensor_or_desc: tl.tensor | tl.tensor_descriptor) -> tl.dtype:
|
24 |
if isinstance(tensor_or_desc, tl.tensor):
|
25 |
return tensor_or_desc.dtype.element_ty
|
|
|
12 |
compute_scale,
|
13 |
)
|
14 |
from ..numerics_details.mxfp_details._downcast_to_mxfp import MXFP_BLOCK_SIZE
|
15 |
+
from ._common import make_matmul_repr, matmul_launch_metadata, swizzle2d, xcd_swizzle, get_scaled_dot_format_string, _constexpr_function
|
16 |
|
17 |
|
18 |
+
@_constexpr_function
|
19 |
def cuda_capability_geq(major, minor):
|
20 |
return target_info.cuda_capability_geq(major, minor)
|
21 |
|
22 |
+
@_constexpr_function
|
23 |
def get_dtype(tensor_or_desc: tl.tensor | tl.tensor_descriptor) -> tl.dtype:
|
24 |
if isinstance(tensor_or_desc, tl.tensor):
|
25 |
return tensor_or_desc.dtype.element_ty
|
torch-ext/triton_kernels/matmul_ogs_details/opt_flags.py
CHANGED
@@ -4,7 +4,7 @@ from dataclasses import dataclass
|
|
4 |
import triton
|
5 |
from ..target_info import get_cdna_version
|
6 |
import torch
|
7 |
-
from .opt_flags_details import opt_flags_amd, opt_flags_nvidia
|
8 |
|
9 |
|
10 |
@dataclass
|
@@ -30,6 +30,83 @@ class OptFlags:
|
|
30 |
raise ValueError("Not supported")
|
31 |
|
32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
|
34 |
def make_default_opt_flags_amd(
|
35 |
out_dtype,
|
@@ -292,6 +369,8 @@ def make_opt_flags(
|
|
292 |
enforce_bitwise_invariance, epilogue_effective_itemsize,
|
293 |
_opt_flags_constraints]
|
294 |
backend = triton.runtime.driver.active.get_current_target().backend
|
|
|
|
|
295 |
if backend == "hip":
|
296 |
return make_default_opt_flags_amd(*args)
|
297 |
if backend == "cuda":
|
|
|
4 |
import triton
|
5 |
from ..target_info import get_cdna_version
|
6 |
import torch
|
7 |
+
from .opt_flags_details import opt_flags_amd, opt_flags_nvidia, opt_flags_intel
|
8 |
|
9 |
|
10 |
@dataclass
|
|
|
30 |
raise ValueError("Not supported")
|
31 |
|
32 |
|
33 |
+
def make_default_opt_flags_intel(
|
34 |
+
out_dtype,
|
35 |
+
lhs_dtype,
|
36 |
+
rhs_dtype,
|
37 |
+
precision_config,
|
38 |
+
m,
|
39 |
+
n,
|
40 |
+
k,
|
41 |
+
routing_data,
|
42 |
+
can_use_persistent_tma,
|
43 |
+
can_use_fused_scatter,
|
44 |
+
enforce_bitwise_invariance,
|
45 |
+
epilogue_effective_itemsize,
|
46 |
+
constraints,
|
47 |
+
):
|
48 |
+
constraints_supported = ["block_m", "block_k", "split_k", "is_persistent", "fused_scatter", "epilogue_subtile", "num_stages"]
|
49 |
+
assert not any([c not in constraints_supported for c in constraints]), constraints.keys()
|
50 |
+
# tokens per expert
|
51 |
+
if routing_data is None:
|
52 |
+
tokens_per_expt = m
|
53 |
+
elif routing_data.expected_tokens_per_expt is None:
|
54 |
+
tokens_per_expt = max(1, m // routing_data.n_expts_tot)
|
55 |
+
else:
|
56 |
+
tokens_per_expt = routing_data.expected_tokens_per_expt
|
57 |
+
# pid swizzling
|
58 |
+
group_m = 8
|
59 |
+
xcd_swizzle = 1
|
60 |
+
# block_m
|
61 |
+
if constraints.get("block_m", None):
|
62 |
+
block_m = constraints["block_m"]
|
63 |
+
elif enforce_bitwise_invariance:
|
64 |
+
block_m = 128
|
65 |
+
else:
|
66 |
+
block_m = max(16, min(triton.next_power_of_2(tokens_per_expt), 128))
|
67 |
+
# block n
|
68 |
+
block_n = opt_flags_intel.compute_block_n(n)
|
69 |
+
# is_persistent
|
70 |
+
is_persistent = constraints.get("is_persistent", False)
|
71 |
+
# block k
|
72 |
+
if constraints.get("block_k", None) is not None:
|
73 |
+
block_k = constraints["block_k"]
|
74 |
+
else:
|
75 |
+
block_k = opt_flags_intel.compute_block_k(k, is_persistent, precision_config)
|
76 |
+
# split_k
|
77 |
+
if constraints.get("split_k", None) is not None:
|
78 |
+
split_k = constraints["split_k"]
|
79 |
+
elif is_persistent or enforce_bitwise_invariance or precision_config.act_scale is not None or precision_config.out_scale is not None:
|
80 |
+
split_k = 1
|
81 |
+
else:
|
82 |
+
estimated_actual_grid_size = opt_flags_intel.compute_grid_size(None, m, n, block_m, block_n)
|
83 |
+
split_k = opt_flags_intel.compute_split_k(block_k, k, estimated_actual_grid_size)
|
84 |
+
|
85 |
+
epilogue_subtile = constraints.get('epilogue_subtile', None)
|
86 |
+
if epilogue_subtile is None:
|
87 |
+
epilogue_subtile = 1
|
88 |
+
|
89 |
+
ret = OptFlags(
|
90 |
+
block_m=block_m,
|
91 |
+
block_n=block_n,
|
92 |
+
block_k=block_k,
|
93 |
+
num_warps=opt_flags_intel.compute_num_warps(block_m, block_n),
|
94 |
+
num_stages=constraints.get("num_stages", 2),
|
95 |
+
fused_scatter=constraints.get('fused_scatter', False),
|
96 |
+
group_m=group_m,
|
97 |
+
xcd_swizzle=xcd_swizzle,
|
98 |
+
w_cache_modifier=None,
|
99 |
+
split_k=split_k,
|
100 |
+
is_persistent=is_persistent,
|
101 |
+
epilogue_subtile=epilogue_subtile,
|
102 |
+
arch=None,
|
103 |
+
target_kernel_kwargs=dict(),
|
104 |
+
idle_sms=0,
|
105 |
+
)
|
106 |
+
# check constraints
|
107 |
+
assert all(getattr(ret, ck) == cv for ck, cv in constraints.items() if cv is not None), f"{ret} != {constraints}"
|
108 |
+
return ret
|
109 |
+
|
110 |
|
111 |
def make_default_opt_flags_amd(
|
112 |
out_dtype,
|
|
|
369 |
enforce_bitwise_invariance, epilogue_effective_itemsize,
|
370 |
_opt_flags_constraints]
|
371 |
backend = triton.runtime.driver.active.get_current_target().backend
|
372 |
+
if backend == "xpu":
|
373 |
+
return make_default_opt_flags_intel(*args)
|
374 |
if backend == "hip":
|
375 |
return make_default_opt_flags_amd(*args)
|
376 |
if backend == "cuda":
|
torch-ext/triton_kernels/matmul_ogs_details/opt_flags_details/opt_flags_intel.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import triton
|
3 |
+
|
4 |
+
|
5 |
+
def compute_grid_size(routing_data, m, n, block_m, block_n):
|
6 |
+
if routing_data is not None:
|
7 |
+
grid_m = routing_data.n_blocks(m, block_m)
|
8 |
+
else:
|
9 |
+
grid_m = triton.cdiv(m, block_m)
|
10 |
+
grid_n = (n + block_n - 1) // block_n
|
11 |
+
return grid_m * grid_n
|
12 |
+
|
13 |
+
|
14 |
+
def compute_block_n(n: int):
|
15 |
+
# block_n:
|
16 |
+
return max(16, min(128, triton.next_power_of_2(n)))
|
17 |
+
|
18 |
+
|
19 |
+
def compute_block_k(k: int | None, is_persistent: bool, precision_config):
|
20 |
+
if k is not None:
|
21 |
+
block_k = max(32, min(128, triton.next_power_of_2(k)))
|
22 |
+
has_mx_weight_scale = precision_config is not None and precision_config.weight_scale is not None
|
23 |
+
if is_persistent and has_mx_weight_scale:
|
24 |
+
block_k = min(block_k, 128)
|
25 |
+
return block_k
|
26 |
+
|
27 |
+
|
28 |
+
def compute_split_k(block_k: int, k: int | None, grid_size: int) -> int:
|
29 |
+
device_props = torch.xpu.get_device_properties(0)
|
30 |
+
n_sms = device_props.gpu_subslice_count
|
31 |
+
split_k = n_sms // grid_size
|
32 |
+
if k is not None:
|
33 |
+
# avoid split_k for small k
|
34 |
+
num_block_k = triton.cdiv(k, block_k)
|
35 |
+
split_k = min(split_k, num_block_k // 4)
|
36 |
+
split_k = max(split_k, 1)
|
37 |
+
return split_k
|
38 |
+
|
39 |
+
|
40 |
+
def compute_num_warps(block_m, block_n):
|
41 |
+
return max(block_m * block_n // 4096, 4)
|
torch-ext/triton_kernels/numerics_details/flexpoint.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
from ..numerics import MAX_FINITE_FLOAT8E4B8, MAX_FINITE_FLOAT8E4NV, MAX_FINITE_FLOAT8E5
|
2 |
from .. import target_info
|
|
|
3 |
import triton
|
4 |
import triton.language as tl
|
5 |
|
@@ -52,7 +53,7 @@ def rcp_max_finite(dtype):
|
|
52 |
tl.static_assert(tl.constexpr(False), f"{dtype} not supported in flexpoint")
|
53 |
|
54 |
|
55 |
-
@
|
56 |
def cuda_capability_geq(major, minor):
|
57 |
return target_info.cuda_capability_geq(major, minor)
|
58 |
|
|
|
1 |
from ..numerics import MAX_FINITE_FLOAT8E4B8, MAX_FINITE_FLOAT8E4NV, MAX_FINITE_FLOAT8E5
|
2 |
from .. import target_info
|
3 |
+
from ..matmul_ogs_details._common import _constexpr_function
|
4 |
import triton
|
5 |
import triton.language as tl
|
6 |
|
|
|
53 |
tl.static_assert(tl.constexpr(False), f"{dtype} not supported in flexpoint")
|
54 |
|
55 |
|
56 |
+
@_constexpr_function
|
57 |
def cuda_capability_geq(major, minor):
|
58 |
return target_info.cuda_capability_geq(major, minor)
|
59 |
|
torch-ext/triton_kernels/target_info.py
CHANGED
@@ -1,54 +1,70 @@
|
|
1 |
import torch
|
2 |
import triton
|
3 |
|
4 |
-
|
|
|
5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
|
|
|
|
|
|
|
|
|
7 |
def is_cuda():
|
8 |
-
|
9 |
-
|
10 |
-
cached_capabilities["is_cuda"] = False if target is None else target.backend == "cuda"
|
11 |
-
return cached_capabilities["is_cuda"]
|
12 |
|
13 |
|
|
|
14 |
def is_hip():
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
|
|
|
|
|
|
|
|
|
19 |
|
|
|
|
|
20 |
def is_hip_cdna3():
|
21 |
-
|
22 |
-
|
23 |
-
cached_capabilities["is_hip_cdna3"] = (target is not None and target.backend == 'hip'
|
24 |
-
and target.arch == 'gfx942')
|
25 |
-
return cached_capabilities["is_hip_cdna3"]
|
26 |
|
27 |
|
|
|
28 |
def is_hip_cdna4():
|
29 |
-
|
30 |
-
|
31 |
-
cached_capabilities["is_hip_cdna4"] = (target is not None and target.backend == 'hip'
|
32 |
-
and target.arch == 'gfx950')
|
33 |
-
return cached_capabilities["is_hip_cdna4"]
|
34 |
|
35 |
|
|
|
36 |
def cuda_capability_geq(major, minor=0):
|
37 |
"""
|
38 |
Determines whether we have compute capability >= (major, minor) and
|
39 |
returns this as a constexpr boolean. This can be used for guarding
|
40 |
inline asm implementations that require a certain compute capability.
|
41 |
"""
|
42 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
return False
|
44 |
-
|
45 |
-
|
46 |
-
cached_capabilities["cuda"] = torch.cuda.get_device_capability()
|
47 |
-
else:
|
48 |
-
cached_capabilities["cuda"] = (0, 0)
|
49 |
-
return cached_capabilities["cuda"] >= (major, minor)
|
50 |
|
51 |
|
|
|
52 |
def get_cdna_version():
|
53 |
"""
|
54 |
Gets the AMD architecture version, i.e. CDNA3 or CDNA4, currently
|
@@ -65,13 +81,18 @@ def get_cdna_version():
|
|
65 |
return -1
|
66 |
|
67 |
|
|
|
68 |
def has_tma_gather():
|
69 |
return cuda_capability_geq(10, 0)
|
70 |
|
71 |
|
|
|
72 |
def has_native_mxfp():
|
73 |
return cuda_capability_geq(10, 0)
|
74 |
|
75 |
|
76 |
def num_sms():
|
77 |
-
|
|
|
|
|
|
|
|
1 |
import torch
|
2 |
import triton
|
3 |
|
4 |
+
from .matmul_ogs_details._common import _constexpr_function
|
5 |
+
from triton.runtime import driver
|
6 |
|
7 |
+
def current_target():
|
8 |
+
try:
|
9 |
+
active_driver = driver.active
|
10 |
+
except RuntimeError:
|
11 |
+
# If there is no active driver, return None
|
12 |
+
return None
|
13 |
+
return active_driver.get_current_target()
|
14 |
|
15 |
+
current_target.__triton_builtin__ = True
|
16 |
+
|
17 |
+
|
18 |
+
@_constexpr_function
|
19 |
def is_cuda():
|
20 |
+
target = current_target()
|
21 |
+
return target is not None and target.backend == "cuda"
|
|
|
|
|
22 |
|
23 |
|
24 |
+
@_constexpr_function
|
25 |
def is_hip():
|
26 |
+
target = current_target()
|
27 |
+
return target is not None and target.backend == "hip"
|
28 |
+
|
29 |
|
30 |
+
@_constexpr_function
|
31 |
+
def is_xpu():
|
32 |
+
target = current_target()
|
33 |
+
return target is not None and target.backend == "xpu"
|
34 |
|
35 |
+
|
36 |
+
@_constexpr_function
|
37 |
def is_hip_cdna3():
|
38 |
+
target = current_target()
|
39 |
+
return target is not None and target.arch == "gfx942"
|
|
|
|
|
|
|
40 |
|
41 |
|
42 |
+
@_constexpr_function
|
43 |
def is_hip_cdna4():
|
44 |
+
target = current_target()
|
45 |
+
return target is not None and target.arch == "gfx950"
|
|
|
|
|
|
|
46 |
|
47 |
|
48 |
+
@_constexpr_function
|
49 |
def cuda_capability_geq(major, minor=0):
|
50 |
"""
|
51 |
Determines whether we have compute capability >= (major, minor) and
|
52 |
returns this as a constexpr boolean. This can be used for guarding
|
53 |
inline asm implementations that require a certain compute capability.
|
54 |
"""
|
55 |
+
"""
|
56 |
+
Determines whether we have compute capability >= (major, minor) and
|
57 |
+
returns this as a constexpr boolean. This can be used for guarding
|
58 |
+
inline asm implementations that require a certain compute capability.
|
59 |
+
"""
|
60 |
+
target = current_target()
|
61 |
+
if target is None or target.backend != "cuda":
|
62 |
return False
|
63 |
+
assert isinstance(target.arch, int)
|
64 |
+
return target.arch >= major * 10 + minor
|
|
|
|
|
|
|
|
|
65 |
|
66 |
|
67 |
+
@_constexpr_function
|
68 |
def get_cdna_version():
|
69 |
"""
|
70 |
Gets the AMD architecture version, i.e. CDNA3 or CDNA4, currently
|
|
|
81 |
return -1
|
82 |
|
83 |
|
84 |
+
@_constexpr_function
|
85 |
def has_tma_gather():
|
86 |
return cuda_capability_geq(10, 0)
|
87 |
|
88 |
|
89 |
+
@_constexpr_function
|
90 |
def has_native_mxfp():
|
91 |
return cuda_capability_geq(10, 0)
|
92 |
|
93 |
|
94 |
def num_sms():
|
95 |
+
if is_cuda():
|
96 |
+
return torch.cuda.get_device_properties(0).multi_processor_count
|
97 |
+
if is_xpu():
|
98 |
+
return torch.xpu.get_device_properties(0).max_compute_units
|