YangKai0616 commited on
Commit
c1e53ae
·
1 Parent(s): acd39ac

Add support for XPU to run gpt-oss

Browse files
torch-ext/triton_kernels/matmul_ogs.py CHANGED
@@ -602,6 +602,7 @@ def matmul_ogs_torch(x, w, bias,
602
  betas = None,
603
  gammas = None,
604
  round_x = None, round_y = None,
 
605
  ):
606
  is_input_batched = x.ndim == 3
607
  assert x.dtype.itemsize > 1
@@ -641,7 +642,7 @@ def matmul_ogs_torch(x, w, bias,
641
  else:
642
  idx = gather_indx.src_indx[lo:hi] // n_expts_act
643
  batch = i if is_input_batched else 0
644
- out = torch.matmul(round_x(x[batch, idx, :], torch.arange(lo, hi, device="cuda")).float(),
645
  w[i].float())
646
  if bias is not None:
647
  out += bias[i, :] if betas is None else bias[i, :] * betas[lo:hi, None]
 
602
  betas = None,
603
  gammas = None,
604
  round_x = None, round_y = None,
605
+ device: str = "cuda",
606
  ):
607
  is_input_batched = x.ndim == 3
608
  assert x.dtype.itemsize > 1
 
642
  else:
643
  idx = gather_indx.src_indx[lo:hi] // n_expts_act
644
  batch = i if is_input_batched else 0
645
+ out = torch.matmul(round_x(x[batch, idx, :], torch.arange(lo, hi, device=device)).float(),
646
  w[i].float())
647
  if bias is not None:
648
  out += bias[i, :] if betas is None else bias[i, :] * betas[lo:hi, None]
torch-ext/triton_kernels/matmul_ogs_details/_common.py CHANGED
@@ -7,9 +7,21 @@ from triton.tools.tensor_descriptor import TensorDescriptor
7
  # -----------------------------------------------------------------------------
8
  # Utilities
9
  # -----------------------------------------------------------------------------
 
 
 
 
 
 
10
 
 
 
 
 
11
 
12
- @tl.constexpr_function
 
 
13
  def get_scaled_dot_format_string(dtype: tl.dtype):
14
  mapping = {
15
  tl.float16: "fp16",
 
7
  # -----------------------------------------------------------------------------
8
  # Utilities
9
  # -----------------------------------------------------------------------------
10
+ try:
11
+ _ver_str = getattr(triton, "__version__", "0.0.0").split("+")[0]
12
+ _parts = _ver_str.split(".")
13
+ _ver_tuple = tuple(int(p) for p in (_parts + ["0", "0", "0"])[:3])
14
+ except Exception:
15
+ _ver_tuple = (0, 0, 0)
16
 
17
+ if _ver_tuple > (3, 4, 0) and hasattr(triton, "constexpr_function"):
18
+ _constexpr_function = triton.constexpr_function
19
+ else:
20
+ _constexpr_function = tl.constexpr_function
21
 
22
+
23
+
24
+ @_constexpr_function
25
  def get_scaled_dot_format_string(dtype: tl.dtype):
26
  mapping = {
27
  tl.float16: "fp16",
torch-ext/triton_kernels/matmul_ogs_details/_finalize_matmul.py CHANGED
@@ -4,25 +4,26 @@ from ..numerics_details.flexpoint import float_to_flex, load_scale, update_scale
4
  from ..numerics_details.mxfp_details._downcast_to_mxfp import MXFP_BLOCK_SIZE
5
  from ..target_info import cuda_capability_geq as _cuda_capability_geq
6
  from ..target_info import is_hip as _is_hip
 
7
 
8
 
9
  # fmt: off
10
- @tl.constexpr_function
11
  def is_hip():
12
  return _is_hip()
13
 
14
 
15
- @tl.constexpr_function
16
  def cuda_capability_geq(x, y):
17
  return _cuda_capability_geq(x, y)
18
 
19
 
20
- @tl.constexpr_function
21
  def log2(n):
22
  return len(bin(n)) - 3
23
 
24
 
25
- @tl.constexpr_function
26
  def _permute_to_end_order(n: int, axis: int):
27
  """
28
  Returns the order of the axes of a tensor to permute `axis` to the end.
@@ -105,7 +106,7 @@ def _finalize_matmul_launch_metadata(grid, kernel, args):
105
  return ret
106
 
107
 
108
- @tl.constexpr_function
109
  def _accumulate_f16_into_f32_and_track_absmax_ptx(n_inputs: int, src_type: str, absmax_reg_name: str | None):
110
  """
111
  Generate PTX code to take fp16 inputs and sum them into an f32 accumulator using mixed-precision
 
4
  from ..numerics_details.mxfp_details._downcast_to_mxfp import MXFP_BLOCK_SIZE
5
  from ..target_info import cuda_capability_geq as _cuda_capability_geq
6
  from ..target_info import is_hip as _is_hip
7
+ from ._common import _constexpr_function
8
 
9
 
10
  # fmt: off
11
+ @_constexpr_function
12
  def is_hip():
13
  return _is_hip()
14
 
15
 
16
+ @_constexpr_function
17
  def cuda_capability_geq(x, y):
18
  return _cuda_capability_geq(x, y)
19
 
20
 
21
+ @_constexpr_function
22
  def log2(n):
23
  return len(bin(n)) - 3
24
 
25
 
26
+ @_constexpr_function
27
  def _permute_to_end_order(n: int, axis: int):
28
  """
29
  Returns the order of the axes of a tensor to permute `axis` to the end.
 
106
  return ret
107
 
108
 
109
+ @_constexpr_function
110
  def _accumulate_f16_into_f32_and_track_absmax_ptx(n_inputs: int, src_type: str, absmax_reg_name: str | None):
111
  """
112
  Generate PTX code to take fp16 inputs and sum them into an f32 accumulator using mixed-precision
torch-ext/triton_kernels/matmul_ogs_details/_p_matmul_ogs.py CHANGED
@@ -12,14 +12,14 @@ from ..numerics_details.flexpoint import (
12
  compute_scale,
13
  )
14
  from ..numerics_details.mxfp_details._downcast_to_mxfp import MXFP_BLOCK_SIZE
15
- from ._common import make_matmul_repr, matmul_launch_metadata, swizzle2d, xcd_swizzle, get_scaled_dot_format_string
16
 
17
 
18
- @tl.constexpr_function
19
  def cuda_capability_geq(major, minor):
20
  return target_info.cuda_capability_geq(major, minor)
21
 
22
- @tl.constexpr_function
23
  def get_dtype(tensor_or_desc: tl.tensor | tl.tensor_descriptor) -> tl.dtype:
24
  if isinstance(tensor_or_desc, tl.tensor):
25
  return tensor_or_desc.dtype.element_ty
 
12
  compute_scale,
13
  )
14
  from ..numerics_details.mxfp_details._downcast_to_mxfp import MXFP_BLOCK_SIZE
15
+ from ._common import make_matmul_repr, matmul_launch_metadata, swizzle2d, xcd_swizzle, get_scaled_dot_format_string, _constexpr_function
16
 
17
 
18
+ @_constexpr_function
19
  def cuda_capability_geq(major, minor):
20
  return target_info.cuda_capability_geq(major, minor)
21
 
22
+ @_constexpr_function
23
  def get_dtype(tensor_or_desc: tl.tensor | tl.tensor_descriptor) -> tl.dtype:
24
  if isinstance(tensor_or_desc, tl.tensor):
25
  return tensor_or_desc.dtype.element_ty
torch-ext/triton_kernels/matmul_ogs_details/opt_flags.py CHANGED
@@ -4,7 +4,7 @@ from dataclasses import dataclass
4
  import triton
5
  from ..target_info import get_cdna_version
6
  import torch
7
- from .opt_flags_details import opt_flags_amd, opt_flags_nvidia
8
 
9
 
10
  @dataclass
@@ -30,6 +30,83 @@ class OptFlags:
30
  raise ValueError("Not supported")
31
 
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
  def make_default_opt_flags_amd(
35
  out_dtype,
@@ -292,6 +369,8 @@ def make_opt_flags(
292
  enforce_bitwise_invariance, epilogue_effective_itemsize,
293
  _opt_flags_constraints]
294
  backend = triton.runtime.driver.active.get_current_target().backend
 
 
295
  if backend == "hip":
296
  return make_default_opt_flags_amd(*args)
297
  if backend == "cuda":
 
4
  import triton
5
  from ..target_info import get_cdna_version
6
  import torch
7
+ from .opt_flags_details import opt_flags_amd, opt_flags_nvidia, opt_flags_intel
8
 
9
 
10
  @dataclass
 
30
  raise ValueError("Not supported")
31
 
32
 
33
+ def make_default_opt_flags_intel(
34
+ out_dtype,
35
+ lhs_dtype,
36
+ rhs_dtype,
37
+ precision_config,
38
+ m,
39
+ n,
40
+ k,
41
+ routing_data,
42
+ can_use_persistent_tma,
43
+ can_use_fused_scatter,
44
+ enforce_bitwise_invariance,
45
+ epilogue_effective_itemsize,
46
+ constraints,
47
+ ):
48
+ constraints_supported = ["block_m", "block_k", "split_k", "is_persistent", "fused_scatter", "epilogue_subtile", "num_stages"]
49
+ assert not any([c not in constraints_supported for c in constraints]), constraints.keys()
50
+ # tokens per expert
51
+ if routing_data is None:
52
+ tokens_per_expt = m
53
+ elif routing_data.expected_tokens_per_expt is None:
54
+ tokens_per_expt = max(1, m // routing_data.n_expts_tot)
55
+ else:
56
+ tokens_per_expt = routing_data.expected_tokens_per_expt
57
+ # pid swizzling
58
+ group_m = 8
59
+ xcd_swizzle = 1
60
+ # block_m
61
+ if constraints.get("block_m", None):
62
+ block_m = constraints["block_m"]
63
+ elif enforce_bitwise_invariance:
64
+ block_m = 128
65
+ else:
66
+ block_m = max(16, min(triton.next_power_of_2(tokens_per_expt), 128))
67
+ # block n
68
+ block_n = opt_flags_intel.compute_block_n(n)
69
+ # is_persistent
70
+ is_persistent = constraints.get("is_persistent", False)
71
+ # block k
72
+ if constraints.get("block_k", None) is not None:
73
+ block_k = constraints["block_k"]
74
+ else:
75
+ block_k = opt_flags_intel.compute_block_k(k, is_persistent, precision_config)
76
+ # split_k
77
+ if constraints.get("split_k", None) is not None:
78
+ split_k = constraints["split_k"]
79
+ elif is_persistent or enforce_bitwise_invariance or precision_config.act_scale is not None or precision_config.out_scale is not None:
80
+ split_k = 1
81
+ else:
82
+ estimated_actual_grid_size = opt_flags_intel.compute_grid_size(None, m, n, block_m, block_n)
83
+ split_k = opt_flags_intel.compute_split_k(block_k, k, estimated_actual_grid_size)
84
+
85
+ epilogue_subtile = constraints.get('epilogue_subtile', None)
86
+ if epilogue_subtile is None:
87
+ epilogue_subtile = 1
88
+
89
+ ret = OptFlags(
90
+ block_m=block_m,
91
+ block_n=block_n,
92
+ block_k=block_k,
93
+ num_warps=opt_flags_intel.compute_num_warps(block_m, block_n),
94
+ num_stages=constraints.get("num_stages", 2),
95
+ fused_scatter=constraints.get('fused_scatter', False),
96
+ group_m=group_m,
97
+ xcd_swizzle=xcd_swizzle,
98
+ w_cache_modifier=None,
99
+ split_k=split_k,
100
+ is_persistent=is_persistent,
101
+ epilogue_subtile=epilogue_subtile,
102
+ arch=None,
103
+ target_kernel_kwargs=dict(),
104
+ idle_sms=0,
105
+ )
106
+ # check constraints
107
+ assert all(getattr(ret, ck) == cv for ck, cv in constraints.items() if cv is not None), f"{ret} != {constraints}"
108
+ return ret
109
+
110
 
111
  def make_default_opt_flags_amd(
112
  out_dtype,
 
369
  enforce_bitwise_invariance, epilogue_effective_itemsize,
370
  _opt_flags_constraints]
371
  backend = triton.runtime.driver.active.get_current_target().backend
372
+ if backend == "xpu":
373
+ return make_default_opt_flags_intel(*args)
374
  if backend == "hip":
375
  return make_default_opt_flags_amd(*args)
376
  if backend == "cuda":
torch-ext/triton_kernels/matmul_ogs_details/opt_flags_details/opt_flags_intel.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import triton
3
+
4
+
5
+ def compute_grid_size(routing_data, m, n, block_m, block_n):
6
+ if routing_data is not None:
7
+ grid_m = routing_data.n_blocks(m, block_m)
8
+ else:
9
+ grid_m = triton.cdiv(m, block_m)
10
+ grid_n = (n + block_n - 1) // block_n
11
+ return grid_m * grid_n
12
+
13
+
14
+ def compute_block_n(n: int):
15
+ # block_n:
16
+ return max(16, min(128, triton.next_power_of_2(n)))
17
+
18
+
19
+ def compute_block_k(k: int | None, is_persistent: bool, precision_config):
20
+ if k is not None:
21
+ block_k = max(32, min(128, triton.next_power_of_2(k)))
22
+ has_mx_weight_scale = precision_config is not None and precision_config.weight_scale is not None
23
+ if is_persistent and has_mx_weight_scale:
24
+ block_k = min(block_k, 128)
25
+ return block_k
26
+
27
+
28
+ def compute_split_k(block_k: int, k: int | None, grid_size: int) -> int:
29
+ device_props = torch.xpu.get_device_properties(0)
30
+ n_sms = device_props.gpu_subslice_count
31
+ split_k = n_sms // grid_size
32
+ if k is not None:
33
+ # avoid split_k for small k
34
+ num_block_k = triton.cdiv(k, block_k)
35
+ split_k = min(split_k, num_block_k // 4)
36
+ split_k = max(split_k, 1)
37
+ return split_k
38
+
39
+
40
+ def compute_num_warps(block_m, block_n):
41
+ return max(block_m * block_n // 4096, 4)
torch-ext/triton_kernels/numerics_details/flexpoint.py CHANGED
@@ -1,5 +1,6 @@
1
  from ..numerics import MAX_FINITE_FLOAT8E4B8, MAX_FINITE_FLOAT8E4NV, MAX_FINITE_FLOAT8E5
2
  from .. import target_info
 
3
  import triton
4
  import triton.language as tl
5
 
@@ -52,7 +53,7 @@ def rcp_max_finite(dtype):
52
  tl.static_assert(tl.constexpr(False), f"{dtype} not supported in flexpoint")
53
 
54
 
55
- @tl.constexpr_function
56
  def cuda_capability_geq(major, minor):
57
  return target_info.cuda_capability_geq(major, minor)
58
 
 
1
  from ..numerics import MAX_FINITE_FLOAT8E4B8, MAX_FINITE_FLOAT8E4NV, MAX_FINITE_FLOAT8E5
2
  from .. import target_info
3
+ from ..matmul_ogs_details._common import _constexpr_function
4
  import triton
5
  import triton.language as tl
6
 
 
53
  tl.static_assert(tl.constexpr(False), f"{dtype} not supported in flexpoint")
54
 
55
 
56
+ @_constexpr_function
57
  def cuda_capability_geq(major, minor):
58
  return target_info.cuda_capability_geq(major, minor)
59
 
torch-ext/triton_kernels/target_info.py CHANGED
@@ -1,54 +1,70 @@
1
  import torch
2
  import triton
3
 
4
- cached_capabilities = {}
 
5
 
 
 
 
 
 
 
 
6
 
 
 
 
 
7
  def is_cuda():
8
- if "is_cuda" not in cached_capabilities:
9
- target = triton.runtime.driver.active.get_current_target()
10
- cached_capabilities["is_cuda"] = False if target is None else target.backend == "cuda"
11
- return cached_capabilities["is_cuda"]
12
 
13
 
 
14
  def is_hip():
15
- if "is_hip" not in cached_capabilities:
16
- cached_capabilities["is_hip"] = torch.cuda.is_available() and bool(torch.version.hip)
17
- return cached_capabilities["is_hip"]
18
 
 
 
 
 
19
 
 
 
20
  def is_hip_cdna3():
21
- if "is_hip_cdna3" not in cached_capabilities:
22
- target = triton.runtime.driver.active.get_current_target()
23
- cached_capabilities["is_hip_cdna3"] = (target is not None and target.backend == 'hip'
24
- and target.arch == 'gfx942')
25
- return cached_capabilities["is_hip_cdna3"]
26
 
27
 
 
28
  def is_hip_cdna4():
29
- if "is_hip_cdna4" not in cached_capabilities:
30
- target = triton.runtime.driver.active.get_current_target()
31
- cached_capabilities["is_hip_cdna4"] = (target is not None and target.backend == 'hip'
32
- and target.arch == 'gfx950')
33
- return cached_capabilities["is_hip_cdna4"]
34
 
35
 
 
36
  def cuda_capability_geq(major, minor=0):
37
  """
38
  Determines whether we have compute capability >= (major, minor) and
39
  returns this as a constexpr boolean. This can be used for guarding
40
  inline asm implementations that require a certain compute capability.
41
  """
42
- if is_hip():
 
 
 
 
 
 
43
  return False
44
- if "cuda" not in cached_capabilities:
45
- if torch.cuda.is_available():
46
- cached_capabilities["cuda"] = torch.cuda.get_device_capability()
47
- else:
48
- cached_capabilities["cuda"] = (0, 0)
49
- return cached_capabilities["cuda"] >= (major, minor)
50
 
51
 
 
52
  def get_cdna_version():
53
  """
54
  Gets the AMD architecture version, i.e. CDNA3 or CDNA4, currently
@@ -65,13 +81,18 @@ def get_cdna_version():
65
  return -1
66
 
67
 
 
68
  def has_tma_gather():
69
  return cuda_capability_geq(10, 0)
70
 
71
 
 
72
  def has_native_mxfp():
73
  return cuda_capability_geq(10, 0)
74
 
75
 
76
  def num_sms():
77
- return torch.cuda.get_device_properties(0).multi_processor_count
 
 
 
 
1
  import torch
2
  import triton
3
 
4
+ from .matmul_ogs_details._common import _constexpr_function
5
+ from triton.runtime import driver
6
 
7
+ def current_target():
8
+ try:
9
+ active_driver = driver.active
10
+ except RuntimeError:
11
+ # If there is no active driver, return None
12
+ return None
13
+ return active_driver.get_current_target()
14
 
15
+ current_target.__triton_builtin__ = True
16
+
17
+
18
+ @_constexpr_function
19
  def is_cuda():
20
+ target = current_target()
21
+ return target is not None and target.backend == "cuda"
 
 
22
 
23
 
24
+ @_constexpr_function
25
  def is_hip():
26
+ target = current_target()
27
+ return target is not None and target.backend == "hip"
28
+
29
 
30
+ @_constexpr_function
31
+ def is_xpu():
32
+ target = current_target()
33
+ return target is not None and target.backend == "xpu"
34
 
35
+
36
+ @_constexpr_function
37
  def is_hip_cdna3():
38
+ target = current_target()
39
+ return target is not None and target.arch == "gfx942"
 
 
 
40
 
41
 
42
+ @_constexpr_function
43
  def is_hip_cdna4():
44
+ target = current_target()
45
+ return target is not None and target.arch == "gfx950"
 
 
 
46
 
47
 
48
+ @_constexpr_function
49
  def cuda_capability_geq(major, minor=0):
50
  """
51
  Determines whether we have compute capability >= (major, minor) and
52
  returns this as a constexpr boolean. This can be used for guarding
53
  inline asm implementations that require a certain compute capability.
54
  """
55
+ """
56
+ Determines whether we have compute capability >= (major, minor) and
57
+ returns this as a constexpr boolean. This can be used for guarding
58
+ inline asm implementations that require a certain compute capability.
59
+ """
60
+ target = current_target()
61
+ if target is None or target.backend != "cuda":
62
  return False
63
+ assert isinstance(target.arch, int)
64
+ return target.arch >= major * 10 + minor
 
 
 
 
65
 
66
 
67
+ @_constexpr_function
68
  def get_cdna_version():
69
  """
70
  Gets the AMD architecture version, i.e. CDNA3 or CDNA4, currently
 
81
  return -1
82
 
83
 
84
+ @_constexpr_function
85
  def has_tma_gather():
86
  return cuda_capability_geq(10, 0)
87
 
88
 
89
+ @_constexpr_function
90
  def has_native_mxfp():
91
  return cuda_capability_geq(10, 0)
92
 
93
 
94
  def num_sms():
95
+ if is_cuda():
96
+ return torch.cuda.get_device_properties(0).multi_processor_count
97
+ if is_xpu():
98
+ return torch.xpu.get_device_properties(0).max_compute_units