danieldk HF Staff commited on
Commit
36d2e7e
·
1 Parent(s): f0d59b1
build/torch27-cxx11-rocm63-x86_64-linux/residual_rms_rocm/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .wrapped_rms import residual_rms, reference_residual_rms
2
+
3
+ __all__ = ["residual_rms", "reference_residual_rms"]
build/torch27-cxx11-rocm63-x86_64-linux/residual_rms_rocm/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _residual_rms_rocm_7d048af
3
+ ops = torch.ops._residual_rms_rocm_7d048af
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_residual_rms_rocm_7d048af::{op_name}"
build/torch27-cxx11-rocm63-x86_64-linux/residual_rms_rocm/_residual_rms_rocm_7d048af.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4824e58a42d3b4a6f66e9fcc08098b60bd3f222eda14bdaceb0c15006ac8744e
3
+ size 2086672
build/torch27-cxx11-rocm63-x86_64-linux/residual_rms_rocm/wrapped_rms.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple, Optional
2
+ import torch
3
+ from torch import Tensor
4
+
5
+ from ._ops import ops
6
+
7
+
8
+ _HIGHEST_RESIDUAL_RMS_MODE = 3
9
+
10
+
11
+ def residual_rms_checks(
12
+ input: Tensor,
13
+ residual: Tensor,
14
+ weight: Tensor,
15
+ scale_tensor: Tensor,
16
+ epsilon: float,
17
+ next_buffer: Tensor,
18
+ ) -> None:
19
+ # Check shapes
20
+ assert input.dim() == 2, f"Expected input to have 2 dimensions but got {input.dim() = } instead."
21
+ assert residual.shape == input.shape, \
22
+ f"Expected input and residual to have same shape but got {input.shape = } and {residual.shape = }"
23
+ assert weight.shape == (input.size(1), ), \
24
+ f"Expected weight to have shape {(input.size(1), ) = } but got {weight.shape = }"
25
+ # Check devices
26
+ device = input.device
27
+ assert device.type == "cuda", f"Expected input.device to be of type cuda, but got {device.type = } instead."
28
+ assert residual.device == device, f"Expected {residual.device = } to be the same as {input.device = }"
29
+ if scale_tensor is not None:
30
+ assert scale_tensor.device == device, f"Expected {scale_tensor.device = } to be the same as {input.device = }"
31
+ assert next_buffer.device == device, f"Expected {next_buffer.device = } to be the same as {input.device = }"
32
+ # Check layouts
33
+ assert input.is_contiguous(), f"Expected input to be contiguous but got {input.stride() = }"
34
+ assert residual.is_contiguous(), f"Expected residual to be contiguous but got {residual.stride() = }"
35
+ # Check scalars
36
+ assert epsilon > 0, f"Expected RMS epsilon to be > 0 to avoid division by zero, but got {epsilon = }"
37
+
38
+
39
+ def residual_rms_choose_mode(
40
+ input: Tensor,
41
+ residual: Tensor,
42
+ weight: Tensor,
43
+ next_buffer: Tensor,
44
+ mode: int,
45
+ ) -> int:
46
+ cols_is_multiple_of_8 = (input.size(1) % 8 == 0) and (next_buffer.size(1) % 8 == 0)
47
+ tensors_are_16b_aligned = all([x.data_ptr() % 16 == 0 for x in [input, residual, weight]])
48
+ if mode == -1:
49
+ mode = _HIGHEST_RESIDUAL_RMS_MODE if (tensors_are_16b_aligned and cols_is_multiple_of_8) else 0
50
+ elif mode > 0:
51
+ assert tensors_are_16b_aligned, (
52
+ f"Requested a {mode = } > 0 requires tensors to be 16 bits aligned but got {input.data_ptr() % 16 = }, "
53
+ f"{residual.data_ptr() % 16 = }, {weight.data_ptr() % 16 = }"
54
+ )
55
+ assert cols_is_multiple_of_8, f"Requested {mode = } requires {input.size(1) = } to be a multiple of 8."
56
+ return mode
57
+
58
+
59
+ def infer_num_threads(rows: int, num_threads: int) -> int:
60
+ # Error case
61
+ if num_threads < 0 or num_threads > 1024:
62
+ raise ValueError(f"{num_threads = } is not between 0 and 1024")
63
+ # Case: num_threads was specified
64
+ elif num_threads != 0:
65
+ return num_threads
66
+ # Otherwise, we branch upon the number of rows
67
+ if rows <= 16:
68
+ return 1024
69
+ if rows <= 32:
70
+ return 768
71
+ if rows <= 64:
72
+ return 1024
73
+ if rows <= 256:
74
+ return 960
75
+ return 1024
76
+
77
+ ## Main kernel
78
+ def residual_rms(
79
+ input: Tensor,
80
+ residual: Tensor,
81
+ weight: Tensor,
82
+ epsilon: float,
83
+ scale_tensor: Optional[Tensor] = None,
84
+ next_buffer: Optional[Tensor] = None,
85
+ num_threads: int = 0,
86
+ force_scalar: bool = False,
87
+ ) -> Tuple[Tensor, Tensor]:
88
+ """Kernel that fuses a residual connection, an RMS normalization and a conversion to fp8. The resdiual argument is
89
+ modified inplace (residual <- input + residual).
90
+ Args:
91
+ - input: a fp16 tensor of shape (rows, cols) in row-major format
92
+ - residual: a fp16 tensor of shape (rows, cols) in row-major format
93
+ - weight: a fp16 tensor of shape (cols, ) in row-major format which contains the weight of the RMS norm
94
+ - epsilon: the small epsilon used inside the RMS norm to avoid division by zero
95
+ - scale_tensor: a fp32 one-item tensor to divide the output of the RMS norm before their conversion to fp8. If
96
+ set to None, then the output dtype is fp16
97
+ - next_buffer: an optional tensor of shape (rows, .) to initialize to zero if the output dtype in fp8
98
+ - num_threads: the number of threads per block in the kernel. Default value is 0, which then defaults to 1024
99
+ Outputs:
100
+ - an fp8 tensor of shape (rows, cols) in row-major format
101
+ - the residual modified in place
102
+ """
103
+ if next_buffer is None:
104
+ next_buffer = torch.empty(size=(input.size(0), 0), device=input.device, dtype=torch.float16)
105
+
106
+ residual_rms_checks(input, residual, weight, scale_tensor, epsilon, next_buffer)
107
+ num_threads = infer_num_threads(input.size(0), num_threads)
108
+
109
+ if scale_tensor is not None:
110
+ output = torch.empty(size=input.shape, dtype=torch.float8_e4m3fnuz, device=input.device)
111
+ else:
112
+ # TODO: here, we could use input as the output tensor
113
+ output = torch.empty(size=input.shape, dtype=torch.float16, device=input.device)
114
+ ops.residual_rms(
115
+ input=input,
116
+ residual=residual,
117
+ weight=weight,
118
+ scale_tensor=scale_tensor,
119
+ epsilon=epsilon,
120
+ output=output,
121
+ next_buffer=next_buffer,
122
+ num_threads=num_threads,
123
+ force_scalar=force_scalar,
124
+ )
125
+ return output, residual
126
+
127
+ ## Reference implementation
128
+ def fp8_quantize(
129
+ x_full_precision: Tensor,
130
+ scale: Tensor,
131
+ ) -> Tuple[Tensor, Tensor]:
132
+ finfo = torch.finfo(torch.float8_e4m3fn)
133
+ x_quantized = (x_full_precision * scale.reciprocal()).clamp(min=finfo.min, max=finfo.max)
134
+ x_quantized = x_quantized.to(torch.float8_e4m3fn)
135
+ weight_as_int8 = x_quantized.view(torch.int8)
136
+ ROCM_FP8_NAN_AS_INT = -128
137
+ mask = weight_as_int8 == ROCM_FP8_NAN_AS_INT
138
+ weight_as_int8[mask] = 0
139
+ x_quantized = weight_as_int8.view(torch.float8_e4m3fnuz)
140
+ return x_quantized, scale * 2.0
141
+
142
+ def reference_residual_rms(
143
+ input: Tensor,
144
+ residual: Tensor,
145
+ weight: Tensor,
146
+ epsilon: float,
147
+ scale_tensor: Optional[Tensor],
148
+ next_buffer: Optional[Tensor] = None,
149
+ ) -> Tuple[Tensor, Tensor, float]:
150
+ """Reference for the residual_rms operation. Check its docstring for more details, the only difference here is that
151
+ the scale needs to be passed a tensor and not a float."""
152
+ assert input.dtype == torch.float16, f"Expected torch.float16 but got {input.dtype = }"
153
+ assert residual.dtype == torch.float16, f"Expected torch.float16 but got {residual.dtype = }"
154
+ input += residual
155
+ residual = input
156
+ input = reference_rms(input, epsilon)
157
+ if weight.dtype in [torch.float16, torch.bfloat16]:
158
+ input = input.to(weight.dtype)
159
+ input = weight * input
160
+ if scale_tensor is not None:
161
+ qinput, scale_tensor = fp8_quantize(input, scale_tensor)
162
+ if next_buffer is not None:
163
+ next_buffer.fill_(0)
164
+ else:
165
+ qinput = input
166
+ return qinput, residual, scale_tensor
167
+
168
+ def reference_rms(x: Tensor, eps: float) -> Tensor:
169
+ x = x.to(torch.float32)
170
+ variance = x.pow(2).mean(-1, keepdim=True)
171
+ return x * torch.rsqrt(variance + eps)
build/torch28-cxx11-rocm63-x86_64-linux/residual_rms_rocm/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .wrapped_rms import residual_rms, reference_residual_rms
2
+
3
+ __all__ = ["residual_rms", "reference_residual_rms"]
build/torch28-cxx11-rocm63-x86_64-linux/residual_rms_rocm/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _residual_rms_rocm_7d048af
3
+ ops = torch.ops._residual_rms_rocm_7d048af
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_residual_rms_rocm_7d048af::{op_name}"
build/torch28-cxx11-rocm63-x86_64-linux/residual_rms_rocm/_residual_rms_rocm_7d048af.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:05a6bf9cd460f1b4edb3f8e5e094a54f4208ac863af12d498b0401b9e2675bb3
3
+ size 2086672
build/torch28-cxx11-rocm63-x86_64-linux/residual_rms_rocm/wrapped_rms.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple, Optional
2
+ import torch
3
+ from torch import Tensor
4
+
5
+ from ._ops import ops
6
+
7
+
8
+ _HIGHEST_RESIDUAL_RMS_MODE = 3
9
+
10
+
11
+ def residual_rms_checks(
12
+ input: Tensor,
13
+ residual: Tensor,
14
+ weight: Tensor,
15
+ scale_tensor: Tensor,
16
+ epsilon: float,
17
+ next_buffer: Tensor,
18
+ ) -> None:
19
+ # Check shapes
20
+ assert input.dim() == 2, f"Expected input to have 2 dimensions but got {input.dim() = } instead."
21
+ assert residual.shape == input.shape, \
22
+ f"Expected input and residual to have same shape but got {input.shape = } and {residual.shape = }"
23
+ assert weight.shape == (input.size(1), ), \
24
+ f"Expected weight to have shape {(input.size(1), ) = } but got {weight.shape = }"
25
+ # Check devices
26
+ device = input.device
27
+ assert device.type == "cuda", f"Expected input.device to be of type cuda, but got {device.type = } instead."
28
+ assert residual.device == device, f"Expected {residual.device = } to be the same as {input.device = }"
29
+ if scale_tensor is not None:
30
+ assert scale_tensor.device == device, f"Expected {scale_tensor.device = } to be the same as {input.device = }"
31
+ assert next_buffer.device == device, f"Expected {next_buffer.device = } to be the same as {input.device = }"
32
+ # Check layouts
33
+ assert input.is_contiguous(), f"Expected input to be contiguous but got {input.stride() = }"
34
+ assert residual.is_contiguous(), f"Expected residual to be contiguous but got {residual.stride() = }"
35
+ # Check scalars
36
+ assert epsilon > 0, f"Expected RMS epsilon to be > 0 to avoid division by zero, but got {epsilon = }"
37
+
38
+
39
+ def residual_rms_choose_mode(
40
+ input: Tensor,
41
+ residual: Tensor,
42
+ weight: Tensor,
43
+ next_buffer: Tensor,
44
+ mode: int,
45
+ ) -> int:
46
+ cols_is_multiple_of_8 = (input.size(1) % 8 == 0) and (next_buffer.size(1) % 8 == 0)
47
+ tensors_are_16b_aligned = all([x.data_ptr() % 16 == 0 for x in [input, residual, weight]])
48
+ if mode == -1:
49
+ mode = _HIGHEST_RESIDUAL_RMS_MODE if (tensors_are_16b_aligned and cols_is_multiple_of_8) else 0
50
+ elif mode > 0:
51
+ assert tensors_are_16b_aligned, (
52
+ f"Requested a {mode = } > 0 requires tensors to be 16 bits aligned but got {input.data_ptr() % 16 = }, "
53
+ f"{residual.data_ptr() % 16 = }, {weight.data_ptr() % 16 = }"
54
+ )
55
+ assert cols_is_multiple_of_8, f"Requested {mode = } requires {input.size(1) = } to be a multiple of 8."
56
+ return mode
57
+
58
+
59
+ def infer_num_threads(rows: int, num_threads: int) -> int:
60
+ # Error case
61
+ if num_threads < 0 or num_threads > 1024:
62
+ raise ValueError(f"{num_threads = } is not between 0 and 1024")
63
+ # Case: num_threads was specified
64
+ elif num_threads != 0:
65
+ return num_threads
66
+ # Otherwise, we branch upon the number of rows
67
+ if rows <= 16:
68
+ return 1024
69
+ if rows <= 32:
70
+ return 768
71
+ if rows <= 64:
72
+ return 1024
73
+ if rows <= 256:
74
+ return 960
75
+ return 1024
76
+
77
+ ## Main kernel
78
+ def residual_rms(
79
+ input: Tensor,
80
+ residual: Tensor,
81
+ weight: Tensor,
82
+ epsilon: float,
83
+ scale_tensor: Optional[Tensor] = None,
84
+ next_buffer: Optional[Tensor] = None,
85
+ num_threads: int = 0,
86
+ force_scalar: bool = False,
87
+ ) -> Tuple[Tensor, Tensor]:
88
+ """Kernel that fuses a residual connection, an RMS normalization and a conversion to fp8. The resdiual argument is
89
+ modified inplace (residual <- input + residual).
90
+ Args:
91
+ - input: a fp16 tensor of shape (rows, cols) in row-major format
92
+ - residual: a fp16 tensor of shape (rows, cols) in row-major format
93
+ - weight: a fp16 tensor of shape (cols, ) in row-major format which contains the weight of the RMS norm
94
+ - epsilon: the small epsilon used inside the RMS norm to avoid division by zero
95
+ - scale_tensor: a fp32 one-item tensor to divide the output of the RMS norm before their conversion to fp8. If
96
+ set to None, then the output dtype is fp16
97
+ - next_buffer: an optional tensor of shape (rows, .) to initialize to zero if the output dtype in fp8
98
+ - num_threads: the number of threads per block in the kernel. Default value is 0, which then defaults to 1024
99
+ Outputs:
100
+ - an fp8 tensor of shape (rows, cols) in row-major format
101
+ - the residual modified in place
102
+ """
103
+ if next_buffer is None:
104
+ next_buffer = torch.empty(size=(input.size(0), 0), device=input.device, dtype=torch.float16)
105
+
106
+ residual_rms_checks(input, residual, weight, scale_tensor, epsilon, next_buffer)
107
+ num_threads = infer_num_threads(input.size(0), num_threads)
108
+
109
+ if scale_tensor is not None:
110
+ output = torch.empty(size=input.shape, dtype=torch.float8_e4m3fnuz, device=input.device)
111
+ else:
112
+ # TODO: here, we could use input as the output tensor
113
+ output = torch.empty(size=input.shape, dtype=torch.float16, device=input.device)
114
+ ops.residual_rms(
115
+ input=input,
116
+ residual=residual,
117
+ weight=weight,
118
+ scale_tensor=scale_tensor,
119
+ epsilon=epsilon,
120
+ output=output,
121
+ next_buffer=next_buffer,
122
+ num_threads=num_threads,
123
+ force_scalar=force_scalar,
124
+ )
125
+ return output, residual
126
+
127
+ ## Reference implementation
128
+ def fp8_quantize(
129
+ x_full_precision: Tensor,
130
+ scale: Tensor,
131
+ ) -> Tuple[Tensor, Tensor]:
132
+ finfo = torch.finfo(torch.float8_e4m3fn)
133
+ x_quantized = (x_full_precision * scale.reciprocal()).clamp(min=finfo.min, max=finfo.max)
134
+ x_quantized = x_quantized.to(torch.float8_e4m3fn)
135
+ weight_as_int8 = x_quantized.view(torch.int8)
136
+ ROCM_FP8_NAN_AS_INT = -128
137
+ mask = weight_as_int8 == ROCM_FP8_NAN_AS_INT
138
+ weight_as_int8[mask] = 0
139
+ x_quantized = weight_as_int8.view(torch.float8_e4m3fnuz)
140
+ return x_quantized, scale * 2.0
141
+
142
+ def reference_residual_rms(
143
+ input: Tensor,
144
+ residual: Tensor,
145
+ weight: Tensor,
146
+ epsilon: float,
147
+ scale_tensor: Optional[Tensor],
148
+ next_buffer: Optional[Tensor] = None,
149
+ ) -> Tuple[Tensor, Tensor, float]:
150
+ """Reference for the residual_rms operation. Check its docstring for more details, the only difference here is that
151
+ the scale needs to be passed a tensor and not a float."""
152
+ assert input.dtype == torch.float16, f"Expected torch.float16 but got {input.dtype = }"
153
+ assert residual.dtype == torch.float16, f"Expected torch.float16 but got {residual.dtype = }"
154
+ input += residual
155
+ residual = input
156
+ input = reference_rms(input, epsilon)
157
+ if weight.dtype in [torch.float16, torch.bfloat16]:
158
+ input = input.to(weight.dtype)
159
+ input = weight * input
160
+ if scale_tensor is not None:
161
+ qinput, scale_tensor = fp8_quantize(input, scale_tensor)
162
+ if next_buffer is not None:
163
+ next_buffer.fill_(0)
164
+ else:
165
+ qinput = input
166
+ return qinput, residual, scale_tensor
167
+
168
+ def reference_rms(x: Tensor, eps: float) -> Tensor:
169
+ x = x.to(torch.float32)
170
+ variance = x.pow(2).mean(-1, keepdim=True)
171
+ return x * torch.rsqrt(variance + eps)
build/torch28-cxx11-rocm64-x86_64-linux/residual_rms_rocm/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .wrapped_rms import residual_rms, reference_residual_rms
2
+
3
+ __all__ = ["residual_rms", "reference_residual_rms"]
build/torch28-cxx11-rocm64-x86_64-linux/residual_rms_rocm/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _residual_rms_rocm_7d048af
3
+ ops = torch.ops._residual_rms_rocm_7d048af
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_residual_rms_rocm_7d048af::{op_name}"
build/torch28-cxx11-rocm64-x86_64-linux/residual_rms_rocm/_residual_rms_rocm_7d048af.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e55198597e13d009f6f86906a112280be31e2c6df3e6f186083d1ac555845a33
3
+ size 2092832
build/torch28-cxx11-rocm64-x86_64-linux/residual_rms_rocm/wrapped_rms.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple, Optional
2
+ import torch
3
+ from torch import Tensor
4
+
5
+ from ._ops import ops
6
+
7
+
8
+ _HIGHEST_RESIDUAL_RMS_MODE = 3
9
+
10
+
11
+ def residual_rms_checks(
12
+ input: Tensor,
13
+ residual: Tensor,
14
+ weight: Tensor,
15
+ scale_tensor: Tensor,
16
+ epsilon: float,
17
+ next_buffer: Tensor,
18
+ ) -> None:
19
+ # Check shapes
20
+ assert input.dim() == 2, f"Expected input to have 2 dimensions but got {input.dim() = } instead."
21
+ assert residual.shape == input.shape, \
22
+ f"Expected input and residual to have same shape but got {input.shape = } and {residual.shape = }"
23
+ assert weight.shape == (input.size(1), ), \
24
+ f"Expected weight to have shape {(input.size(1), ) = } but got {weight.shape = }"
25
+ # Check devices
26
+ device = input.device
27
+ assert device.type == "cuda", f"Expected input.device to be of type cuda, but got {device.type = } instead."
28
+ assert residual.device == device, f"Expected {residual.device = } to be the same as {input.device = }"
29
+ if scale_tensor is not None:
30
+ assert scale_tensor.device == device, f"Expected {scale_tensor.device = } to be the same as {input.device = }"
31
+ assert next_buffer.device == device, f"Expected {next_buffer.device = } to be the same as {input.device = }"
32
+ # Check layouts
33
+ assert input.is_contiguous(), f"Expected input to be contiguous but got {input.stride() = }"
34
+ assert residual.is_contiguous(), f"Expected residual to be contiguous but got {residual.stride() = }"
35
+ # Check scalars
36
+ assert epsilon > 0, f"Expected RMS epsilon to be > 0 to avoid division by zero, but got {epsilon = }"
37
+
38
+
39
+ def residual_rms_choose_mode(
40
+ input: Tensor,
41
+ residual: Tensor,
42
+ weight: Tensor,
43
+ next_buffer: Tensor,
44
+ mode: int,
45
+ ) -> int:
46
+ cols_is_multiple_of_8 = (input.size(1) % 8 == 0) and (next_buffer.size(1) % 8 == 0)
47
+ tensors_are_16b_aligned = all([x.data_ptr() % 16 == 0 for x in [input, residual, weight]])
48
+ if mode == -1:
49
+ mode = _HIGHEST_RESIDUAL_RMS_MODE if (tensors_are_16b_aligned and cols_is_multiple_of_8) else 0
50
+ elif mode > 0:
51
+ assert tensors_are_16b_aligned, (
52
+ f"Requested a {mode = } > 0 requires tensors to be 16 bits aligned but got {input.data_ptr() % 16 = }, "
53
+ f"{residual.data_ptr() % 16 = }, {weight.data_ptr() % 16 = }"
54
+ )
55
+ assert cols_is_multiple_of_8, f"Requested {mode = } requires {input.size(1) = } to be a multiple of 8."
56
+ return mode
57
+
58
+
59
+ def infer_num_threads(rows: int, num_threads: int) -> int:
60
+ # Error case
61
+ if num_threads < 0 or num_threads > 1024:
62
+ raise ValueError(f"{num_threads = } is not between 0 and 1024")
63
+ # Case: num_threads was specified
64
+ elif num_threads != 0:
65
+ return num_threads
66
+ # Otherwise, we branch upon the number of rows
67
+ if rows <= 16:
68
+ return 1024
69
+ if rows <= 32:
70
+ return 768
71
+ if rows <= 64:
72
+ return 1024
73
+ if rows <= 256:
74
+ return 960
75
+ return 1024
76
+
77
+ ## Main kernel
78
+ def residual_rms(
79
+ input: Tensor,
80
+ residual: Tensor,
81
+ weight: Tensor,
82
+ epsilon: float,
83
+ scale_tensor: Optional[Tensor] = None,
84
+ next_buffer: Optional[Tensor] = None,
85
+ num_threads: int = 0,
86
+ force_scalar: bool = False,
87
+ ) -> Tuple[Tensor, Tensor]:
88
+ """Kernel that fuses a residual connection, an RMS normalization and a conversion to fp8. The resdiual argument is
89
+ modified inplace (residual <- input + residual).
90
+ Args:
91
+ - input: a fp16 tensor of shape (rows, cols) in row-major format
92
+ - residual: a fp16 tensor of shape (rows, cols) in row-major format
93
+ - weight: a fp16 tensor of shape (cols, ) in row-major format which contains the weight of the RMS norm
94
+ - epsilon: the small epsilon used inside the RMS norm to avoid division by zero
95
+ - scale_tensor: a fp32 one-item tensor to divide the output of the RMS norm before their conversion to fp8. If
96
+ set to None, then the output dtype is fp16
97
+ - next_buffer: an optional tensor of shape (rows, .) to initialize to zero if the output dtype in fp8
98
+ - num_threads: the number of threads per block in the kernel. Default value is 0, which then defaults to 1024
99
+ Outputs:
100
+ - an fp8 tensor of shape (rows, cols) in row-major format
101
+ - the residual modified in place
102
+ """
103
+ if next_buffer is None:
104
+ next_buffer = torch.empty(size=(input.size(0), 0), device=input.device, dtype=torch.float16)
105
+
106
+ residual_rms_checks(input, residual, weight, scale_tensor, epsilon, next_buffer)
107
+ num_threads = infer_num_threads(input.size(0), num_threads)
108
+
109
+ if scale_tensor is not None:
110
+ output = torch.empty(size=input.shape, dtype=torch.float8_e4m3fnuz, device=input.device)
111
+ else:
112
+ # TODO: here, we could use input as the output tensor
113
+ output = torch.empty(size=input.shape, dtype=torch.float16, device=input.device)
114
+ ops.residual_rms(
115
+ input=input,
116
+ residual=residual,
117
+ weight=weight,
118
+ scale_tensor=scale_tensor,
119
+ epsilon=epsilon,
120
+ output=output,
121
+ next_buffer=next_buffer,
122
+ num_threads=num_threads,
123
+ force_scalar=force_scalar,
124
+ )
125
+ return output, residual
126
+
127
+ ## Reference implementation
128
+ def fp8_quantize(
129
+ x_full_precision: Tensor,
130
+ scale: Tensor,
131
+ ) -> Tuple[Tensor, Tensor]:
132
+ finfo = torch.finfo(torch.float8_e4m3fn)
133
+ x_quantized = (x_full_precision * scale.reciprocal()).clamp(min=finfo.min, max=finfo.max)
134
+ x_quantized = x_quantized.to(torch.float8_e4m3fn)
135
+ weight_as_int8 = x_quantized.view(torch.int8)
136
+ ROCM_FP8_NAN_AS_INT = -128
137
+ mask = weight_as_int8 == ROCM_FP8_NAN_AS_INT
138
+ weight_as_int8[mask] = 0
139
+ x_quantized = weight_as_int8.view(torch.float8_e4m3fnuz)
140
+ return x_quantized, scale * 2.0
141
+
142
+ def reference_residual_rms(
143
+ input: Tensor,
144
+ residual: Tensor,
145
+ weight: Tensor,
146
+ epsilon: float,
147
+ scale_tensor: Optional[Tensor],
148
+ next_buffer: Optional[Tensor] = None,
149
+ ) -> Tuple[Tensor, Tensor, float]:
150
+ """Reference for the residual_rms operation. Check its docstring for more details, the only difference here is that
151
+ the scale needs to be passed a tensor and not a float."""
152
+ assert input.dtype == torch.float16, f"Expected torch.float16 but got {input.dtype = }"
153
+ assert residual.dtype == torch.float16, f"Expected torch.float16 but got {residual.dtype = }"
154
+ input += residual
155
+ residual = input
156
+ input = reference_rms(input, epsilon)
157
+ if weight.dtype in [torch.float16, torch.bfloat16]:
158
+ input = input.to(weight.dtype)
159
+ input = weight * input
160
+ if scale_tensor is not None:
161
+ qinput, scale_tensor = fp8_quantize(input, scale_tensor)
162
+ if next_buffer is not None:
163
+ next_buffer.fill_(0)
164
+ else:
165
+ qinput = input
166
+ return qinput, residual, scale_tensor
167
+
168
+ def reference_rms(x: Tensor, eps: float) -> Tensor:
169
+ x = x.to(torch.float32)
170
+ variance = x.pow(2).mean(-1, keepdim=True)
171
+ return x * torch.rsqrt(variance + eps)