danieldk HF Staff commited on
Commit
9f5c3df
·
1 Parent(s): 36d2e7e
build/torch27-cxx11-rocm63-x86_64-linux/residual_rms_rocm/__init__.py DELETED
@@ -1,3 +0,0 @@
1
- from .wrapped_rms import residual_rms, reference_residual_rms
2
-
3
- __all__ = ["residual_rms", "reference_residual_rms"]
 
 
 
 
build/torch27-cxx11-rocm63-x86_64-linux/residual_rms_rocm/_ops.py DELETED
@@ -1,9 +0,0 @@
1
- import torch
2
- from . import _residual_rms_rocm_7d048af
3
- ops = torch.ops._residual_rms_rocm_7d048af
4
-
5
- def add_op_namespace_prefix(op_name: str):
6
- """
7
- Prefix op by namespace.
8
- """
9
- return f"_residual_rms_rocm_7d048af::{op_name}"
 
 
 
 
 
 
 
 
 
 
build/torch27-cxx11-rocm63-x86_64-linux/residual_rms_rocm/wrapped_rms.py DELETED
@@ -1,171 +0,0 @@
1
- from typing import Tuple, Optional
2
- import torch
3
- from torch import Tensor
4
-
5
- from ._ops import ops
6
-
7
-
8
- _HIGHEST_RESIDUAL_RMS_MODE = 3
9
-
10
-
11
- def residual_rms_checks(
12
- input: Tensor,
13
- residual: Tensor,
14
- weight: Tensor,
15
- scale_tensor: Tensor,
16
- epsilon: float,
17
- next_buffer: Tensor,
18
- ) -> None:
19
- # Check shapes
20
- assert input.dim() == 2, f"Expected input to have 2 dimensions but got {input.dim() = } instead."
21
- assert residual.shape == input.shape, \
22
- f"Expected input and residual to have same shape but got {input.shape = } and {residual.shape = }"
23
- assert weight.shape == (input.size(1), ), \
24
- f"Expected weight to have shape {(input.size(1), ) = } but got {weight.shape = }"
25
- # Check devices
26
- device = input.device
27
- assert device.type == "cuda", f"Expected input.device to be of type cuda, but got {device.type = } instead."
28
- assert residual.device == device, f"Expected {residual.device = } to be the same as {input.device = }"
29
- if scale_tensor is not None:
30
- assert scale_tensor.device == device, f"Expected {scale_tensor.device = } to be the same as {input.device = }"
31
- assert next_buffer.device == device, f"Expected {next_buffer.device = } to be the same as {input.device = }"
32
- # Check layouts
33
- assert input.is_contiguous(), f"Expected input to be contiguous but got {input.stride() = }"
34
- assert residual.is_contiguous(), f"Expected residual to be contiguous but got {residual.stride() = }"
35
- # Check scalars
36
- assert epsilon > 0, f"Expected RMS epsilon to be > 0 to avoid division by zero, but got {epsilon = }"
37
-
38
-
39
- def residual_rms_choose_mode(
40
- input: Tensor,
41
- residual: Tensor,
42
- weight: Tensor,
43
- next_buffer: Tensor,
44
- mode: int,
45
- ) -> int:
46
- cols_is_multiple_of_8 = (input.size(1) % 8 == 0) and (next_buffer.size(1) % 8 == 0)
47
- tensors_are_16b_aligned = all([x.data_ptr() % 16 == 0 for x in [input, residual, weight]])
48
- if mode == -1:
49
- mode = _HIGHEST_RESIDUAL_RMS_MODE if (tensors_are_16b_aligned and cols_is_multiple_of_8) else 0
50
- elif mode > 0:
51
- assert tensors_are_16b_aligned, (
52
- f"Requested a {mode = } > 0 requires tensors to be 16 bits aligned but got {input.data_ptr() % 16 = }, "
53
- f"{residual.data_ptr() % 16 = }, {weight.data_ptr() % 16 = }"
54
- )
55
- assert cols_is_multiple_of_8, f"Requested {mode = } requires {input.size(1) = } to be a multiple of 8."
56
- return mode
57
-
58
-
59
- def infer_num_threads(rows: int, num_threads: int) -> int:
60
- # Error case
61
- if num_threads < 0 or num_threads > 1024:
62
- raise ValueError(f"{num_threads = } is not between 0 and 1024")
63
- # Case: num_threads was specified
64
- elif num_threads != 0:
65
- return num_threads
66
- # Otherwise, we branch upon the number of rows
67
- if rows <= 16:
68
- return 1024
69
- if rows <= 32:
70
- return 768
71
- if rows <= 64:
72
- return 1024
73
- if rows <= 256:
74
- return 960
75
- return 1024
76
-
77
- ## Main kernel
78
- def residual_rms(
79
- input: Tensor,
80
- residual: Tensor,
81
- weight: Tensor,
82
- epsilon: float,
83
- scale_tensor: Optional[Tensor] = None,
84
- next_buffer: Optional[Tensor] = None,
85
- num_threads: int = 0,
86
- force_scalar: bool = False,
87
- ) -> Tuple[Tensor, Tensor]:
88
- """Kernel that fuses a residual connection, an RMS normalization and a conversion to fp8. The resdiual argument is
89
- modified inplace (residual <- input + residual).
90
- Args:
91
- - input: a fp16 tensor of shape (rows, cols) in row-major format
92
- - residual: a fp16 tensor of shape (rows, cols) in row-major format
93
- - weight: a fp16 tensor of shape (cols, ) in row-major format which contains the weight of the RMS norm
94
- - epsilon: the small epsilon used inside the RMS norm to avoid division by zero
95
- - scale_tensor: a fp32 one-item tensor to divide the output of the RMS norm before their conversion to fp8. If
96
- set to None, then the output dtype is fp16
97
- - next_buffer: an optional tensor of shape (rows, .) to initialize to zero if the output dtype in fp8
98
- - num_threads: the number of threads per block in the kernel. Default value is 0, which then defaults to 1024
99
- Outputs:
100
- - an fp8 tensor of shape (rows, cols) in row-major format
101
- - the residual modified in place
102
- """
103
- if next_buffer is None:
104
- next_buffer = torch.empty(size=(input.size(0), 0), device=input.device, dtype=torch.float16)
105
-
106
- residual_rms_checks(input, residual, weight, scale_tensor, epsilon, next_buffer)
107
- num_threads = infer_num_threads(input.size(0), num_threads)
108
-
109
- if scale_tensor is not None:
110
- output = torch.empty(size=input.shape, dtype=torch.float8_e4m3fnuz, device=input.device)
111
- else:
112
- # TODO: here, we could use input as the output tensor
113
- output = torch.empty(size=input.shape, dtype=torch.float16, device=input.device)
114
- ops.residual_rms(
115
- input=input,
116
- residual=residual,
117
- weight=weight,
118
- scale_tensor=scale_tensor,
119
- epsilon=epsilon,
120
- output=output,
121
- next_buffer=next_buffer,
122
- num_threads=num_threads,
123
- force_scalar=force_scalar,
124
- )
125
- return output, residual
126
-
127
- ## Reference implementation
128
- def fp8_quantize(
129
- x_full_precision: Tensor,
130
- scale: Tensor,
131
- ) -> Tuple[Tensor, Tensor]:
132
- finfo = torch.finfo(torch.float8_e4m3fn)
133
- x_quantized = (x_full_precision * scale.reciprocal()).clamp(min=finfo.min, max=finfo.max)
134
- x_quantized = x_quantized.to(torch.float8_e4m3fn)
135
- weight_as_int8 = x_quantized.view(torch.int8)
136
- ROCM_FP8_NAN_AS_INT = -128
137
- mask = weight_as_int8 == ROCM_FP8_NAN_AS_INT
138
- weight_as_int8[mask] = 0
139
- x_quantized = weight_as_int8.view(torch.float8_e4m3fnuz)
140
- return x_quantized, scale * 2.0
141
-
142
- def reference_residual_rms(
143
- input: Tensor,
144
- residual: Tensor,
145
- weight: Tensor,
146
- epsilon: float,
147
- scale_tensor: Optional[Tensor],
148
- next_buffer: Optional[Tensor] = None,
149
- ) -> Tuple[Tensor, Tensor, float]:
150
- """Reference for the residual_rms operation. Check its docstring for more details, the only difference here is that
151
- the scale needs to be passed a tensor and not a float."""
152
- assert input.dtype == torch.float16, f"Expected torch.float16 but got {input.dtype = }"
153
- assert residual.dtype == torch.float16, f"Expected torch.float16 but got {residual.dtype = }"
154
- input += residual
155
- residual = input
156
- input = reference_rms(input, epsilon)
157
- if weight.dtype in [torch.float16, torch.bfloat16]:
158
- input = input.to(weight.dtype)
159
- input = weight * input
160
- if scale_tensor is not None:
161
- qinput, scale_tensor = fp8_quantize(input, scale_tensor)
162
- if next_buffer is not None:
163
- next_buffer.fill_(0)
164
- else:
165
- qinput = input
166
- return qinput, residual, scale_tensor
167
-
168
- def reference_rms(x: Tensor, eps: float) -> Tensor:
169
- x = x.to(torch.float32)
170
- variance = x.pow(2).mean(-1, keepdim=True)
171
- return x * torch.rsqrt(variance + eps)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch28-cxx11-rocm63-x86_64-linux/residual_rms_rocm/__init__.py DELETED
@@ -1,3 +0,0 @@
1
- from .wrapped_rms import residual_rms, reference_residual_rms
2
-
3
- __all__ = ["residual_rms", "reference_residual_rms"]
 
 
 
 
build/torch28-cxx11-rocm63-x86_64-linux/residual_rms_rocm/_ops.py DELETED
@@ -1,9 +0,0 @@
1
- import torch
2
- from . import _residual_rms_rocm_7d048af
3
- ops = torch.ops._residual_rms_rocm_7d048af
4
-
5
- def add_op_namespace_prefix(op_name: str):
6
- """
7
- Prefix op by namespace.
8
- """
9
- return f"_residual_rms_rocm_7d048af::{op_name}"
 
 
 
 
 
 
 
 
 
 
build/torch28-cxx11-rocm63-x86_64-linux/residual_rms_rocm/wrapped_rms.py DELETED
@@ -1,171 +0,0 @@
1
- from typing import Tuple, Optional
2
- import torch
3
- from torch import Tensor
4
-
5
- from ._ops import ops
6
-
7
-
8
- _HIGHEST_RESIDUAL_RMS_MODE = 3
9
-
10
-
11
- def residual_rms_checks(
12
- input: Tensor,
13
- residual: Tensor,
14
- weight: Tensor,
15
- scale_tensor: Tensor,
16
- epsilon: float,
17
- next_buffer: Tensor,
18
- ) -> None:
19
- # Check shapes
20
- assert input.dim() == 2, f"Expected input to have 2 dimensions but got {input.dim() = } instead."
21
- assert residual.shape == input.shape, \
22
- f"Expected input and residual to have same shape but got {input.shape = } and {residual.shape = }"
23
- assert weight.shape == (input.size(1), ), \
24
- f"Expected weight to have shape {(input.size(1), ) = } but got {weight.shape = }"
25
- # Check devices
26
- device = input.device
27
- assert device.type == "cuda", f"Expected input.device to be of type cuda, but got {device.type = } instead."
28
- assert residual.device == device, f"Expected {residual.device = } to be the same as {input.device = }"
29
- if scale_tensor is not None:
30
- assert scale_tensor.device == device, f"Expected {scale_tensor.device = } to be the same as {input.device = }"
31
- assert next_buffer.device == device, f"Expected {next_buffer.device = } to be the same as {input.device = }"
32
- # Check layouts
33
- assert input.is_contiguous(), f"Expected input to be contiguous but got {input.stride() = }"
34
- assert residual.is_contiguous(), f"Expected residual to be contiguous but got {residual.stride() = }"
35
- # Check scalars
36
- assert epsilon > 0, f"Expected RMS epsilon to be > 0 to avoid division by zero, but got {epsilon = }"
37
-
38
-
39
- def residual_rms_choose_mode(
40
- input: Tensor,
41
- residual: Tensor,
42
- weight: Tensor,
43
- next_buffer: Tensor,
44
- mode: int,
45
- ) -> int:
46
- cols_is_multiple_of_8 = (input.size(1) % 8 == 0) and (next_buffer.size(1) % 8 == 0)
47
- tensors_are_16b_aligned = all([x.data_ptr() % 16 == 0 for x in [input, residual, weight]])
48
- if mode == -1:
49
- mode = _HIGHEST_RESIDUAL_RMS_MODE if (tensors_are_16b_aligned and cols_is_multiple_of_8) else 0
50
- elif mode > 0:
51
- assert tensors_are_16b_aligned, (
52
- f"Requested a {mode = } > 0 requires tensors to be 16 bits aligned but got {input.data_ptr() % 16 = }, "
53
- f"{residual.data_ptr() % 16 = }, {weight.data_ptr() % 16 = }"
54
- )
55
- assert cols_is_multiple_of_8, f"Requested {mode = } requires {input.size(1) = } to be a multiple of 8."
56
- return mode
57
-
58
-
59
- def infer_num_threads(rows: int, num_threads: int) -> int:
60
- # Error case
61
- if num_threads < 0 or num_threads > 1024:
62
- raise ValueError(f"{num_threads = } is not between 0 and 1024")
63
- # Case: num_threads was specified
64
- elif num_threads != 0:
65
- return num_threads
66
- # Otherwise, we branch upon the number of rows
67
- if rows <= 16:
68
- return 1024
69
- if rows <= 32:
70
- return 768
71
- if rows <= 64:
72
- return 1024
73
- if rows <= 256:
74
- return 960
75
- return 1024
76
-
77
- ## Main kernel
78
- def residual_rms(
79
- input: Tensor,
80
- residual: Tensor,
81
- weight: Tensor,
82
- epsilon: float,
83
- scale_tensor: Optional[Tensor] = None,
84
- next_buffer: Optional[Tensor] = None,
85
- num_threads: int = 0,
86
- force_scalar: bool = False,
87
- ) -> Tuple[Tensor, Tensor]:
88
- """Kernel that fuses a residual connection, an RMS normalization and a conversion to fp8. The resdiual argument is
89
- modified inplace (residual <- input + residual).
90
- Args:
91
- - input: a fp16 tensor of shape (rows, cols) in row-major format
92
- - residual: a fp16 tensor of shape (rows, cols) in row-major format
93
- - weight: a fp16 tensor of shape (cols, ) in row-major format which contains the weight of the RMS norm
94
- - epsilon: the small epsilon used inside the RMS norm to avoid division by zero
95
- - scale_tensor: a fp32 one-item tensor to divide the output of the RMS norm before their conversion to fp8. If
96
- set to None, then the output dtype is fp16
97
- - next_buffer: an optional tensor of shape (rows, .) to initialize to zero if the output dtype in fp8
98
- - num_threads: the number of threads per block in the kernel. Default value is 0, which then defaults to 1024
99
- Outputs:
100
- - an fp8 tensor of shape (rows, cols) in row-major format
101
- - the residual modified in place
102
- """
103
- if next_buffer is None:
104
- next_buffer = torch.empty(size=(input.size(0), 0), device=input.device, dtype=torch.float16)
105
-
106
- residual_rms_checks(input, residual, weight, scale_tensor, epsilon, next_buffer)
107
- num_threads = infer_num_threads(input.size(0), num_threads)
108
-
109
- if scale_tensor is not None:
110
- output = torch.empty(size=input.shape, dtype=torch.float8_e4m3fnuz, device=input.device)
111
- else:
112
- # TODO: here, we could use input as the output tensor
113
- output = torch.empty(size=input.shape, dtype=torch.float16, device=input.device)
114
- ops.residual_rms(
115
- input=input,
116
- residual=residual,
117
- weight=weight,
118
- scale_tensor=scale_tensor,
119
- epsilon=epsilon,
120
- output=output,
121
- next_buffer=next_buffer,
122
- num_threads=num_threads,
123
- force_scalar=force_scalar,
124
- )
125
- return output, residual
126
-
127
- ## Reference implementation
128
- def fp8_quantize(
129
- x_full_precision: Tensor,
130
- scale: Tensor,
131
- ) -> Tuple[Tensor, Tensor]:
132
- finfo = torch.finfo(torch.float8_e4m3fn)
133
- x_quantized = (x_full_precision * scale.reciprocal()).clamp(min=finfo.min, max=finfo.max)
134
- x_quantized = x_quantized.to(torch.float8_e4m3fn)
135
- weight_as_int8 = x_quantized.view(torch.int8)
136
- ROCM_FP8_NAN_AS_INT = -128
137
- mask = weight_as_int8 == ROCM_FP8_NAN_AS_INT
138
- weight_as_int8[mask] = 0
139
- x_quantized = weight_as_int8.view(torch.float8_e4m3fnuz)
140
- return x_quantized, scale * 2.0
141
-
142
- def reference_residual_rms(
143
- input: Tensor,
144
- residual: Tensor,
145
- weight: Tensor,
146
- epsilon: float,
147
- scale_tensor: Optional[Tensor],
148
- next_buffer: Optional[Tensor] = None,
149
- ) -> Tuple[Tensor, Tensor, float]:
150
- """Reference for the residual_rms operation. Check its docstring for more details, the only difference here is that
151
- the scale needs to be passed a tensor and not a float."""
152
- assert input.dtype == torch.float16, f"Expected torch.float16 but got {input.dtype = }"
153
- assert residual.dtype == torch.float16, f"Expected torch.float16 but got {residual.dtype = }"
154
- input += residual
155
- residual = input
156
- input = reference_rms(input, epsilon)
157
- if weight.dtype in [torch.float16, torch.bfloat16]:
158
- input = input.to(weight.dtype)
159
- input = weight * input
160
- if scale_tensor is not None:
161
- qinput, scale_tensor = fp8_quantize(input, scale_tensor)
162
- if next_buffer is not None:
163
- next_buffer.fill_(0)
164
- else:
165
- qinput = input
166
- return qinput, residual, scale_tensor
167
-
168
- def reference_rms(x: Tensor, eps: float) -> Tensor:
169
- x = x.to(torch.float32)
170
- variance = x.pow(2).mean(-1, keepdim=True)
171
- return x * torch.rsqrt(variance + eps)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch28-cxx11-rocm64-x86_64-linux/residual_rms_rocm/__init__.py DELETED
@@ -1,3 +0,0 @@
1
- from .wrapped_rms import residual_rms, reference_residual_rms
2
-
3
- __all__ = ["residual_rms", "reference_residual_rms"]
 
 
 
 
build/torch28-cxx11-rocm64-x86_64-linux/residual_rms_rocm/_ops.py DELETED
@@ -1,9 +0,0 @@
1
- import torch
2
- from . import _residual_rms_rocm_7d048af
3
- ops = torch.ops._residual_rms_rocm_7d048af
4
-
5
- def add_op_namespace_prefix(op_name: str):
6
- """
7
- Prefix op by namespace.
8
- """
9
- return f"_residual_rms_rocm_7d048af::{op_name}"
 
 
 
 
 
 
 
 
 
 
build/torch28-cxx11-rocm64-x86_64-linux/residual_rms_rocm/wrapped_rms.py DELETED
@@ -1,171 +0,0 @@
1
- from typing import Tuple, Optional
2
- import torch
3
- from torch import Tensor
4
-
5
- from ._ops import ops
6
-
7
-
8
- _HIGHEST_RESIDUAL_RMS_MODE = 3
9
-
10
-
11
- def residual_rms_checks(
12
- input: Tensor,
13
- residual: Tensor,
14
- weight: Tensor,
15
- scale_tensor: Tensor,
16
- epsilon: float,
17
- next_buffer: Tensor,
18
- ) -> None:
19
- # Check shapes
20
- assert input.dim() == 2, f"Expected input to have 2 dimensions but got {input.dim() = } instead."
21
- assert residual.shape == input.shape, \
22
- f"Expected input and residual to have same shape but got {input.shape = } and {residual.shape = }"
23
- assert weight.shape == (input.size(1), ), \
24
- f"Expected weight to have shape {(input.size(1), ) = } but got {weight.shape = }"
25
- # Check devices
26
- device = input.device
27
- assert device.type == "cuda", f"Expected input.device to be of type cuda, but got {device.type = } instead."
28
- assert residual.device == device, f"Expected {residual.device = } to be the same as {input.device = }"
29
- if scale_tensor is not None:
30
- assert scale_tensor.device == device, f"Expected {scale_tensor.device = } to be the same as {input.device = }"
31
- assert next_buffer.device == device, f"Expected {next_buffer.device = } to be the same as {input.device = }"
32
- # Check layouts
33
- assert input.is_contiguous(), f"Expected input to be contiguous but got {input.stride() = }"
34
- assert residual.is_contiguous(), f"Expected residual to be contiguous but got {residual.stride() = }"
35
- # Check scalars
36
- assert epsilon > 0, f"Expected RMS epsilon to be > 0 to avoid division by zero, but got {epsilon = }"
37
-
38
-
39
- def residual_rms_choose_mode(
40
- input: Tensor,
41
- residual: Tensor,
42
- weight: Tensor,
43
- next_buffer: Tensor,
44
- mode: int,
45
- ) -> int:
46
- cols_is_multiple_of_8 = (input.size(1) % 8 == 0) and (next_buffer.size(1) % 8 == 0)
47
- tensors_are_16b_aligned = all([x.data_ptr() % 16 == 0 for x in [input, residual, weight]])
48
- if mode == -1:
49
- mode = _HIGHEST_RESIDUAL_RMS_MODE if (tensors_are_16b_aligned and cols_is_multiple_of_8) else 0
50
- elif mode > 0:
51
- assert tensors_are_16b_aligned, (
52
- f"Requested a {mode = } > 0 requires tensors to be 16 bits aligned but got {input.data_ptr() % 16 = }, "
53
- f"{residual.data_ptr() % 16 = }, {weight.data_ptr() % 16 = }"
54
- )
55
- assert cols_is_multiple_of_8, f"Requested {mode = } requires {input.size(1) = } to be a multiple of 8."
56
- return mode
57
-
58
-
59
- def infer_num_threads(rows: int, num_threads: int) -> int:
60
- # Error case
61
- if num_threads < 0 or num_threads > 1024:
62
- raise ValueError(f"{num_threads = } is not between 0 and 1024")
63
- # Case: num_threads was specified
64
- elif num_threads != 0:
65
- return num_threads
66
- # Otherwise, we branch upon the number of rows
67
- if rows <= 16:
68
- return 1024
69
- if rows <= 32:
70
- return 768
71
- if rows <= 64:
72
- return 1024
73
- if rows <= 256:
74
- return 960
75
- return 1024
76
-
77
- ## Main kernel
78
- def residual_rms(
79
- input: Tensor,
80
- residual: Tensor,
81
- weight: Tensor,
82
- epsilon: float,
83
- scale_tensor: Optional[Tensor] = None,
84
- next_buffer: Optional[Tensor] = None,
85
- num_threads: int = 0,
86
- force_scalar: bool = False,
87
- ) -> Tuple[Tensor, Tensor]:
88
- """Kernel that fuses a residual connection, an RMS normalization and a conversion to fp8. The resdiual argument is
89
- modified inplace (residual <- input + residual).
90
- Args:
91
- - input: a fp16 tensor of shape (rows, cols) in row-major format
92
- - residual: a fp16 tensor of shape (rows, cols) in row-major format
93
- - weight: a fp16 tensor of shape (cols, ) in row-major format which contains the weight of the RMS norm
94
- - epsilon: the small epsilon used inside the RMS norm to avoid division by zero
95
- - scale_tensor: a fp32 one-item tensor to divide the output of the RMS norm before their conversion to fp8. If
96
- set to None, then the output dtype is fp16
97
- - next_buffer: an optional tensor of shape (rows, .) to initialize to zero if the output dtype in fp8
98
- - num_threads: the number of threads per block in the kernel. Default value is 0, which then defaults to 1024
99
- Outputs:
100
- - an fp8 tensor of shape (rows, cols) in row-major format
101
- - the residual modified in place
102
- """
103
- if next_buffer is None:
104
- next_buffer = torch.empty(size=(input.size(0), 0), device=input.device, dtype=torch.float16)
105
-
106
- residual_rms_checks(input, residual, weight, scale_tensor, epsilon, next_buffer)
107
- num_threads = infer_num_threads(input.size(0), num_threads)
108
-
109
- if scale_tensor is not None:
110
- output = torch.empty(size=input.shape, dtype=torch.float8_e4m3fnuz, device=input.device)
111
- else:
112
- # TODO: here, we could use input as the output tensor
113
- output = torch.empty(size=input.shape, dtype=torch.float16, device=input.device)
114
- ops.residual_rms(
115
- input=input,
116
- residual=residual,
117
- weight=weight,
118
- scale_tensor=scale_tensor,
119
- epsilon=epsilon,
120
- output=output,
121
- next_buffer=next_buffer,
122
- num_threads=num_threads,
123
- force_scalar=force_scalar,
124
- )
125
- return output, residual
126
-
127
- ## Reference implementation
128
- def fp8_quantize(
129
- x_full_precision: Tensor,
130
- scale: Tensor,
131
- ) -> Tuple[Tensor, Tensor]:
132
- finfo = torch.finfo(torch.float8_e4m3fn)
133
- x_quantized = (x_full_precision * scale.reciprocal()).clamp(min=finfo.min, max=finfo.max)
134
- x_quantized = x_quantized.to(torch.float8_e4m3fn)
135
- weight_as_int8 = x_quantized.view(torch.int8)
136
- ROCM_FP8_NAN_AS_INT = -128
137
- mask = weight_as_int8 == ROCM_FP8_NAN_AS_INT
138
- weight_as_int8[mask] = 0
139
- x_quantized = weight_as_int8.view(torch.float8_e4m3fnuz)
140
- return x_quantized, scale * 2.0
141
-
142
- def reference_residual_rms(
143
- input: Tensor,
144
- residual: Tensor,
145
- weight: Tensor,
146
- epsilon: float,
147
- scale_tensor: Optional[Tensor],
148
- next_buffer: Optional[Tensor] = None,
149
- ) -> Tuple[Tensor, Tensor, float]:
150
- """Reference for the residual_rms operation. Check its docstring for more details, the only difference here is that
151
- the scale needs to be passed a tensor and not a float."""
152
- assert input.dtype == torch.float16, f"Expected torch.float16 but got {input.dtype = }"
153
- assert residual.dtype == torch.float16, f"Expected torch.float16 but got {residual.dtype = }"
154
- input += residual
155
- residual = input
156
- input = reference_rms(input, epsilon)
157
- if weight.dtype in [torch.float16, torch.bfloat16]:
158
- input = input.to(weight.dtype)
159
- input = weight * input
160
- if scale_tensor is not None:
161
- qinput, scale_tensor = fp8_quantize(input, scale_tensor)
162
- if next_buffer is not None:
163
- next_buffer.fill_(0)
164
- else:
165
- qinput = input
166
- return qinput, residual, scale_tensor
167
-
168
- def reference_rms(x: Tensor, eps: float) -> Tensor:
169
- x = x.to(torch.float32)
170
- variance = x.pow(2).mean(-1, keepdim=True)
171
- return x * torch.rsqrt(variance + eps)