Upload custom kernels
Browse files- README.md +0 -0
- build.toml +5 -0
- build/torch-universal/liger_kernels/__init__.py +29 -0
- build/torch-universal/liger_kernels/_ops.py +8 -0
- build/torch-universal/liger_kernels/cross_entropy.py +460 -0
- build/torch-universal/liger_kernels/dyt.py +225 -0
- build/torch-universal/liger_kernels/fused_linear_cross_entropy.py +283 -0
- build/torch-universal/liger_kernels/geglu.py +141 -0
- build/torch-universal/liger_kernels/group_norm.py +305 -0
- build/torch-universal/liger_kernels/jsd.py +201 -0
- build/torch-universal/liger_kernels/kl_div.py +262 -0
- build/torch-universal/liger_kernels/layer_norm.py +265 -0
- build/torch-universal/liger_kernels/qwen2vl_mrope.py +222 -0
- build/torch-universal/liger_kernels/rms_norm.py +365 -0
- build/torch-universal/liger_kernels/rope.py +239 -0
- build/torch-universal/liger_kernels/swiglu.py +116 -0
- build/torch-universal/liger_kernels/tvd.py +207 -0
- build/torch-universal/liger_kernels/utils.py +135 -0
- flake.lock +117 -0
- flake.nix +17 -0
- torch-ext/liger_kernels/__init__.py +29 -0
- torch-ext/liger_kernels/cross_entropy.py +460 -0
- torch-ext/liger_kernels/dyt.py +225 -0
- torch-ext/liger_kernels/fused_linear_cross_entropy.py +283 -0
- torch-ext/liger_kernels/geglu.py +141 -0
- torch-ext/liger_kernels/group_norm.py +305 -0
- torch-ext/liger_kernels/jsd.py +201 -0
- torch-ext/liger_kernels/kl_div.py +262 -0
- torch-ext/liger_kernels/layer_norm.py +265 -0
- torch-ext/liger_kernels/qwen2vl_mrope.py +222 -0
- torch-ext/liger_kernels/rms_norm.py +365 -0
- torch-ext/liger_kernels/rope.py +239 -0
- torch-ext/liger_kernels/swiglu.py +116 -0
- torch-ext/liger_kernels/tvd.py +207 -0
- torch-ext/liger_kernels/utils.py +135 -0
README.md
ADDED
File without changes
|
build.toml
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[general]
|
2 |
+
name = "liger_kernels"
|
3 |
+
|
4 |
+
[torch]
|
5 |
+
universal = true
|
build/torch-universal/liger_kernels/__init__.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from cross_entropy import LigerCrossEntropyFunction
|
2 |
+
from fused_linear_cross_entropy import LigerFusedLinearCrossEntropyFunction
|
3 |
+
from dyt import LigerDyTFunction
|
4 |
+
from geglu import LigerGELUMulFunction
|
5 |
+
from group_norm import LigerGroupNormFunction
|
6 |
+
from kl_div import LigerKLDivLossFunction
|
7 |
+
from layer_norm import LigerLayerNormFunction
|
8 |
+
from qwen2vl_mrope import LigerQwen2VLMRopeFunction
|
9 |
+
from rms_norm import LigerRMSNormFunction
|
10 |
+
from jsd import LigerJSDFunction
|
11 |
+
from rope import LigerRopeFunction
|
12 |
+
from swiglu import LigerSiLUMulFunction
|
13 |
+
from tvd import LigerTVDLossFunction
|
14 |
+
|
15 |
+
__all__ = [
|
16 |
+
"LigerCrossEntropyFunction",
|
17 |
+
"LigerFusedLinearCrossEntropyFunction",
|
18 |
+
"LigerDyTFunction",
|
19 |
+
"LigerGELUMulFunction",
|
20 |
+
"LigerGroupNormFunction",
|
21 |
+
"LigerKLDivLossFunction",
|
22 |
+
"LigerLayerNormFunction",
|
23 |
+
"LigerQwen2VLMRopeFunction",
|
24 |
+
"LigerRMSNormFunction",
|
25 |
+
"LigerJSDFunction",
|
26 |
+
"LigerRopeFunction",
|
27 |
+
"LigerSiLUMulFunction",
|
28 |
+
"LigerTVDLossFunction",
|
29 |
+
]
|
build/torch-universal/liger_kernels/_ops.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
ops = torch.ops._liger_kernels_20250505094412
|
3 |
+
|
4 |
+
def add_op_namespace_prefix(op_name: str):
|
5 |
+
"""
|
6 |
+
Prefix op by namespace.
|
7 |
+
"""
|
8 |
+
return f"_liger_kernels_20250505094412::{op_name}"
|
build/torch-universal/liger_kernels/cross_entropy.py
ADDED
@@ -0,0 +1,460 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import operator
|
2 |
+
|
3 |
+
from typing import Optional
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import triton
|
7 |
+
import triton.language as tl
|
8 |
+
|
9 |
+
from utils import compare_version
|
10 |
+
from utils import element_mul_kernel
|
11 |
+
from utils import is_hip
|
12 |
+
from utils import infer_device
|
13 |
+
|
14 |
+
if compare_version("triton", operator.ge, "3.0.0"):
|
15 |
+
try:
|
16 |
+
# typical import path with dispatch available
|
17 |
+
from triton.language.extra.libdevice import tanh
|
18 |
+
except ModuleNotFoundError:
|
19 |
+
# for working with NGC containers
|
20 |
+
from triton.language.extra.cuda.libdevice import tanh
|
21 |
+
else:
|
22 |
+
from triton.language.math import tanh
|
23 |
+
|
24 |
+
|
25 |
+
@triton.jit
|
26 |
+
def liger_cross_entropy_kernel(
|
27 |
+
X_ptr,
|
28 |
+
X_stride,
|
29 |
+
Y_ptr,
|
30 |
+
Y_stride,
|
31 |
+
weight_ptr,
|
32 |
+
loss_ptr,
|
33 |
+
z_loss_ptr,
|
34 |
+
loss_stride,
|
35 |
+
n_cols,
|
36 |
+
n_non_ignore,
|
37 |
+
sum_non_ignore_weight,
|
38 |
+
weight_sum,
|
39 |
+
ignore_index,
|
40 |
+
lse_square_scale: tl.constexpr,
|
41 |
+
label_smoothing: tl.constexpr,
|
42 |
+
reduction: tl.constexpr, # set it as constexpr since reduction is always known at compile time
|
43 |
+
softcap,
|
44 |
+
RETURN_Z_LOSS: tl.constexpr,
|
45 |
+
BLOCK_SIZE: tl.constexpr,
|
46 |
+
HAS_WEIGHT: tl.constexpr,
|
47 |
+
HAS_SOFTCAPPING: tl.constexpr,
|
48 |
+
):
|
49 |
+
"""
|
50 |
+
This kernel computes both cross entropy loss and the gradient of the input.
|
51 |
+
We only consider hard label + mean reduction for now. Please refer to https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html for the math.
|
52 |
+
|
53 |
+
Parameters:
|
54 |
+
X_ptr: Pointer to input tensor.
|
55 |
+
X_stride (int): The stride of the input tensor.
|
56 |
+
Y_ptr: Pointer to target tensor.
|
57 |
+
Y_stride (int): The stride of the target tensor.
|
58 |
+
weight_ptr: Pointer to weight tensor.
|
59 |
+
loss_ptr: Pointer to tensor to store the loss.
|
60 |
+
z_loss_ptr: Pointer to tensor to store the z loss. No operation if RETURN_Z_LOSS is 0.
|
61 |
+
loss_stride (int): The stride of the loss tensor.
|
62 |
+
n_cols (int): The number of columns in the input tensor.
|
63 |
+
n_non_ignore (float): The number of non-ignored elements in the batch.
|
64 |
+
sum_non_ignore_weight (float): The sum of non-ignored target's weights in the batch.
|
65 |
+
weight_sum (float): The sum of weight tensor.
|
66 |
+
ignore_index (int): The index to ignore in the target.
|
67 |
+
label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
|
68 |
+
lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training.
|
69 |
+
reduction (str): The string for the reduction to apply
|
70 |
+
softcap (float): The upper threshold for scaling logits to the range (-softcap, +softcap).
|
71 |
+
RETURN_Z_LOSS (int): The boolean value to decide whether storing z loss to z_loss_ptr or not. It must be 0 or 1.
|
72 |
+
BLOCK_SIZE (int): The block size for Triton operations.
|
73 |
+
HAS_WEIGHT (bool): The boolean value to determine whether assigning weight to each of the classes.
|
74 |
+
HAS_SOFTCAPPING (bool): The boolean value to determine whether applying soft-capping or not.
|
75 |
+
"""
|
76 |
+
|
77 |
+
# https://github.com/triton-lang/triton/issues/1058
|
78 |
+
# If B*T*V is too large, program_id * stride will overflow out of int32, so we convert to int64
|
79 |
+
program_id = tl.program_id(0).to(tl.int64)
|
80 |
+
|
81 |
+
# 1. Load Y_ptr first because if the target is ignore_index, we can return right away
|
82 |
+
Y_ptr += program_id * Y_stride
|
83 |
+
y = tl.load(Y_ptr)
|
84 |
+
|
85 |
+
# 2. locate the start index
|
86 |
+
X_ptr += program_id * X_stride
|
87 |
+
|
88 |
+
if y == ignore_index:
|
89 |
+
# set all X_ptr as 0
|
90 |
+
for i in range(0, n_cols, BLOCK_SIZE):
|
91 |
+
X_offsets = i + tl.arange(0, BLOCK_SIZE)
|
92 |
+
tl.store(X_ptr + X_offsets, 0.0, mask=X_offsets < n_cols)
|
93 |
+
return
|
94 |
+
|
95 |
+
loss_ptr += program_id * loss_stride
|
96 |
+
if RETURN_Z_LOSS:
|
97 |
+
z_loss_ptr += program_id * loss_stride
|
98 |
+
|
99 |
+
if HAS_WEIGHT:
|
100 |
+
weight_y = tl.load(weight_ptr + y).cast(tl.float32)
|
101 |
+
|
102 |
+
# Online softmax: 2 loads + 1 store (compared with 3 loads + 1 store for the safe softmax)
|
103 |
+
# Refer to Algorithm 3 in the paper: https://arxiv.org/pdf/1805.02867
|
104 |
+
|
105 |
+
# 3. [Online softmax] first pass: find max + sum
|
106 |
+
m = float("-inf") # m is the max value. use the notation from the paper
|
107 |
+
d = 0.0 # d is the sum. use the notation from the paper
|
108 |
+
ori_X_y = tl.load(X_ptr + y).cast(tl.float32) # we need to store the original value of X_y for the loss calculation
|
109 |
+
if HAS_SOFTCAPPING:
|
110 |
+
ori_X_y = softcap * tanh(ori_X_y / softcap)
|
111 |
+
|
112 |
+
# Label smoothing is a general case of normal cross entropy
|
113 |
+
# See the full derivation at https://github.com/linkedin/Liger-Kernel/pull/198#issue-2503665310
|
114 |
+
scaled_x_sum = 0.0
|
115 |
+
eps = label_smoothing / n_cols
|
116 |
+
|
117 |
+
for i in range(0, n_cols, BLOCK_SIZE):
|
118 |
+
X_offsets = i + tl.arange(0, BLOCK_SIZE)
|
119 |
+
X_block = tl.load(
|
120 |
+
X_ptr + X_offsets,
|
121 |
+
mask=X_offsets < n_cols,
|
122 |
+
other=float("-inf"),
|
123 |
+
# Ensure float32 precision for softmax calculation
|
124 |
+
).cast(tl.float32)
|
125 |
+
if HAS_SOFTCAPPING:
|
126 |
+
X_block = softcap * tanh(X_block / softcap)
|
127 |
+
block_max = tl.max(X_block)
|
128 |
+
if label_smoothing > 0:
|
129 |
+
# scale X beforehand to avoid overflow
|
130 |
+
if HAS_WEIGHT:
|
131 |
+
weight_block = tl.load(weight_ptr + X_offsets, mask=X_offsets < n_cols)
|
132 |
+
scaled_x_sum += tl.sum(tl.where(X_offsets < n_cols, -eps * X_block * weight_block, 0.0))
|
133 |
+
else:
|
134 |
+
scaled_x_sum += tl.sum(tl.where(X_offsets < n_cols, -eps * X_block, 0.0))
|
135 |
+
m_new = tl.maximum(m, block_max)
|
136 |
+
d = d * tl.exp(m - m_new) + tl.sum(tl.exp(X_block - m_new))
|
137 |
+
m = m_new
|
138 |
+
|
139 |
+
# log (sum(e^(X_i))) = log (sum(e ^ (max(X) * e ^ (X_i - max(X)))))
|
140 |
+
# = log (e^(max(X)) * sum(e ^ (X_i - max(X))))
|
141 |
+
# = max(X) + log (sum(e ^ (X_i - max(X)))) = m + log d
|
142 |
+
lse = m + tl.log(d)
|
143 |
+
|
144 |
+
# 4. [Online Softmax] Second pass: compute gradients
|
145 |
+
# For 'mean' reduction, gradients are normalized by number of non-ignored elements (N)
|
146 |
+
# dx_y = (softmax(x_y) - 1) / N
|
147 |
+
# dx_i = softmax(x_i) / N, i != y
|
148 |
+
# For label smoothing:
|
149 |
+
# dx_i = (softmax(x_i) - label_smoothing / V) / N, V = n_cols, i != y
|
150 |
+
# dx_y = (softmax(x_y) - label_smoothing / V - (1 - label_smoothing)) / N
|
151 |
+
# = dx_i - (1 - label_smoothing) / N
|
152 |
+
# With Z loss:
|
153 |
+
# dx_i = ((1 + 2 * lse_square_scale * lse) * softmax(x_i) - label_smoothing / V) / N, i != y
|
154 |
+
# dx_y = dx_i - (1 - label_smoothing) / N
|
155 |
+
# For 'sum' reduction, no normalization is applied:
|
156 |
+
# dx_y = softmax(x_y) - 1
|
157 |
+
# dx_i = softmax(x_i), for i ≠ y
|
158 |
+
|
159 |
+
for i in range(0, n_cols, BLOCK_SIZE):
|
160 |
+
X_offsets = i + tl.arange(0, BLOCK_SIZE)
|
161 |
+
X_block = tl.load(
|
162 |
+
X_ptr + X_offsets,
|
163 |
+
mask=X_offsets < n_cols,
|
164 |
+
other=float("-inf"),
|
165 |
+
# Ensure float32 precision for softmax calculation
|
166 |
+
).cast(tl.float32)
|
167 |
+
if HAS_SOFTCAPPING:
|
168 |
+
intermediate = tanh(X_block / softcap)
|
169 |
+
X_block = softcap * intermediate
|
170 |
+
|
171 |
+
if not HAS_WEIGHT:
|
172 |
+
# softmax(x_i)
|
173 |
+
X_block = tl.exp(X_block - m) / d
|
174 |
+
# derivative of z-loss: 2 * lse_square_scale * lse * softmax(x_i)
|
175 |
+
X_block += 2 * lse_square_scale * lse * X_block
|
176 |
+
# smoothing term
|
177 |
+
X_block += -eps
|
178 |
+
# special handle dx_y
|
179 |
+
X_block = tl.where(X_offsets != y, X_block, X_block - (1 - label_smoothing))
|
180 |
+
# reduction scale
|
181 |
+
if reduction == "mean":
|
182 |
+
X_block = X_block / n_non_ignore
|
183 |
+
else:
|
184 |
+
weight_block = tl.load(weight_ptr + X_offsets, mask=X_offsets < n_cols)
|
185 |
+
softmax_X = tl.exp(X_block - m) / d
|
186 |
+
# derivative of original_loss
|
187 |
+
dloss_ori = (1 - label_smoothing) * softmax_X
|
188 |
+
# specially handle dx_y
|
189 |
+
dloss_ori = tl.where(X_offsets != y, dloss_ori, dloss_ori - (1 - label_smoothing))
|
190 |
+
dloss_ori = dloss_ori * weight_y
|
191 |
+
# derivative of smooth_loss
|
192 |
+
dloss_smooth = eps * (-weight_block + softmax_X * weight_sum)
|
193 |
+
# derivative of z-loss
|
194 |
+
dz_loss = 2 * lse_square_scale * lse * softmax_X
|
195 |
+
# reduction scale
|
196 |
+
if reduction == "mean":
|
197 |
+
dloss_ori = dloss_ori / sum_non_ignore_weight
|
198 |
+
dloss_smooth = dloss_smooth / sum_non_ignore_weight
|
199 |
+
# TODO: Implement weighted z_loss. Currently, z_loss is not scaled by weight.
|
200 |
+
dz_loss = dz_loss / n_non_ignore
|
201 |
+
# derivative of total_loss
|
202 |
+
X_block = dloss_ori + dloss_smooth + dz_loss
|
203 |
+
|
204 |
+
# chain rule softcapping
|
205 |
+
# d(softcap * tanh(x / softcap)) = (1 - tanh^2(x / softcap))
|
206 |
+
if HAS_SOFTCAPPING:
|
207 |
+
X_block = X_block * (1 - intermediate * intermediate)
|
208 |
+
|
209 |
+
tl.store(X_ptr + X_offsets, X_block, mask=X_offsets < n_cols)
|
210 |
+
|
211 |
+
# We need tl.debug_barrier() to ensure the new result of X_ptr is written as mentioned in
|
212 |
+
# https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/ops/cross_entropy.py#L34
|
213 |
+
tl.debug_barrier()
|
214 |
+
|
215 |
+
# 5. Calculate the loss
|
216 |
+
|
217 |
+
# loss = log (softmax(X_y)) = log ((e ^ (X_y - max(X)) / sum(e ^ (X - max(X))))
|
218 |
+
# = (X_y - max(X)) - log(sum(e ^ (X - max(X))))
|
219 |
+
# = X_y - m - log d = X_y - lse
|
220 |
+
# sum(e ^ (X - max(X))) must >= 1 because the max term is e ^ 0 = 1
|
221 |
+
# So we can safely calculate log (softmax(X_y)) without overflow
|
222 |
+
loss = lse - ori_X_y
|
223 |
+
if HAS_WEIGHT:
|
224 |
+
loss = weight_y * loss
|
225 |
+
|
226 |
+
# Original loss = H(q, p), with label smoothing regularization = H(q', p) and (label_smoothing / V) = eps
|
227 |
+
# H(q', p) = (1 - label_smoothing) * H(q, p) + label_smoothing * H(u, p)
|
228 |
+
# = (1 - label_smoothing) * H(q, p) + eps * sum(logsoftmax(x_i))
|
229 |
+
# By using m (global max of xi) and d (sum of e^(xi-m)), we can simplify as:
|
230 |
+
# = (1 - label_smoothing) * H(q, p) + (sum(-eps * x_i) + label_smoothing * (m + logd))
|
231 |
+
# Refer to H(q', p) in section 7 of the paper: https://arxiv.org/pdf/1512.00567
|
232 |
+
# pytorch: https://github.com/pytorch/pytorch/blob/2981534f54d49fa3a9755c9b0855e7929c2527f0/aten/src/ATen/native/LossNLL.cpp#L516
|
233 |
+
# See full derivation at https://github.com/linkedin/Liger-Kernel/pull/198#issuecomment-2333753087
|
234 |
+
if label_smoothing > 0:
|
235 |
+
if HAS_WEIGHT:
|
236 |
+
smooth_loss = scaled_x_sum + eps * lse * weight_sum
|
237 |
+
else:
|
238 |
+
smooth_loss = scaled_x_sum + label_smoothing * lse
|
239 |
+
loss = loss * (1 - label_smoothing) + smooth_loss
|
240 |
+
|
241 |
+
# An auxiliary loss, z_loss
|
242 |
+
# Refer to Page14 Loss function section in the paper PaLM: https://www.jmlr.org/papers/v24/22-1144.html
|
243 |
+
z_loss = lse_square_scale * lse * lse
|
244 |
+
# Normalize the loss by the number of non-ignored elements if reduction is "mean"
|
245 |
+
if reduction == "mean":
|
246 |
+
if HAS_WEIGHT:
|
247 |
+
loss = loss / sum_non_ignore_weight
|
248 |
+
else:
|
249 |
+
loss = loss / n_non_ignore
|
250 |
+
# TODO: Implement weighted z_loss. Currently, z_loss is not scaled by weight.
|
251 |
+
z_loss = z_loss / n_non_ignore
|
252 |
+
loss += z_loss
|
253 |
+
|
254 |
+
tl.store(loss_ptr, loss)
|
255 |
+
if RETURN_Z_LOSS:
|
256 |
+
tl.store(z_loss_ptr, z_loss)
|
257 |
+
|
258 |
+
|
259 |
+
# The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
|
260 |
+
# However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
|
261 |
+
# The optimal maximum block size depends on your hardware, your kernel, and your dtype
|
262 |
+
MAX_FUSED_SIZE = 4096 if infer_device() == "xpu" else 65536 // 2 # the best size we found by manually tuning
|
263 |
+
|
264 |
+
|
265 |
+
def cross_entropy_forward(
|
266 |
+
_input,
|
267 |
+
target,
|
268 |
+
weight,
|
269 |
+
ignore_index,
|
270 |
+
lse_square_scale,
|
271 |
+
label_smoothing,
|
272 |
+
reduction,
|
273 |
+
softcap,
|
274 |
+
return_z_loss,
|
275 |
+
):
|
276 |
+
assert isinstance(return_z_loss, bool), f"return_z_loss must be True or False. Got: {return_z_loss}"
|
277 |
+
|
278 |
+
BT, V = _input.shape
|
279 |
+
n_rows = BT
|
280 |
+
|
281 |
+
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
|
282 |
+
|
283 |
+
# unreduced loss
|
284 |
+
loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device)
|
285 |
+
z_loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device) if return_z_loss else None
|
286 |
+
|
287 |
+
target_mask = target != ignore_index
|
288 |
+
n_non_ignore = target_mask.sum().item()
|
289 |
+
assert (target * target_mask).max() < _input.shape[-1], (
|
290 |
+
f"Target {target.max()} is out of bounds. Expected < {_input.shape[-1]}"
|
291 |
+
)
|
292 |
+
assert (target * target_mask).min() >= 0, f"Target {target.min()} is out of bounds. Expected >= 0"
|
293 |
+
sum_non_ignore_weight = n_non_ignore
|
294 |
+
weight_sum = 0.0
|
295 |
+
if weight is not None:
|
296 |
+
assert weight.shape[0] == V, f"If given, weight has to be a Tensor of size V. Got: {weight.shape}"
|
297 |
+
assert torch.is_floating_point(weight), (
|
298 |
+
f"If given, weight has to be a Tensor of floating point dtype. Got: {weight.dtype}"
|
299 |
+
)
|
300 |
+
sum_non_ignore_weight = torch.gather(weight, dim=0, index=target.masked_select(target_mask)).sum().item()
|
301 |
+
weight_sum = weight.sum().item()
|
302 |
+
# ensure weight is contiguous
|
303 |
+
if weight.stride(-1) != 1:
|
304 |
+
weight = weight.contiguous()
|
305 |
+
|
306 |
+
# ensure _input and target are contiguous in the last dimension
|
307 |
+
if _input.stride(-1) != 1:
|
308 |
+
_input = _input.contiguous()
|
309 |
+
if target.stride(-1) != 1:
|
310 |
+
target = target.contiguous()
|
311 |
+
|
312 |
+
# Here we use a trick to store X_ptr gradient in X_ptr so we can save memory
|
313 |
+
liger_cross_entropy_kernel[(n_rows,)](
|
314 |
+
X_ptr=_input,
|
315 |
+
X_stride=_input.stride(-2),
|
316 |
+
Y_ptr=target,
|
317 |
+
Y_stride=target.stride(-1), # always 1
|
318 |
+
weight_ptr=weight, # dummy if None
|
319 |
+
loss_ptr=loss_1d,
|
320 |
+
z_loss_ptr=z_loss_1d,
|
321 |
+
loss_stride=loss_1d.stride(-1), # always 1
|
322 |
+
n_cols=V,
|
323 |
+
n_non_ignore=n_non_ignore,
|
324 |
+
sum_non_ignore_weight=sum_non_ignore_weight,
|
325 |
+
ignore_index=ignore_index,
|
326 |
+
weight_sum=weight_sum,
|
327 |
+
lse_square_scale=lse_square_scale,
|
328 |
+
label_smoothing=label_smoothing,
|
329 |
+
reduction=reduction,
|
330 |
+
softcap=softcap,
|
331 |
+
RETURN_Z_LOSS=return_z_loss,
|
332 |
+
BLOCK_SIZE=BLOCK_SIZE,
|
333 |
+
HAS_WEIGHT=True if weight is not None else False,
|
334 |
+
HAS_SOFTCAPPING=True if softcap is not None else False,
|
335 |
+
# TODO: 32 seems to give the best performance
|
336 |
+
# Performance is quite sensitive to num_warps
|
337 |
+
num_warps=32 if not is_hip() else 16,
|
338 |
+
)
|
339 |
+
|
340 |
+
if reduction == "none":
|
341 |
+
loss = loss_1d
|
342 |
+
z_loss = z_loss_1d if return_z_loss else None
|
343 |
+
else:
|
344 |
+
loss = torch.sum(loss_1d)
|
345 |
+
z_loss = torch.sum(z_loss_1d) if return_z_loss else None
|
346 |
+
|
347 |
+
return loss, z_loss, _input
|
348 |
+
|
349 |
+
|
350 |
+
def cross_entropy_backward(_input, grad_output):
|
351 |
+
# If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time
|
352 |
+
if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
|
353 |
+
pass
|
354 |
+
# If reduction is 'none'
|
355 |
+
elif grad_output.ndim > 0:
|
356 |
+
_input = _input * grad_output.unsqueeze(dim=1)
|
357 |
+
# If reduction is ['mean', 'sum'], grad_output is just a scalar
|
358 |
+
# We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place
|
359 |
+
# for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton.
|
360 |
+
else:
|
361 |
+
BT, V = _input.shape
|
362 |
+
n_rows = BT
|
363 |
+
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
|
364 |
+
|
365 |
+
element_mul_kernel[(n_rows,)](
|
366 |
+
_input,
|
367 |
+
_input.stride(-2),
|
368 |
+
grad_output,
|
369 |
+
V,
|
370 |
+
BLOCK_SIZE=BLOCK_SIZE,
|
371 |
+
num_warps=32 if not is_hip() else 16,
|
372 |
+
)
|
373 |
+
|
374 |
+
return _input
|
375 |
+
|
376 |
+
|
377 |
+
class LigerCrossEntropyFunction(torch.autograd.Function):
|
378 |
+
"""
|
379 |
+
This class implements a custom autograd function for the Liger Cross Entropy loss.
|
380 |
+
It overrides the forward and backward methods of the torch.autograd.Function class.
|
381 |
+
"""
|
382 |
+
|
383 |
+
@staticmethod
|
384 |
+
def forward(
|
385 |
+
ctx,
|
386 |
+
_input: torch.Tensor,
|
387 |
+
target: torch.Tensor,
|
388 |
+
weight: Optional[torch.FloatTensor],
|
389 |
+
ignore_index: int = -100,
|
390 |
+
lse_square_scale: float = 0.0,
|
391 |
+
label_smoothing: float = 0.0,
|
392 |
+
reduction: str = "mean",
|
393 |
+
softcap: Optional[float] = None,
|
394 |
+
return_z_loss: bool = False,
|
395 |
+
):
|
396 |
+
"""
|
397 |
+
The forward pass of the Liger Cross Entropy loss.
|
398 |
+
|
399 |
+
Parameters:
|
400 |
+
ctx : The context object.
|
401 |
+
_input (tensor): The input tensor of shape (BT, V) where B is batch size, T is sequence length, V is vocab size.
|
402 |
+
target (tensor): The target tensor of shape (BT) where each value is in [0, V-1].
|
403 |
+
weight(Tensor, optional): a manual rescaling weight given to each class. If given, has to be a Tensor of size V and floating point dtype
|
404 |
+
ignore_index (int): The index to ignore in the target.
|
405 |
+
lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training.
|
406 |
+
label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
|
407 |
+
reduction (str): The reduction to apply to the output: "none" | "mean | "sum".
|
408 |
+
softcap (Optional[float]): The upper threshold for scaling logits to the range (-softcap, +softcap).
|
409 |
+
return_z_loss (bool): When `return_z_loss` is `True`, returns (loss, z_loss) instead of (loss, None). Default: `False`
|
410 |
+
|
411 |
+
Returns:
|
412 |
+
tuple: A tuple with the compouted losses with respect to loss and z loss. The elements are tensors or None.
|
413 |
+
"""
|
414 |
+
loss, z_loss, _input = cross_entropy_forward(
|
415 |
+
_input,
|
416 |
+
target,
|
417 |
+
weight,
|
418 |
+
ignore_index,
|
419 |
+
lse_square_scale,
|
420 |
+
label_smoothing,
|
421 |
+
reduction,
|
422 |
+
softcap,
|
423 |
+
return_z_loss,
|
424 |
+
)
|
425 |
+
# TODO: investigation
|
426 |
+
# If we don't detach the _input tensor, the memory will double
|
427 |
+
# Not sure why but seems that there will be a time both grad and value exist but in different location
|
428 |
+
ctx.save_for_backward(_input.detach())
|
429 |
+
ctx.return_z_loss = return_z_loss
|
430 |
+
|
431 |
+
return loss, z_loss
|
432 |
+
|
433 |
+
@staticmethod
|
434 |
+
def backward(ctx, grad_output, grad_ouput2):
|
435 |
+
"""
|
436 |
+
The backward pass of the Liger Cross Entropy loss.
|
437 |
+
|
438 |
+
Parameters:
|
439 |
+
ctx : The context object with saved tensors.
|
440 |
+
grad_output (tensor): The tensor containing the gradient of the loss with respect to the output.
|
441 |
+
grad_output2 (tenosr): No use.
|
442 |
+
Returns:
|
443 |
+
tuple: A tuple with the gradients with respect to the inputs. The elements are tensors or None.
|
444 |
+
"""
|
445 |
+
if ctx.return_z_loss:
|
446 |
+
del grad_ouput2 # z_loss is only for logging
|
447 |
+
|
448 |
+
(_input,) = ctx.saved_tensors
|
449 |
+
_input = cross_entropy_backward(_input, grad_output)
|
450 |
+
return (
|
451 |
+
_input,
|
452 |
+
None,
|
453 |
+
None,
|
454 |
+
None,
|
455 |
+
None,
|
456 |
+
None,
|
457 |
+
None,
|
458 |
+
None,
|
459 |
+
None,
|
460 |
+
)
|
build/torch-universal/liger_kernels/dyt.py
ADDED
@@ -0,0 +1,225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import operator
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import triton
|
5 |
+
import triton.language as tl
|
6 |
+
|
7 |
+
from utils import calculate_settings
|
8 |
+
from utils import compare_version
|
9 |
+
from utils import ensure_contiguous
|
10 |
+
from utils import infer_device
|
11 |
+
|
12 |
+
if compare_version("triton", operator.ge, "3.0.0"):
|
13 |
+
try:
|
14 |
+
# typical import path with dispatch available
|
15 |
+
from triton.language.extra.libdevice import tanh
|
16 |
+
except ModuleNotFoundError:
|
17 |
+
# for working with NGC containers
|
18 |
+
from triton.language.extra.cuda.libdevice import tanh
|
19 |
+
else:
|
20 |
+
from triton.language.math import tanh
|
21 |
+
|
22 |
+
|
23 |
+
@triton.jit
|
24 |
+
def _dyt_fwd_kernel(
|
25 |
+
x_ptr,
|
26 |
+
x_row_stride,
|
27 |
+
alpha_ptr,
|
28 |
+
gamma_ptr,
|
29 |
+
beta_ptr,
|
30 |
+
y_ptr,
|
31 |
+
y_row_stride,
|
32 |
+
n_cols,
|
33 |
+
BLOCK_SIZE: tl.constexpr,
|
34 |
+
):
|
35 |
+
"""
|
36 |
+
Reference:
|
37 |
+
https://arxiv.org/abs/2503.10622
|
38 |
+
|
39 |
+
Shapes:
|
40 |
+
- x: (BT, C)
|
41 |
+
- alpha: (1)
|
42 |
+
- gamma: (C)
|
43 |
+
- beta: (C)
|
44 |
+
"""
|
45 |
+
row_idx = tl.program_id(0)
|
46 |
+
offsets = tl.arange(0, BLOCK_SIZE)
|
47 |
+
mask = offsets < n_cols
|
48 |
+
|
49 |
+
x_ptr += row_idx * x_row_stride
|
50 |
+
y_ptr += row_idx * y_row_stride
|
51 |
+
|
52 |
+
alpha = tl.load(alpha_ptr)
|
53 |
+
gamma = tl.load(gamma_ptr + offsets, mask=mask)
|
54 |
+
beta = tl.load(beta_ptr + offsets, mask=mask)
|
55 |
+
x = tl.load(x_ptr + offsets, mask=mask)
|
56 |
+
y = gamma * tanh((alpha * x).cast(tl.float32)) + beta
|
57 |
+
tl.store(y_ptr + offsets, y, mask=mask)
|
58 |
+
|
59 |
+
|
60 |
+
@triton.jit
|
61 |
+
def _dyt_bwd_kernel(
|
62 |
+
x_ptr,
|
63 |
+
x_row_stride,
|
64 |
+
dy_ptr,
|
65 |
+
dy_row_stride,
|
66 |
+
dx_ptr,
|
67 |
+
dx_row_stride,
|
68 |
+
alpha_ptr,
|
69 |
+
dalpha_ptr,
|
70 |
+
gamma_ptr,
|
71 |
+
dgamma_ptr,
|
72 |
+
dgamma_row_stride,
|
73 |
+
n_cols,
|
74 |
+
n_rows,
|
75 |
+
ROWS_PER_PROGRAM: tl.constexpr,
|
76 |
+
BLOCK_SIZE: tl.constexpr,
|
77 |
+
):
|
78 |
+
"""
|
79 |
+
Reference:
|
80 |
+
https://arxiv.org/abs/2503.10622
|
81 |
+
|
82 |
+
Shapes:
|
83 |
+
- x: (BT, C)
|
84 |
+
- alpha: (1)
|
85 |
+
- gamma: (C)
|
86 |
+
- dx: (BT, C)
|
87 |
+
- dy: (BT, C)
|
88 |
+
- dgamma: (sm_count, C)
|
89 |
+
- dalpha: (sm_count,)
|
90 |
+
"""
|
91 |
+
# d(gamma * tanh(alpha * x) + beta) / dx
|
92 |
+
# = gamma * (1 - tanh^2(alpha * x)) * alpha
|
93 |
+
# d(gamma * tanh(alpha * x) + beta) / dalpha
|
94 |
+
# = gamma * (1 - tanh^2(alpha * x)) * x
|
95 |
+
# d(gamma * tanh(alpha * x) + beta) / dgamma
|
96 |
+
# = tanh(alpha * x)
|
97 |
+
# d(gamma * tanh(alpha * x)) / dbeta = 1
|
98 |
+
pid = tl.program_id(0)
|
99 |
+
|
100 |
+
row_start = pid * ROWS_PER_PROGRAM
|
101 |
+
row_end = min((pid + 1) * ROWS_PER_PROGRAM, n_rows)
|
102 |
+
offsets = tl.arange(0, BLOCK_SIZE)
|
103 |
+
mask = offsets < n_cols
|
104 |
+
|
105 |
+
dalpha = 0.0
|
106 |
+
dgamma = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
|
107 |
+
|
108 |
+
x_ptr += row_start * x_row_stride
|
109 |
+
dx_ptr += row_start * dx_row_stride
|
110 |
+
dy_ptr += row_start * dy_row_stride
|
111 |
+
alpha = tl.load(alpha_ptr)
|
112 |
+
gamma = tl.load(gamma_ptr + offsets, mask=mask, other=0.0)
|
113 |
+
|
114 |
+
for _ in tl.range(row_start, row_end):
|
115 |
+
dy = tl.load(dy_ptr + offsets, mask=mask, other=0.0)
|
116 |
+
x = tl.load(x_ptr + offsets, mask=mask, other=0.0)
|
117 |
+
tanh_ax = tanh((alpha * x).cast(tl.float32))
|
118 |
+
sech2_ax = 1 - tanh_ax * tanh_ax
|
119 |
+
|
120 |
+
dx = dy * gamma * sech2_ax * alpha
|
121 |
+
dalpha += tl.sum(dy * gamma * sech2_ax * x)
|
122 |
+
dgamma += dy * tanh_ax
|
123 |
+
tl.store(dx_ptr + offsets, dx, mask=mask)
|
124 |
+
|
125 |
+
dy_ptr += dy_row_stride
|
126 |
+
x_ptr += x_row_stride
|
127 |
+
dx_ptr += dx_row_stride
|
128 |
+
|
129 |
+
tl.store(dgamma_ptr + pid * dgamma_row_stride + offsets, dgamma, mask=mask)
|
130 |
+
tl.store(dalpha_ptr + pid, dalpha)
|
131 |
+
|
132 |
+
pass
|
133 |
+
|
134 |
+
|
135 |
+
def liger_dyt_fwd(x, alpha, gamma, beta):
|
136 |
+
shape = x.shape
|
137 |
+
dim = shape[-1]
|
138 |
+
x = x.view(-1, dim)
|
139 |
+
n_rows, n_cols = x.shape
|
140 |
+
y = torch.empty_like(x)
|
141 |
+
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
142 |
+
_dyt_fwd_kernel[(n_rows,)](
|
143 |
+
x_ptr=x,
|
144 |
+
alpha_ptr=alpha,
|
145 |
+
gamma_ptr=gamma,
|
146 |
+
beta_ptr=beta,
|
147 |
+
y_ptr=y,
|
148 |
+
x_row_stride=x.stride(0),
|
149 |
+
y_row_stride=y.stride(0),
|
150 |
+
n_cols=n_cols,
|
151 |
+
BLOCK_SIZE=BLOCK_SIZE,
|
152 |
+
num_warps=num_warps,
|
153 |
+
)
|
154 |
+
return y.view(*shape)
|
155 |
+
|
156 |
+
|
157 |
+
def liger_dyt_bwd(dy, x, alpha, gamma):
|
158 |
+
shape = dy.shape
|
159 |
+
dtype = x.dtype
|
160 |
+
dim = shape[-1]
|
161 |
+
dy = dy.view(-1, dim)
|
162 |
+
x = x.view(-1, dim)
|
163 |
+
n_rows, n_cols = dy.shape
|
164 |
+
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
165 |
+
sm_count = 1
|
166 |
+
device = infer_device()
|
167 |
+
if device == "cuda":
|
168 |
+
sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count
|
169 |
+
elif device == "xpu":
|
170 |
+
sm_count = torch.xpu.get_device_properties(x.device).gpu_subslice_count
|
171 |
+
if n_cols > BLOCK_SIZE:
|
172 |
+
raise RuntimeError(
|
173 |
+
f"Feature dimension {dim} exceeds maximum supported size of {BLOCK_SIZE}. Consider using a smaller feature dimension."
|
174 |
+
)
|
175 |
+
|
176 |
+
dx = torch.empty_like(x, dtype=torch.float32)
|
177 |
+
_dalpha = torch.empty((sm_count,), dtype=torch.float32, device=x.device)
|
178 |
+
_dgamma = torch.empty((sm_count, n_cols), dtype=torch.float32, device=x.device)
|
179 |
+
|
180 |
+
grid = (sm_count,)
|
181 |
+
rows_per_program = triton.cdiv(n_rows, sm_count)
|
182 |
+
_dyt_bwd_kernel[grid](
|
183 |
+
x_ptr=x,
|
184 |
+
x_row_stride=x.stride(0),
|
185 |
+
dy_ptr=dy,
|
186 |
+
dy_row_stride=dy.stride(0),
|
187 |
+
dx_ptr=dx,
|
188 |
+
dx_row_stride=dx.stride(0),
|
189 |
+
alpha_ptr=alpha,
|
190 |
+
dalpha_ptr=_dalpha,
|
191 |
+
gamma_ptr=gamma,
|
192 |
+
dgamma_ptr=_dgamma,
|
193 |
+
dgamma_row_stride=_dgamma.stride(0),
|
194 |
+
n_cols=n_cols,
|
195 |
+
n_rows=n_rows,
|
196 |
+
ROWS_PER_PROGRAM=rows_per_program,
|
197 |
+
BLOCK_SIZE=BLOCK_SIZE,
|
198 |
+
num_warps=num_warps,
|
199 |
+
)
|
200 |
+
dalpha = _dalpha.sum(dim=0, keepdim=True).to(dtype)
|
201 |
+
dgamma = _dgamma.sum(dim=0).to(dtype)
|
202 |
+
dbeta = dy.sum(dim=0).to(dtype)
|
203 |
+
return dx.view(*shape), dalpha, dgamma, dbeta
|
204 |
+
|
205 |
+
|
206 |
+
class LigerDyTFunction(torch.autograd.Function):
|
207 |
+
@staticmethod
|
208 |
+
@ensure_contiguous
|
209 |
+
def forward(ctx, x, alpha, gamma, beta):
|
210 |
+
y = liger_dyt_fwd(x, alpha, gamma, beta)
|
211 |
+
ctx.save_for_backward(x, alpha, gamma)
|
212 |
+
return y
|
213 |
+
|
214 |
+
@staticmethod
|
215 |
+
@ensure_contiguous
|
216 |
+
def backward(ctx, grad_output):
|
217 |
+
x, alpha, gamma = ctx.saved_tensors
|
218 |
+
dx, dalpha, dgamma, dbeta = liger_dyt_bwd(
|
219 |
+
grad_output,
|
220 |
+
x,
|
221 |
+
alpha,
|
222 |
+
gamma,
|
223 |
+
)
|
224 |
+
|
225 |
+
return (dx, dalpha, dgamma, dbeta)
|
build/torch-universal/liger_kernels/fused_linear_cross_entropy.py
ADDED
@@ -0,0 +1,283 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import triton
|
3 |
+
|
4 |
+
from cross_entropy import liger_cross_entropy_kernel
|
5 |
+
from utils import amp_custom_bwd
|
6 |
+
from utils import amp_custom_fwd
|
7 |
+
from utils import element_mul_kernel
|
8 |
+
from utils import is_hip
|
9 |
+
|
10 |
+
# The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
|
11 |
+
# However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
|
12 |
+
# The optimal maximum block size depends on your hardware, your kernel, and your dtype
|
13 |
+
MAX_FUSED_SIZE = 65536 // 2
|
14 |
+
|
15 |
+
|
16 |
+
def fused_linear_cross_entropy_forward(
|
17 |
+
_input,
|
18 |
+
weight,
|
19 |
+
target,
|
20 |
+
ce_weight=None,
|
21 |
+
bias=None,
|
22 |
+
ignore_index=-100,
|
23 |
+
lse_square_scale=0.0,
|
24 |
+
label_smoothing=0.0,
|
25 |
+
reduction="mean",
|
26 |
+
softcap=None,
|
27 |
+
return_z_loss=False,
|
28 |
+
):
|
29 |
+
assert isinstance(return_z_loss, bool), f"return_z_loss must be True or False. Got: {return_z_loss}"
|
30 |
+
device = _input.device
|
31 |
+
|
32 |
+
# inputs have shape: BT x H
|
33 |
+
# materialized activations will have shape: BT x V
|
34 |
+
# the increase in memory = BT x V
|
35 |
+
# reduction can be achieved by partitioning the number of tokens BT into smaller chunks.
|
36 |
+
# for ex: if we were to achieve the same memory consumption as BT x H, then the chunk size should be:
|
37 |
+
# inc_factor = (V+H-1)//H, chunk_size = (BT + inc_factor - 1)//inc_factor
|
38 |
+
# for ex: BT = 4096*4, V = 32000, H = 4096 ==> inc_factor = 8, chunk_size = 2048
|
39 |
+
BT, H = _input.shape
|
40 |
+
V = weight.shape[0]
|
41 |
+
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
|
42 |
+
|
43 |
+
inc_factor = triton.cdiv(V, H) # (V + H - 1) // H
|
44 |
+
chunk_size = triton.next_power_of_2(triton.cdiv(BT, inc_factor)) # (BT + inc_factor - 1) // inc_factor
|
45 |
+
num_chunks = triton.cdiv(BT, chunk_size) # (BT + chunk_size - 1) // chunk_size
|
46 |
+
|
47 |
+
grad_weight = torch.zeros_like(weight, device=device) if weight.requires_grad else None
|
48 |
+
grad_input = torch.zeros_like(_input, device=device)
|
49 |
+
grad_bias = torch.zeros_like(bias, device=device) if bias is not None else None
|
50 |
+
# we use fp32 for loss accumulator
|
51 |
+
loss_1d = torch.zeros(BT, dtype=torch.float32, device=device)
|
52 |
+
z_loss_1d = torch.zeros(BT, dtype=_input.dtype, device=_input.device) if return_z_loss else None
|
53 |
+
|
54 |
+
# TODO: evaluate how CUDA synchronization caused by .item() affects the speed
|
55 |
+
target_mask = target != ignore_index
|
56 |
+
total_n_non_ignore = target_mask.sum().item()
|
57 |
+
total_sum_non_ignore_ce_weight = total_n_non_ignore
|
58 |
+
ce_weight_sum = 0.0
|
59 |
+
if ce_weight is not None:
|
60 |
+
assert ce_weight.shape[0] == V, f"If given, weight has to be a Tensor of size V. Got: {ce_weight.shape}"
|
61 |
+
assert torch.is_floating_point(ce_weight), (
|
62 |
+
f"If given, weight has to be a Tensor of floating point dtype. Got: {ce_weight.dtype}"
|
63 |
+
)
|
64 |
+
total_sum_non_ignore_ce_weight = (
|
65 |
+
torch.gather(ce_weight, dim=0, index=target.masked_select(target_mask)).sum().item()
|
66 |
+
)
|
67 |
+
ce_weight_sum = ce_weight.sum().item()
|
68 |
+
if ce_weight.stride(-1) != 1:
|
69 |
+
ce_weight = ce_weight.contiguous()
|
70 |
+
|
71 |
+
for chunk_id in range(num_chunks):
|
72 |
+
start_idx = chunk_id * chunk_size
|
73 |
+
end_idx = min((chunk_id + 1) * chunk_size, BT)
|
74 |
+
_input_chunk = _input[start_idx:end_idx] # chunk_size x H
|
75 |
+
|
76 |
+
# when doing matmul, use the original precision
|
77 |
+
logits_chunk = _input_chunk @ weight.t() # chunk_size x V
|
78 |
+
if bias is not None:
|
79 |
+
logits_chunk = logits_chunk + bias
|
80 |
+
|
81 |
+
target_chunk = target[start_idx:end_idx] # chunk_size,
|
82 |
+
|
83 |
+
n_rows = logits_chunk.shape[0]
|
84 |
+
|
85 |
+
# unreduced loss
|
86 |
+
loss_1d_slice = loss_1d[start_idx:end_idx] # chunk_size,
|
87 |
+
z_loss_1d_slice = z_loss_1d[start_idx:end_idx] if return_z_loss else None
|
88 |
+
|
89 |
+
# ensure _input and target are contiguous
|
90 |
+
logits_chunk = logits_chunk.contiguous()
|
91 |
+
target_chunk = target_chunk.contiguous()
|
92 |
+
|
93 |
+
# Here we calculate the gradient of logits_chunk in place so we can save memory.
|
94 |
+
liger_cross_entropy_kernel[(n_rows,)](
|
95 |
+
X_ptr=logits_chunk,
|
96 |
+
X_stride=logits_chunk.stride(-2),
|
97 |
+
Y_ptr=target_chunk,
|
98 |
+
Y_stride=target_chunk.stride(-1), # always 1
|
99 |
+
weight_ptr=ce_weight,
|
100 |
+
loss_ptr=loss_1d_slice,
|
101 |
+
z_loss_ptr=z_loss_1d_slice,
|
102 |
+
loss_stride=loss_1d_slice.stride(-1), # always 1
|
103 |
+
n_cols=V,
|
104 |
+
n_non_ignore=total_n_non_ignore,
|
105 |
+
sum_non_ignore_weight=total_sum_non_ignore_ce_weight,
|
106 |
+
weight_sum=ce_weight_sum,
|
107 |
+
ignore_index=ignore_index,
|
108 |
+
lse_square_scale=lse_square_scale,
|
109 |
+
label_smoothing=label_smoothing,
|
110 |
+
reduction=reduction,
|
111 |
+
softcap=softcap,
|
112 |
+
RETURN_Z_LOSS=return_z_loss,
|
113 |
+
HAS_WEIGHT=True if ce_weight is not None else False,
|
114 |
+
HAS_SOFTCAPPING=True if softcap is not None else False,
|
115 |
+
BLOCK_SIZE=BLOCK_SIZE,
|
116 |
+
num_warps=32 if not is_hip() else 16,
|
117 |
+
)
|
118 |
+
|
119 |
+
loss_1d[start_idx:end_idx] = loss_1d_slice
|
120 |
+
if return_z_loss:
|
121 |
+
z_loss_1d[start_idx:end_idx] = z_loss_1d_slice
|
122 |
+
grad_logits_chunk = logits_chunk # chunk_size x V
|
123 |
+
|
124 |
+
grad_input[start_idx:end_idx] = grad_logits_chunk @ weight
|
125 |
+
|
126 |
+
if grad_weight is not None:
|
127 |
+
torch.addmm(
|
128 |
+
input=grad_weight,
|
129 |
+
mat1=logits_chunk.t().to(
|
130 |
+
_input_chunk.dtype
|
131 |
+
), # In an autocast scenario without bias, differing logits_chunk data types will cause an addmm operation error.
|
132 |
+
mat2=_input_chunk,
|
133 |
+
out=grad_weight,
|
134 |
+
alpha=1.0,
|
135 |
+
beta=1.0,
|
136 |
+
)
|
137 |
+
|
138 |
+
if bias is not None:
|
139 |
+
torch.add(
|
140 |
+
input=grad_bias,
|
141 |
+
other=logits_chunk.sum(dim=0),
|
142 |
+
out=grad_bias,
|
143 |
+
alpha=1.0,
|
144 |
+
)
|
145 |
+
|
146 |
+
# Need extra calculations for backward if reduction=='none'. Not supporting reduction='none' now.
|
147 |
+
# if reduction == "none":
|
148 |
+
# loss = loss_1d
|
149 |
+
# z_loss = z_loss_1d if return_z_loss else None
|
150 |
+
|
151 |
+
else:
|
152 |
+
loss = torch.sum(loss_1d)
|
153 |
+
z_loss = torch.sum(z_loss_1d) if return_z_loss else None
|
154 |
+
return loss, z_loss, grad_input, grad_weight, grad_bias
|
155 |
+
|
156 |
+
|
157 |
+
def fused_linear_cross_entropy_backward(grad_output, grad_input, grad_weight, grad_bias):
|
158 |
+
# If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time
|
159 |
+
if not torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
|
160 |
+
# We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place
|
161 |
+
# for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton.
|
162 |
+
BT, H = grad_input.shape
|
163 |
+
n_rows = BT
|
164 |
+
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(H))
|
165 |
+
|
166 |
+
element_mul_kernel[(n_rows,)](
|
167 |
+
grad_input,
|
168 |
+
grad_input.stride(-2),
|
169 |
+
grad_output,
|
170 |
+
H,
|
171 |
+
BLOCK_SIZE=BLOCK_SIZE,
|
172 |
+
num_warps=32 if not is_hip() else 16,
|
173 |
+
)
|
174 |
+
|
175 |
+
# handle grad_weight
|
176 |
+
if grad_weight is not None:
|
177 |
+
V, H = grad_weight.shape
|
178 |
+
n_rows = V
|
179 |
+
|
180 |
+
element_mul_kernel[(n_rows,)](
|
181 |
+
grad_weight,
|
182 |
+
grad_weight.stride(-2),
|
183 |
+
grad_output,
|
184 |
+
H,
|
185 |
+
BLOCK_SIZE=BLOCK_SIZE,
|
186 |
+
num_warps=32 if not is_hip() else 16,
|
187 |
+
)
|
188 |
+
|
189 |
+
if grad_bias is not None:
|
190 |
+
V = grad_bias.shape[0]
|
191 |
+
n_rows = V
|
192 |
+
|
193 |
+
element_mul_kernel[(n_rows,)](
|
194 |
+
grad_bias,
|
195 |
+
grad_bias.stride(-1),
|
196 |
+
grad_output,
|
197 |
+
1,
|
198 |
+
BLOCK_SIZE=BLOCK_SIZE,
|
199 |
+
num_warps=32 if not is_hip() else 16,
|
200 |
+
)
|
201 |
+
return grad_input, grad_weight, grad_bias
|
202 |
+
|
203 |
+
|
204 |
+
class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
|
205 |
+
@staticmethod
|
206 |
+
@amp_custom_fwd
|
207 |
+
def forward(
|
208 |
+
ctx,
|
209 |
+
_input,
|
210 |
+
weight,
|
211 |
+
target,
|
212 |
+
bias=None,
|
213 |
+
ce_weight=None,
|
214 |
+
ignore_index=-100,
|
215 |
+
lse_square_scale=0.0,
|
216 |
+
label_smoothing=0.0,
|
217 |
+
reduction="mean",
|
218 |
+
softcap=None,
|
219 |
+
return_z_loss: bool = False,
|
220 |
+
):
|
221 |
+
"""
|
222 |
+
Fusing the last linear layer with cross-entropy loss
|
223 |
+
Reference: https://github.com/mgmalek/efficient_cross_entropy
|
224 |
+
|
225 |
+
Handle the forward and backward pass of the final linear layer via cross-entropy loss by avoiding
|
226 |
+
the materialization of the large logits tensor. Since Cross Entropy Loss is the last layer, we can
|
227 |
+
compute the gradient at the forward pass. By doing so, we don't have to store the _input and target
|
228 |
+
for the backward pass.
|
229 |
+
|
230 |
+
_input: (B*T, H) where B is batch size, T is sequence length, H is hidden dimension.
|
231 |
+
target: (B*T) where each value is in [0, V-1]
|
232 |
+
weight: (V, H) where V is the number of classes
|
233 |
+
bias: (V) where V is the number of classes
|
234 |
+
ce_weight: a manual rescaling weight given to each class. If given, has to be a Tensor of size V and floating point dtype
|
235 |
+
ignore_index: the index to ignore in the target
|
236 |
+
label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
|
237 |
+
reduction: reduction to apply
|
238 |
+
"""
|
239 |
+
|
240 |
+
loss, z_loss, grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_forward(
|
241 |
+
_input=_input,
|
242 |
+
weight=weight,
|
243 |
+
target=target,
|
244 |
+
bias=bias,
|
245 |
+
ce_weight=ce_weight,
|
246 |
+
ignore_index=ignore_index,
|
247 |
+
lse_square_scale=lse_square_scale,
|
248 |
+
label_smoothing=label_smoothing,
|
249 |
+
reduction=reduction,
|
250 |
+
softcap=softcap,
|
251 |
+
return_z_loss=return_z_loss,
|
252 |
+
)
|
253 |
+
# downcast to dtype and store for backward
|
254 |
+
ctx.save_for_backward(
|
255 |
+
grad_input.detach(),
|
256 |
+
grad_weight.detach() if grad_weight is not None else None,
|
257 |
+
grad_bias.detach() if bias is not None else None,
|
258 |
+
)
|
259 |
+
ctx.return_z_loss = return_z_loss
|
260 |
+
return loss, z_loss
|
261 |
+
|
262 |
+
@staticmethod
|
263 |
+
@amp_custom_bwd
|
264 |
+
def backward(ctx, grad_output, grad_output2):
|
265 |
+
if ctx.return_z_loss:
|
266 |
+
del grad_output2 # z_loss is only for logging
|
267 |
+
(grad_input, grad_weight, grad_bias) = ctx.saved_tensors
|
268 |
+
grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_backward(
|
269 |
+
grad_output, grad_input, grad_weight, grad_bias
|
270 |
+
)
|
271 |
+
return (
|
272 |
+
grad_input,
|
273 |
+
grad_weight,
|
274 |
+
None,
|
275 |
+
grad_bias,
|
276 |
+
None,
|
277 |
+
None,
|
278 |
+
None,
|
279 |
+
None,
|
280 |
+
None,
|
281 |
+
None,
|
282 |
+
None,
|
283 |
+
)
|
build/torch-universal/liger_kernels/geglu.py
ADDED
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import operator
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import triton
|
5 |
+
import triton.language as tl
|
6 |
+
|
7 |
+
from utils import calculate_settings
|
8 |
+
from utils import compare_version
|
9 |
+
from utils import ensure_contiguous
|
10 |
+
|
11 |
+
if compare_version("triton", operator.ge, "3.0.0"):
|
12 |
+
try:
|
13 |
+
# typical import path with dispatch available
|
14 |
+
from triton.language.extra.libdevice import tanh
|
15 |
+
except ModuleNotFoundError:
|
16 |
+
# for working with NGC containers
|
17 |
+
from triton.language.extra.cuda.libdevice import tanh
|
18 |
+
else:
|
19 |
+
from triton.language.math import tanh
|
20 |
+
|
21 |
+
|
22 |
+
@triton.jit
|
23 |
+
def _geglu_tanh_forward_kernel(a, b, c, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr):
|
24 |
+
program_id = tl.program_id(0).to(tl.int64)
|
25 |
+
|
26 |
+
# locate start index
|
27 |
+
a += program_id * stride
|
28 |
+
b += program_id * stride
|
29 |
+
c += program_id * stride
|
30 |
+
|
31 |
+
col_offsets = tl.arange(0, BLOCK_SIZE)
|
32 |
+
mask = col_offsets < n_cols
|
33 |
+
a_row = tl.load(a + col_offsets, mask=mask, other=0).to(tl.float32)
|
34 |
+
b_row = tl.load(b + col_offsets, mask=mask, other=0)
|
35 |
+
|
36 |
+
# tanh approximation form of GELU is computed with:
|
37 |
+
# 0.5 * a * (1 + tanh(sqrt(2 / pi) * (a + 0.044715 * a^3)))
|
38 |
+
sqrt_2_over_pi = 0.7978845608028654 # sqrt(2 / pi)
|
39 |
+
a_cubed = a_row * a_row * a_row
|
40 |
+
tanh_arg = sqrt_2_over_pi * (a_row + 0.044715 * a_cubed)
|
41 |
+
tanh_result = tanh(tanh_arg)
|
42 |
+
geglu_a = 0.5 * a_row * (1 + tanh_result)
|
43 |
+
c_row = geglu_a * b_row
|
44 |
+
tl.store(c + col_offsets, c_row, mask=mask)
|
45 |
+
|
46 |
+
|
47 |
+
@triton.jit
|
48 |
+
def _geglu_tanh_backward_kernel(dc, a, b, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr):
|
49 |
+
program_id = tl.program_id(0).to(tl.int64)
|
50 |
+
|
51 |
+
# locate start index
|
52 |
+
dc += program_id * stride
|
53 |
+
a += program_id * stride
|
54 |
+
b += program_id * stride
|
55 |
+
|
56 |
+
col_offsets = tl.arange(0, BLOCK_SIZE)
|
57 |
+
mask = col_offsets < n_cols
|
58 |
+
|
59 |
+
dc_row = tl.load(dc + col_offsets, mask=mask, other=0)
|
60 |
+
a_row = tl.load(a + col_offsets, mask=mask, other=0).to(tl.float32)
|
61 |
+
b_row = tl.load(b + col_offsets, mask=mask, other=0)
|
62 |
+
|
63 |
+
# recomputation to save memory
|
64 |
+
sqrt_2_over_pi = 0.7978845608028654 # sqrt(2 / pi)
|
65 |
+
a_cubed = a_row * a_row * a_row
|
66 |
+
tanh_arg = sqrt_2_over_pi * (a_row + 0.044715 * a_cubed)
|
67 |
+
tanh_result = tanh(tanh_arg)
|
68 |
+
geglu_a = 0.5 * a_row * (1 + tanh_result)
|
69 |
+
|
70 |
+
db_row = dc_row * geglu_a
|
71 |
+
|
72 |
+
# Gradient w.r.t. a can be computed with:
|
73 |
+
# b * (0.5 * (1 + tanh(z)) + 0.5 * a * (1 - tanh(z)^2) * (sqrt(2/pi) * (1 + 3 * 0.044715 * a^2)))
|
74 |
+
# where z = sqrt(2/pi) * (a + 0.044715 * a^3)
|
75 |
+
term1 = 0.5 * (1 + tanh_result)
|
76 |
+
tanh_sq = tanh_result * tanh_result
|
77 |
+
term2 = 0.5 * a_row * (1 - tanh_sq) * (sqrt_2_over_pi * (1 + 3 * 0.044715 * a_row * a_row))
|
78 |
+
da_row = dc_row * b_row * (term1 + term2)
|
79 |
+
|
80 |
+
tl.store(a + col_offsets, da_row, mask=mask)
|
81 |
+
tl.store(b + col_offsets, db_row, mask=mask)
|
82 |
+
|
83 |
+
|
84 |
+
def geglu_forward(a, b):
|
85 |
+
ori_shape = a.shape
|
86 |
+
|
87 |
+
n_cols = ori_shape[-1]
|
88 |
+
a = a.view(-1, n_cols)
|
89 |
+
b = b.view(-1, n_cols)
|
90 |
+
c = torch.empty_like(a)
|
91 |
+
n_rows = a.shape[0]
|
92 |
+
|
93 |
+
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
94 |
+
|
95 |
+
_geglu_tanh_forward_kernel[(n_rows,)](
|
96 |
+
a,
|
97 |
+
b,
|
98 |
+
c,
|
99 |
+
c.stride(-2),
|
100 |
+
n_cols=n_cols,
|
101 |
+
BLOCK_SIZE=BLOCK_SIZE,
|
102 |
+
num_warps=num_warps,
|
103 |
+
)
|
104 |
+
return a, b, c.view(*ori_shape)
|
105 |
+
|
106 |
+
|
107 |
+
def geglu_backward(a, b, dc):
|
108 |
+
ori_shape = dc.shape
|
109 |
+
n_cols = ori_shape[-1]
|
110 |
+
dc = dc.view(-1, n_cols)
|
111 |
+
n_rows = dc.shape[0]
|
112 |
+
|
113 |
+
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
114 |
+
|
115 |
+
_geglu_tanh_backward_kernel[(n_rows,)](
|
116 |
+
dc,
|
117 |
+
a,
|
118 |
+
b,
|
119 |
+
dc.stride(-2),
|
120 |
+
n_cols=n_cols,
|
121 |
+
BLOCK_SIZE=BLOCK_SIZE,
|
122 |
+
num_warps=num_warps,
|
123 |
+
)
|
124 |
+
|
125 |
+
return a.view(*ori_shape), b.view(*ori_shape)
|
126 |
+
|
127 |
+
|
128 |
+
class LigerGELUMulFunction(torch.autograd.Function):
|
129 |
+
@staticmethod
|
130 |
+
@ensure_contiguous
|
131 |
+
def forward(ctx, a, b):
|
132 |
+
a, b, c = geglu_forward(a, b)
|
133 |
+
ctx.save_for_backward(a, b)
|
134 |
+
return c
|
135 |
+
|
136 |
+
@staticmethod
|
137 |
+
@ensure_contiguous
|
138 |
+
def backward(ctx, dc):
|
139 |
+
a, b = ctx.saved_tensors
|
140 |
+
a, b = geglu_backward(a, b, dc)
|
141 |
+
return a, b
|
build/torch-universal/liger_kernels/group_norm.py
ADDED
@@ -0,0 +1,305 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import operator
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import triton
|
5 |
+
import triton.language as tl
|
6 |
+
|
7 |
+
from utils import compare_version
|
8 |
+
from utils import ensure_contiguous
|
9 |
+
|
10 |
+
if compare_version("triton", operator.ge, "3.0.0"):
|
11 |
+
try:
|
12 |
+
# typical import path with dispatch available
|
13 |
+
from triton.language.extra.libdevice import rsqrt
|
14 |
+
except ModuleNotFoundError:
|
15 |
+
# for working with NGC containers
|
16 |
+
from triton.language.extra.cuda.libdevice import rsqrt
|
17 |
+
else:
|
18 |
+
from triton.language.math import rsqrt
|
19 |
+
|
20 |
+
MAX_FUSED_SIZE = 65536
|
21 |
+
|
22 |
+
|
23 |
+
@triton.jit
|
24 |
+
def _group_norm_forward_kernel(
|
25 |
+
Y_ptr, # pointer to output, shape (n_rows, n_groups, hidden_size)
|
26 |
+
Y_row_stride, # stride of each row in output
|
27 |
+
Y_col_stride, # stride of each column in output
|
28 |
+
X_ptr, # pointer to input, shape (n_rows, n_groups, hidden_size)
|
29 |
+
X_row_stride, # stride of each row in input
|
30 |
+
X_col_stride, # stride of each column in input
|
31 |
+
Mean_ptr, # pointer to mean, shape (n_rows, n_groups)
|
32 |
+
Mean_row_stride, # stride of each row in mean
|
33 |
+
Mean_col_stride, # stride of each column in mean
|
34 |
+
RSTD_ptr, # pointer to rstd, shape (n_rows, n_groups)
|
35 |
+
RSTD_row_stride, # stride of each row in rstd
|
36 |
+
RSTD_col_stride, # stride of each column in rstd
|
37 |
+
W_ptr, # pointer to W
|
38 |
+
B_ptr, # pointer to B
|
39 |
+
hidden_size, # hidden size of X
|
40 |
+
channels_per_group, # the number of channels per group
|
41 |
+
eps,
|
42 |
+
BLOCK_SIZE: tl.constexpr,
|
43 |
+
):
|
44 |
+
"""
|
45 |
+
References:
|
46 |
+
https://nn.labml.ai/normalization/group_norm/index.html
|
47 |
+
"""
|
48 |
+
batch_idx = tl.program_id(0)
|
49 |
+
group_idx = tl.program_id(1)
|
50 |
+
|
51 |
+
X_ptr += batch_idx * X_row_stride + group_idx * X_col_stride
|
52 |
+
Y_ptr += batch_idx * Y_row_stride + group_idx * Y_col_stride
|
53 |
+
|
54 |
+
block_range = tl.arange(0, BLOCK_SIZE)
|
55 |
+
|
56 |
+
# Compute mean and variance using the online algorithm
|
57 |
+
s = 0.0
|
58 |
+
squared_sum = 0.0
|
59 |
+
for i in tl.range(0, hidden_size, BLOCK_SIZE):
|
60 |
+
hidden_size_offsets = i + block_range
|
61 |
+
mask = hidden_size_offsets < hidden_size
|
62 |
+
X = tl.load(X_ptr + hidden_size_offsets, mask=mask, other=0.0)
|
63 |
+
s += tl.sum(X)
|
64 |
+
# X**2
|
65 |
+
squared_sum += tl.sum(X * X)
|
66 |
+
|
67 |
+
m = s / hidden_size
|
68 |
+
|
69 |
+
# variance = E[X**2] - E[X]**2
|
70 |
+
variance = (squared_sum / hidden_size) - (m * m)
|
71 |
+
|
72 |
+
# 1/std
|
73 |
+
rstd = rsqrt(variance + eps)
|
74 |
+
|
75 |
+
# Normalize
|
76 |
+
hidden_size_per_channel = hidden_size // channels_per_group
|
77 |
+
for channel_idx in tl.range(group_idx * channels_per_group, (group_idx + 1) * channels_per_group):
|
78 |
+
W = tl.load(W_ptr + channel_idx)
|
79 |
+
B = tl.load(B_ptr + channel_idx)
|
80 |
+
for i in range(0, hidden_size_per_channel, BLOCK_SIZE):
|
81 |
+
hidden_size_offsets = i + block_range
|
82 |
+
mask = hidden_size_offsets < hidden_size_per_channel
|
83 |
+
X = tl.load(X_ptr + hidden_size_offsets, mask=mask, other=m)
|
84 |
+
Y = (X - m) * rstd * W + B
|
85 |
+
tl.store(Y_ptr + hidden_size_offsets, Y, mask=mask)
|
86 |
+
|
87 |
+
X_ptr += hidden_size_per_channel
|
88 |
+
Y_ptr += hidden_size_per_channel
|
89 |
+
|
90 |
+
tl.store(Mean_ptr + batch_idx * Mean_row_stride + group_idx * Mean_col_stride, m)
|
91 |
+
tl.store(RSTD_ptr + batch_idx * RSTD_row_stride + group_idx * RSTD_col_stride, rstd)
|
92 |
+
|
93 |
+
|
94 |
+
@triton.jit
|
95 |
+
def _group_norm_backward_kernel(
|
96 |
+
X_ptr, # pointer to input, shape (n_rows, n_channels, hidden_size)
|
97 |
+
X_row_stride, # stride of each row in input
|
98 |
+
X_col_stride, # stride of each column in input
|
99 |
+
W_ptr, # pointer to weights, shape (n_channels)
|
100 |
+
Mean_ptr, # pointer to mean, shape (n_rows, n_groups)
|
101 |
+
Mean_ptr_row_stride, # stride of each column in mean
|
102 |
+
Mean_ptr_col_stride, # stride of each column in mean
|
103 |
+
RSTD_ptr, # pointer to rstd, shape (n_rows, n_groups)
|
104 |
+
DX_ptr, # pointer to input grad, shape (n_rows, n_groups, hidden_size)
|
105 |
+
DW_ptr, # pointer to weights grad, shape (n_channels)
|
106 |
+
DB_ptr, # pointer to bias grad, shape (n_channels)
|
107 |
+
UPSTREAM_ptr, # pointer to output grad, shape (n_rows, n_channels, hidden_size)
|
108 |
+
hidden_size: tl.constexpr, # hidden size
|
109 |
+
channels_per_group: tl.constexpr, # number of groups in group norm
|
110 |
+
BLOCK_SIZE: tl.constexpr,
|
111 |
+
dtype: tl.constexpr,
|
112 |
+
):
|
113 |
+
"""
|
114 |
+
References:
|
115 |
+
https://nn.labml.ai/normalization/group_norm/index.html
|
116 |
+
https://github.com/karpathy/llm.c/blob/master/doc/layernorm/layernorm.md
|
117 |
+
|
118 |
+
The backprop equations are the same for group_norm and layer_norm
|
119 |
+
the only difference here is that we load the Mean, Rstd corresponding to the
|
120 |
+
group we're computing gradients for and the mean and rstd are computed over n-channels
|
121 |
+
so the total number of elements we compute the mean over is num_channels_per_group * hidden_size
|
122 |
+
|
123 |
+
We also need to load the Weights corresponding to the current channel to compute the gradients.
|
124 |
+
"""
|
125 |
+
batch_idx = tl.program_id(0)
|
126 |
+
group_idx = tl.program_id(1)
|
127 |
+
|
128 |
+
# Move the pointers to the correct batch
|
129 |
+
X_ptr += batch_idx * X_row_stride
|
130 |
+
DX_ptr += batch_idx * X_row_stride
|
131 |
+
UPSTREAM_ptr += batch_idx * X_row_stride
|
132 |
+
|
133 |
+
# Mean and rstd are the same shape so have the same strides
|
134 |
+
mean = tl.load(Mean_ptr + batch_idx * Mean_ptr_row_stride + group_idx * Mean_ptr_col_stride)
|
135 |
+
rstd = tl.load(RSTD_ptr + batch_idx * Mean_ptr_row_stride + group_idx * Mean_ptr_col_stride)
|
136 |
+
|
137 |
+
c1 = 0.0
|
138 |
+
c2 = 0.0
|
139 |
+
block_range = tl.arange(0, BLOCK_SIZE)
|
140 |
+
|
141 |
+
# We need to compute the sum terms of the backprop equations across all channels in the group
|
142 |
+
for channel_idx in range(group_idx * channels_per_group, (group_idx + 1) * channels_per_group):
|
143 |
+
dW = 0.0
|
144 |
+
dB = 0.0
|
145 |
+
# Move the pointers to the correct channel
|
146 |
+
W = tl.load(W_ptr + channel_idx)
|
147 |
+
for i in tl.range(0, hidden_size, BLOCK_SIZE):
|
148 |
+
hidden_size_offsets = i + block_range
|
149 |
+
mask = hidden_size_offsets < hidden_size
|
150 |
+
X = tl.load(
|
151 |
+
X_ptr + channel_idx * X_col_stride + hidden_size_offsets,
|
152 |
+
mask=mask,
|
153 |
+
other=0.0,
|
154 |
+
)
|
155 |
+
UPSTREAM_grad = tl.load(
|
156 |
+
UPSTREAM_ptr + channel_idx * X_col_stride + hidden_size_offsets,
|
157 |
+
mask=mask,
|
158 |
+
other=0.0,
|
159 |
+
)
|
160 |
+
|
161 |
+
x_hat = (X - mean) * rstd
|
162 |
+
dW += tl.sum(UPSTREAM_grad * x_hat)
|
163 |
+
dB += tl.sum(UPSTREAM_grad)
|
164 |
+
|
165 |
+
wdy = W * UPSTREAM_grad
|
166 |
+
c1 += tl.sum(x_hat * wdy)
|
167 |
+
c2 += tl.sum(wdy)
|
168 |
+
|
169 |
+
# Need to ensure additions to the same channel are atomic
|
170 |
+
tl.atomic_add(DW_ptr + channel_idx, dW.to(dtype))
|
171 |
+
tl.atomic_add(DB_ptr + channel_idx, dB.to(dtype))
|
172 |
+
|
173 |
+
N = hidden_size * channels_per_group
|
174 |
+
c1 = c1 / N
|
175 |
+
c2 = c2 / N
|
176 |
+
|
177 |
+
for channel_idx in tl.range(group_idx * channels_per_group, (group_idx + 1) * channels_per_group):
|
178 |
+
# Move the pointers to the correct channel
|
179 |
+
W = tl.load(W_ptr + channel_idx)
|
180 |
+
for i in range(0, hidden_size, BLOCK_SIZE):
|
181 |
+
hidden_size_offsets = i + block_range
|
182 |
+
mask = hidden_size_offsets < hidden_size
|
183 |
+
X = tl.load(
|
184 |
+
X_ptr + channel_idx * X_col_stride + hidden_size_offsets,
|
185 |
+
mask=mask,
|
186 |
+
other=0.0,
|
187 |
+
)
|
188 |
+
UPSTREAM_grad = tl.load(
|
189 |
+
UPSTREAM_ptr + channel_idx * X_col_stride + hidden_size_offsets,
|
190 |
+
mask=mask,
|
191 |
+
other=0.0,
|
192 |
+
)
|
193 |
+
|
194 |
+
x_hat = (X - mean) * rstd
|
195 |
+
wdy = W * UPSTREAM_grad
|
196 |
+
dx = (wdy - (x_hat * c1 + c2)) * rstd
|
197 |
+
tl.store(DX_ptr + channel_idx * X_col_stride + hidden_size_offsets, dx, mask=mask)
|
198 |
+
|
199 |
+
|
200 |
+
def group_norm_forward(X, num_channels, num_groups, W, B, eps):
|
201 |
+
shape = X.shape
|
202 |
+
batch_size = shape[0]
|
203 |
+
channels_per_group = num_channels // num_groups
|
204 |
+
# Reshape X so that the mean and std are computed across the groups
|
205 |
+
X = X.view(batch_size, num_groups, -1).contiguous()
|
206 |
+
hidden_size = X.shape[-1]
|
207 |
+
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(hidden_size))
|
208 |
+
Y = torch.empty((batch_size, num_groups, hidden_size), dtype=X.dtype, device=X.device)
|
209 |
+
Mean = torch.zeros((batch_size, num_groups), dtype=X.dtype, device=X.device)
|
210 |
+
RSTD = torch.zeros((batch_size, num_groups), dtype=X.dtype, device=X.device)
|
211 |
+
|
212 |
+
_group_norm_forward_kernel[(batch_size, num_groups)](
|
213 |
+
Y,
|
214 |
+
Y.stride(0),
|
215 |
+
Y.stride(1),
|
216 |
+
X,
|
217 |
+
X.stride(0),
|
218 |
+
X.stride(1),
|
219 |
+
Mean,
|
220 |
+
Mean.stride(0),
|
221 |
+
Mean.stride(1),
|
222 |
+
RSTD,
|
223 |
+
RSTD.stride(0),
|
224 |
+
RSTD.stride(1),
|
225 |
+
W,
|
226 |
+
B,
|
227 |
+
hidden_size,
|
228 |
+
channels_per_group,
|
229 |
+
eps,
|
230 |
+
BLOCK_SIZE=BLOCK_SIZE,
|
231 |
+
)
|
232 |
+
# Return tensors in the original shape
|
233 |
+
return Y.view(*shape), X.view(*shape), Mean, RSTD, BLOCK_SIZE
|
234 |
+
|
235 |
+
|
236 |
+
def group_norm_backward(dY, X, W, B, Mean, RSTD, num_channels, num_groups):
|
237 |
+
shape = dY.shape
|
238 |
+
batch_size = shape[0]
|
239 |
+
hidden_size = dY.shape[-1]
|
240 |
+
channels_per_group = num_channels // num_groups
|
241 |
+
dY = dY.view(batch_size, num_groups, -1)
|
242 |
+
DX = torch.empty(
|
243 |
+
(batch_size, num_groups, hidden_size * channels_per_group),
|
244 |
+
dtype=X.dtype,
|
245 |
+
device=X.device,
|
246 |
+
)
|
247 |
+
DW = torch.zeros((num_channels), dtype=W.dtype, device=W.device)
|
248 |
+
DB = torch.zeros((num_channels), dtype=B.dtype, device=B.device)
|
249 |
+
triton_dtype = tl.float32 if X.dtype == torch.float32 else tl.bfloat16
|
250 |
+
|
251 |
+
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(hidden_size))
|
252 |
+
_group_norm_backward_kernel[(batch_size, num_groups)](
|
253 |
+
X,
|
254 |
+
X.stride(0),
|
255 |
+
X.stride(1),
|
256 |
+
W,
|
257 |
+
Mean,
|
258 |
+
Mean.stride(0),
|
259 |
+
Mean.stride(1),
|
260 |
+
RSTD,
|
261 |
+
DX,
|
262 |
+
DW,
|
263 |
+
DB,
|
264 |
+
dY,
|
265 |
+
hidden_size,
|
266 |
+
channels_per_group,
|
267 |
+
BLOCK_SIZE=BLOCK_SIZE,
|
268 |
+
dtype=triton_dtype,
|
269 |
+
)
|
270 |
+
|
271 |
+
# Return tensors in the original shape
|
272 |
+
return DX.view(*shape), DW, DB
|
273 |
+
|
274 |
+
|
275 |
+
class LigerGroupNormFunction(torch.autograd.Function):
|
276 |
+
@staticmethod
|
277 |
+
@ensure_contiguous
|
278 |
+
def forward(
|
279 |
+
ctx,
|
280 |
+
X,
|
281 |
+
affine_scaling_weight,
|
282 |
+
affine_shifting_bias,
|
283 |
+
num_channels,
|
284 |
+
num_groups,
|
285 |
+
eps,
|
286 |
+
):
|
287 |
+
Y, X, Mean, RSTD, BLOCK_SIZE = group_norm_forward(
|
288 |
+
X,
|
289 |
+
num_channels,
|
290 |
+
num_groups,
|
291 |
+
affine_scaling_weight,
|
292 |
+
affine_shifting_bias,
|
293 |
+
eps,
|
294 |
+
)
|
295 |
+
ctx.num_channels = num_channels
|
296 |
+
ctx.num_groups = num_groups
|
297 |
+
ctx.save_for_backward(X, affine_scaling_weight, affine_shifting_bias, Mean, RSTD)
|
298 |
+
return Y
|
299 |
+
|
300 |
+
@staticmethod
|
301 |
+
@ensure_contiguous
|
302 |
+
def backward(ctx, dY):
|
303 |
+
X, W, B, Mean, RSTD = ctx.saved_tensors
|
304 |
+
DX, DW, DB = group_norm_backward(dY, X, W, B, Mean, RSTD, ctx.num_channels, ctx.num_groups)
|
305 |
+
return DX, DW, DB, None, None, None
|
build/torch-universal/liger_kernels/jsd.py
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import triton
|
5 |
+
import triton.language as tl
|
6 |
+
|
7 |
+
from utils import ensure_contiguous
|
8 |
+
from utils import infer_device
|
9 |
+
|
10 |
+
|
11 |
+
@triton.jit
|
12 |
+
def _jsd_kernel(
|
13 |
+
X_ptr, # input in logspace, X = log Q
|
14 |
+
X_stride,
|
15 |
+
Y_ptr, # ground truth in logspace, Y = log P
|
16 |
+
Y_stride,
|
17 |
+
loss_ptr,
|
18 |
+
loss_stride,
|
19 |
+
dX_ptr,
|
20 |
+
dX_stride,
|
21 |
+
label_ptr,
|
22 |
+
beta: tl.constexpr,
|
23 |
+
n_non_ignore: int,
|
24 |
+
ignore_index: tl.constexpr,
|
25 |
+
n_cols,
|
26 |
+
BLOCK_SIZE: tl.constexpr,
|
27 |
+
HAS_LABEL: tl.constexpr,
|
28 |
+
):
|
29 |
+
# JSD(P || Q) = (KL(P || M) + KL(Q || M)) / 2, M = (1/2) * (P + Q) = (1/2) * (e ^ Y + e ^ X)
|
30 |
+
# = sum(P * log P + Q * log Q - 2 * M * log M) / 2
|
31 |
+
# = sum(e ^ Y * Y + e ^ X * X - 2 * M * log M) / 2
|
32 |
+
# grad_x_i = 0.5 * Q * (X - log_M)
|
33 |
+
pid = tl.program_id(0).to(tl.int64)
|
34 |
+
X_ptr += pid * X_stride
|
35 |
+
dX_ptr += pid * dX_stride
|
36 |
+
Y_ptr += pid * Y_stride
|
37 |
+
loss_ptr += pid * loss_stride
|
38 |
+
label_ptr += pid
|
39 |
+
|
40 |
+
if HAS_LABEL:
|
41 |
+
label = tl.load(label_ptr)
|
42 |
+
if label == ignore_index:
|
43 |
+
for i in range(0, n_cols, BLOCK_SIZE):
|
44 |
+
offsets = i + tl.arange(0, BLOCK_SIZE)
|
45 |
+
tl.store(dX_ptr + offsets, 0.0, mask=offsets < n_cols)
|
46 |
+
return
|
47 |
+
|
48 |
+
for i in range(0, n_cols, BLOCK_SIZE):
|
49 |
+
offsets = i + tl.arange(0, BLOCK_SIZE)
|
50 |
+
mask = offsets < n_cols
|
51 |
+
X = tl.load(X_ptr + offsets, mask=mask, other=float("-inf")).to(tl.float32)
|
52 |
+
Y = tl.load(Y_ptr + offsets, mask=mask, other=float("-inf")).to(tl.float32)
|
53 |
+
|
54 |
+
if beta == 0.0: # forward KL
|
55 |
+
Y_max = tl.max(Y, axis=0)
|
56 |
+
Y_shifted = Y - Y_max
|
57 |
+
Y_prob = tl.exp(Y_shifted) * tl.exp(Y_max) # Compensate for the shift
|
58 |
+
loss = Y_prob * (Y - X)
|
59 |
+
dX = -Y_prob
|
60 |
+
elif beta == 1.0: # reverse KL
|
61 |
+
X_max = tl.max(X, axis=0)
|
62 |
+
X_shifted = X - X_max
|
63 |
+
X_prob = tl.exp(X_shifted) * tl.exp(X_max) # Compensate for the shift
|
64 |
+
loss = X_prob * (X - Y)
|
65 |
+
dX = loss + X_prob
|
66 |
+
else:
|
67 |
+
max_val = tl.maximum(tl.max(X, axis=0), tl.max(Y, axis=0))
|
68 |
+
X_shifted = X - max_val
|
69 |
+
Y_shifted = Y - max_val
|
70 |
+
|
71 |
+
# Pre-compute exp(max_val) since it's used twice
|
72 |
+
exp_max = tl.exp(max_val)
|
73 |
+
|
74 |
+
# Compute exp terms with compensation
|
75 |
+
Q = tl.exp(X_shifted) * exp_max # = exp(X)
|
76 |
+
P = tl.exp(Y_shifted) * exp_max # = exp(Y)
|
77 |
+
|
78 |
+
# Pre-compute common terms
|
79 |
+
beta_P = beta * P
|
80 |
+
one_minus_beta_Q = (1 - beta) * Q
|
81 |
+
M = beta_P + one_minus_beta_Q
|
82 |
+
log_M = tl.log(M) # No need to compensate as M is already in original scale
|
83 |
+
|
84 |
+
loss = beta_P * Y + one_minus_beta_Q * X - M * log_M
|
85 |
+
dX = one_minus_beta_Q * (X - log_M)
|
86 |
+
|
87 |
+
# Pre-compute scaling factor
|
88 |
+
scale = 1.0 / n_non_ignore
|
89 |
+
loss = loss * scale
|
90 |
+
dX = dX * scale
|
91 |
+
|
92 |
+
tl.store(loss_ptr + offsets, loss, mask=mask)
|
93 |
+
tl.store(dX_ptr + offsets, dX, mask=mask)
|
94 |
+
|
95 |
+
|
96 |
+
MAX_FUSED_SIZE = 4096 if infer_device() == "xpu" else 65536
|
97 |
+
|
98 |
+
|
99 |
+
def jsd_forward(_input, target, shift_labels, beta, ignore_index, has_label):
|
100 |
+
BT, V = _input.shape
|
101 |
+
n_rows = BT
|
102 |
+
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
|
103 |
+
# non reduction loss
|
104 |
+
loss = torch.zeros(_input.shape, dtype=torch.float32, device=_input.device)
|
105 |
+
dX = torch.empty_like(_input)
|
106 |
+
|
107 |
+
if has_label:
|
108 |
+
n_non_ignore = (shift_labels != ignore_index).sum().item()
|
109 |
+
else:
|
110 |
+
n_non_ignore = BT
|
111 |
+
|
112 |
+
_jsd_kernel[(n_rows,)](
|
113 |
+
X_ptr=_input, # input in logspace, X = log Q
|
114 |
+
X_stride=_input.stride(-2),
|
115 |
+
Y_ptr=target, # ground truth in logspace, Y = log P
|
116 |
+
Y_stride=target.stride(-2),
|
117 |
+
loss_ptr=loss,
|
118 |
+
loss_stride=loss.stride(-2),
|
119 |
+
dX_ptr=dX,
|
120 |
+
dX_stride=dX.stride(-2),
|
121 |
+
label_ptr=(shift_labels if has_label else torch.empty(1, device=_input.device)), # dummy ptr if no label
|
122 |
+
beta=beta,
|
123 |
+
n_non_ignore=n_non_ignore,
|
124 |
+
ignore_index=ignore_index,
|
125 |
+
n_cols=V,
|
126 |
+
BLOCK_SIZE=BLOCK_SIZE,
|
127 |
+
HAS_LABEL=has_label,
|
128 |
+
)
|
129 |
+
|
130 |
+
loss = torch.sum(loss)
|
131 |
+
return loss.to(_input.dtype), dX
|
132 |
+
|
133 |
+
|
134 |
+
def jsd_backward(dX, grad_output):
|
135 |
+
# If jsd is the last layer, grad_output is 1.0. Skip the mul to save time
|
136 |
+
if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
|
137 |
+
return dX
|
138 |
+
else:
|
139 |
+
return grad_output * dX
|
140 |
+
|
141 |
+
|
142 |
+
class LigerJSDFunction(torch.autograd.Function):
|
143 |
+
r"""
|
144 |
+
This class implements the forward and backward pass for the generalized Jensen-Shannon Divergence.
|
145 |
+
.. math::
|
146 |
+
JSD(\beta)(P || Q)
|
147 |
+
= \beta * KLDiv(P || (\beta * P + (1 - \beta) * Q)) + (1 - \beta) * KLDiv(Q || (\beta * P + (1 - \beta) * Q))
|
148 |
+
|
149 |
+
.. note::
|
150 |
+
As all the other losses in PyTorch, this function expects the first argument,
|
151 |
+
:attr:`_input`, to be the predictions, the output of the student model, in log-space
|
152 |
+
and the second, :attr:`target`, to be the observations, the output of the teacher model, in log-space.
|
153 |
+
This differs from the standard mathematical notation :math:`JSD(P || Q)` where
|
154 |
+
:math:`P` denotes the teacher model and :math:`Q` denotes the student model.
|
155 |
+
"""
|
156 |
+
|
157 |
+
@staticmethod
|
158 |
+
@ensure_contiguous
|
159 |
+
def forward(
|
160 |
+
ctx,
|
161 |
+
_input: torch.Tensor,
|
162 |
+
target: torch.Tensor,
|
163 |
+
shift_labels: Optional[torch.Tensor] = None,
|
164 |
+
beta: float = 0.5,
|
165 |
+
ignore_index: int = -100,
|
166 |
+
) -> torch.Tensor:
|
167 |
+
"""
|
168 |
+
Args:
|
169 |
+
_input (torch.Tensor): predict values with shape (BT, V) in logspace
|
170 |
+
target (torch.Tensor): ground truth values with shape (BT, V) in logspace
|
171 |
+
shift_labels (Optional[torch.LongTensor]): indicator of next predicted vocab with shape (BT) where each value is in [0, V-1].
|
172 |
+
beta (float): coefficient beta of generalized JSD in the interval [0, 1]. It implements forward/reverse KL when beta equals 0 and 1 respectively. Default: `0.5`
|
173 |
+
ignore_index (int): the index to ignore. Default: -100
|
174 |
+
|
175 |
+
Returns:
|
176 |
+
loss (torch.Tensor): generalized JSD
|
177 |
+
"""
|
178 |
+
has_label = False
|
179 |
+
if shift_labels is not None:
|
180 |
+
assert shift_labels.shape == (_input.shape[0],), (
|
181 |
+
f"the shape of shift_labels must be (BT,). Got: {shift_labels.shape}"
|
182 |
+
)
|
183 |
+
shift_labels = shift_labels.contiguous()
|
184 |
+
has_label = True
|
185 |
+
|
186 |
+
loss, dX = jsd_forward(_input, target, shift_labels, beta, ignore_index, has_label)
|
187 |
+
ctx.save_for_backward(dX)
|
188 |
+
return loss
|
189 |
+
|
190 |
+
@staticmethod
|
191 |
+
@ensure_contiguous
|
192 |
+
def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
|
193 |
+
(dX,) = ctx.saved_tensors
|
194 |
+
dX = jsd_backward(dX, grad_output)
|
195 |
+
return (
|
196 |
+
dX,
|
197 |
+
None,
|
198 |
+
None,
|
199 |
+
None,
|
200 |
+
None,
|
201 |
+
)
|
build/torch-universal/liger_kernels/kl_div.py
ADDED
@@ -0,0 +1,262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Literal
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import triton
|
5 |
+
import triton.language as tl
|
6 |
+
|
7 |
+
from utils import ensure_contiguous
|
8 |
+
from utils import is_hip
|
9 |
+
from utils import infer_device
|
10 |
+
|
11 |
+
|
12 |
+
def get_num_warps(BLOCK_SIZE):
|
13 |
+
num_warps = 4
|
14 |
+
if BLOCK_SIZE >= 32768:
|
15 |
+
num_warps = 32 if not is_hip() else 16
|
16 |
+
elif BLOCK_SIZE >= 8192:
|
17 |
+
num_warps = 16
|
18 |
+
elif BLOCK_SIZE >= 2048:
|
19 |
+
num_warps = 8
|
20 |
+
|
21 |
+
return num_warps
|
22 |
+
|
23 |
+
|
24 |
+
MAX_FUSED_SIZE = 65536 // 4 # 65536 // 4 or 8 works the best
|
25 |
+
|
26 |
+
REDUCTION_LITERAL = Literal["none", "sum", "mean", "batchmean"]
|
27 |
+
|
28 |
+
_REDUCTION_MODE_NONE: tl.constexpr = tl.constexpr(0)
|
29 |
+
_REDUCTION_MODE_SUM: tl.constexpr = tl.constexpr(1)
|
30 |
+
_REDUCTION_MODE_MEAN: tl.constexpr = tl.constexpr(2)
|
31 |
+
_REDUCTION_MODE_BATCHMEAN: tl.constexpr = tl.constexpr(3)
|
32 |
+
|
33 |
+
_str_to_reduction_mode = {
|
34 |
+
"none": _REDUCTION_MODE_NONE.value,
|
35 |
+
"sum": _REDUCTION_MODE_SUM.value,
|
36 |
+
"mean": _REDUCTION_MODE_MEAN.value,
|
37 |
+
"batchmean": _REDUCTION_MODE_BATCHMEAN.value,
|
38 |
+
}
|
39 |
+
|
40 |
+
|
41 |
+
@triton.jit
|
42 |
+
def _kldiv_kernel_forward(
|
43 |
+
y_ptr, # [B, S], prediction ptr, the kernel expects the prediction in log-space
|
44 |
+
y_stride, # int, prediction stride
|
45 |
+
gt_ptr, # [B, S], ground truth ptr
|
46 |
+
gt_stride, # int, ground truth stride
|
47 |
+
loss_ptr, # [B] or [B, S] if reduction == _REDUCTION_MODE_NONE, output ptr
|
48 |
+
loss_stride, # int, output stride
|
49 |
+
n_cols, # int, number of columns in the input tensor
|
50 |
+
eps,
|
51 |
+
BLOCK_SIZE: tl.constexpr,
|
52 |
+
log_target: tl.constexpr = False,
|
53 |
+
reduction: tl.constexpr = _REDUCTION_MODE_BATCHMEAN,
|
54 |
+
):
|
55 |
+
pid = tl.program_id(0).to(tl.int64)
|
56 |
+
y_ptr += pid * y_stride
|
57 |
+
gt_ptr += pid * gt_stride
|
58 |
+
loss_ptr += pid * loss_stride
|
59 |
+
|
60 |
+
base_offsets = tl.arange(0, BLOCK_SIZE)
|
61 |
+
|
62 |
+
loss_sum = 0.0
|
63 |
+
for i in range(0, n_cols, BLOCK_SIZE):
|
64 |
+
offsets = i + base_offsets
|
65 |
+
mask = offsets < n_cols
|
66 |
+
y = tl.load(y_ptr + offsets, mask=mask, other=0.0)
|
67 |
+
y_true = tl.load(gt_ptr + offsets, mask=mask, other=0.0)
|
68 |
+
|
69 |
+
# KL(y_true || y) = y_true * (log(y_true) - log(y))
|
70 |
+
# We compute KL(y_true || y) with y in the log-space
|
71 |
+
if not log_target:
|
72 |
+
loss = y_true * (tl.log(tl.maximum(y_true, eps)) - y)
|
73 |
+
else:
|
74 |
+
loss = tl.exp(y_true) * (y_true - y)
|
75 |
+
|
76 |
+
if reduction == _REDUCTION_MODE_NONE:
|
77 |
+
tl.store(loss_ptr + offsets, loss, mask=mask)
|
78 |
+
else:
|
79 |
+
loss_sum += tl.sum(loss, axis=0)
|
80 |
+
|
81 |
+
if reduction != _REDUCTION_MODE_NONE:
|
82 |
+
tl.store(loss_ptr, loss_sum)
|
83 |
+
|
84 |
+
|
85 |
+
@triton.jit
|
86 |
+
def _kldiv_kernel_backward(
|
87 |
+
target_ptr,
|
88 |
+
target_stride,
|
89 |
+
new_grads_ptr,
|
90 |
+
new_grads_stride,
|
91 |
+
n_cols,
|
92 |
+
BLOCK_SIZE: tl.constexpr,
|
93 |
+
log_target: tl.constexpr = False,
|
94 |
+
):
|
95 |
+
pid = tl.program_id(0).to(tl.int64)
|
96 |
+
|
97 |
+
target_ptr += pid * target_stride
|
98 |
+
new_grads_ptr += pid * new_grads_stride
|
99 |
+
|
100 |
+
offsets = tl.arange(0, BLOCK_SIZE)
|
101 |
+
mask = offsets < n_cols
|
102 |
+
|
103 |
+
for i in range(0, n_cols, BLOCK_SIZE):
|
104 |
+
offsets = i + tl.arange(0, BLOCK_SIZE)
|
105 |
+
mask = offsets < n_cols
|
106 |
+
|
107 |
+
target = tl.load(target_ptr + offsets, mask=mask, other=0.0)
|
108 |
+
|
109 |
+
if not log_target:
|
110 |
+
res = target * -1
|
111 |
+
else:
|
112 |
+
res = -tl.exp(target)
|
113 |
+
|
114 |
+
tl.store(new_grads_ptr + offsets, res, mask=mask)
|
115 |
+
|
116 |
+
|
117 |
+
def kldiv_forward_triton(y_pred, y_true, log_target, reduction, eps): # [BT, V]
|
118 |
+
BT, V = y_pred.shape
|
119 |
+
BLOCK_SIZE = (
|
120 |
+
min(8192, triton.next_power_of_2(V))
|
121 |
+
if infer_device() == "xpu"
|
122 |
+
else min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
|
123 |
+
)
|
124 |
+
num_warps = 32 if infer_device() == "xpu" else get_num_warps(BLOCK_SIZE)
|
125 |
+
|
126 |
+
grid = (BT,)
|
127 |
+
reduction = _str_to_reduction_mode[reduction]
|
128 |
+
|
129 |
+
out_size = (BT, V) if reduction == _REDUCTION_MODE_NONE.value else (BT,)
|
130 |
+
output_tensor = torch.zeros(out_size, device=y_pred.device, dtype=torch.float32)
|
131 |
+
|
132 |
+
_kldiv_kernel_forward[grid](
|
133 |
+
y_pred,
|
134 |
+
y_pred.stride(0),
|
135 |
+
y_true,
|
136 |
+
y_true.stride(0),
|
137 |
+
output_tensor,
|
138 |
+
output_tensor.stride(0),
|
139 |
+
V,
|
140 |
+
eps=eps,
|
141 |
+
BLOCK_SIZE=BLOCK_SIZE,
|
142 |
+
num_warps=num_warps,
|
143 |
+
log_target=log_target,
|
144 |
+
reduction=reduction,
|
145 |
+
)
|
146 |
+
|
147 |
+
# calculated according to the reduction mode same as in Pytorch. In the later versions, `mean` will be changed to the same behavior as `batchmean`
|
148 |
+
# https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html
|
149 |
+
# https://github.com/pytorch/pytorch/blob/d7b57c4d63edb42e1deeeba9497fcb5f1f748ff2/torch/nn/functional.py#L3372
|
150 |
+
if reduction == _REDUCTION_MODE_BATCHMEAN.value:
|
151 |
+
return output_tensor.sum() / BT
|
152 |
+
elif reduction == _REDUCTION_MODE_SUM.value:
|
153 |
+
return output_tensor.sum(dim=0)
|
154 |
+
elif reduction == _REDUCTION_MODE_MEAN.value:
|
155 |
+
return output_tensor.sum() / (BT * V)
|
156 |
+
else:
|
157 |
+
return output_tensor
|
158 |
+
|
159 |
+
|
160 |
+
def kldiv_backward_triton(target, grad_output, new_grads, log_target):
|
161 |
+
BT, V = target.shape
|
162 |
+
BLOCK_SIZE = (
|
163 |
+
min(8192, triton.next_power_of_2(V))
|
164 |
+
if infer_device() == "xpu"
|
165 |
+
else min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
|
166 |
+
)
|
167 |
+
num_warps = 32 if infer_device() == "xpu" else get_num_warps(BLOCK_SIZE)
|
168 |
+
|
169 |
+
grid = (BT,)
|
170 |
+
|
171 |
+
# We store the gradients in-place in the input tensor
|
172 |
+
_kldiv_kernel_backward[grid](
|
173 |
+
target,
|
174 |
+
target.stride(0),
|
175 |
+
new_grads,
|
176 |
+
new_grads.stride(0),
|
177 |
+
V,
|
178 |
+
BLOCK_SIZE=BLOCK_SIZE,
|
179 |
+
num_warps=num_warps,
|
180 |
+
log_target=log_target,
|
181 |
+
)
|
182 |
+
|
183 |
+
# If cross entropy is the last layer, grad_output is 1.0. Skip the mul then.
|
184 |
+
if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
|
185 |
+
return new_grads
|
186 |
+
|
187 |
+
return new_grads * grad_output
|
188 |
+
|
189 |
+
|
190 |
+
class LigerKLDivLossFunction(torch.autograd.Function):
|
191 |
+
"""
|
192 |
+
Class implementing the forward and backward pass for the KL Divergence Loss using Triton, as defined by the following formula:
|
193 |
+
```python
|
194 |
+
if log_target:
|
195 |
+
loss = target.exp() * (target - input)
|
196 |
+
else:
|
197 |
+
loss = target * (target.log() - input)
|
198 |
+
```,
|
199 |
+
then the loss is reduced according to the `reduction` parameter.
|
200 |
+
as defined in the PyTorch documentation: https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html
|
201 |
+
"""
|
202 |
+
|
203 |
+
@staticmethod
|
204 |
+
@ensure_contiguous
|
205 |
+
def forward(
|
206 |
+
ctx,
|
207 |
+
y_pred: torch.Tensor,
|
208 |
+
y_true: torch.Tensor,
|
209 |
+
reduction: REDUCTION_LITERAL = "batchmean",
|
210 |
+
log_target: bool = False,
|
211 |
+
eps: float = 1e-10,
|
212 |
+
) -> torch.Tensor:
|
213 |
+
"""A forward pass for the KL Divergence Loss.
|
214 |
+
|
215 |
+
Args:
|
216 |
+
ctx: Torch autograd context
|
217 |
+
y_pred (torch.Tensor): A tensor of shape (BT, V) containing the predicted values, expected to be log-probabilities.
|
218 |
+
y_true (torch.Tensor): A tensor of shape (BT, V) containing the target values, expected to be either probabilities or log-probabilities, depending on the value of `log_target`.
|
219 |
+
reduction (REDUCTION_LITERAL, optional): Reduction to be used. Defaults to "batchmean".
|
220 |
+
log_target (bool, optional): If set to true, expects the ground truth to already be log-probabilities. Defaults to False.
|
221 |
+
eps: (float, optional): A small value to avoid division by zero. Defaults to 1e-10.
|
222 |
+
|
223 |
+
Returns:
|
224 |
+
torch.Tensor: The computed KL Divergence Loss, with shape (BT, V) if `reduction` is "none", else a scalar.
|
225 |
+
"""
|
226 |
+
ctx.save_for_backward(y_true)
|
227 |
+
ctx.reduction = reduction
|
228 |
+
ctx.log_target = log_target
|
229 |
+
return kldiv_forward_triton(y_pred, y_true, log_target=log_target, reduction=reduction, eps=eps)
|
230 |
+
|
231 |
+
@staticmethod
|
232 |
+
@ensure_contiguous
|
233 |
+
def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
|
234 |
+
"""A backward pass for the KL Divergence Loss.
|
235 |
+
|
236 |
+
Args:
|
237 |
+
ctx: Torch autograd context
|
238 |
+
grad_output (torch.Tensor): The gradient of the loss with respect to the output.
|
239 |
+
|
240 |
+
Returns:
|
241 |
+
tuple[torch.Tensor, None, None, None, None]: The gradient of the loss with respect to the inputs and None for the other arguments of the forward method.
|
242 |
+
"""
|
243 |
+
(y_true,) = ctx.saved_tensors
|
244 |
+
|
245 |
+
new_grads = torch.empty_like(y_true)
|
246 |
+
|
247 |
+
derivative = kldiv_backward_triton(y_true, grad_output, new_grads, ctx.log_target)
|
248 |
+
|
249 |
+
if ctx.reduction == "batchmean":
|
250 |
+
derivative = derivative / y_true.shape[0]
|
251 |
+
elif ctx.reduction == "sum" or ctx.reduction == "none":
|
252 |
+
pass
|
253 |
+
elif ctx.reduction == "mean":
|
254 |
+
derivative = derivative / (y_true.shape[0] * y_true.shape[1])
|
255 |
+
|
256 |
+
return (
|
257 |
+
derivative,
|
258 |
+
None,
|
259 |
+
None,
|
260 |
+
None,
|
261 |
+
None,
|
262 |
+
)
|
build/torch-universal/liger_kernels/layer_norm.py
ADDED
@@ -0,0 +1,265 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import operator
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import triton
|
6 |
+
import triton.language as tl
|
7 |
+
|
8 |
+
from utils import calculate_settings
|
9 |
+
from utils import compare_version
|
10 |
+
from utils import ensure_contiguous
|
11 |
+
|
12 |
+
if compare_version("triton", operator.ge, "3.0.0"):
|
13 |
+
try:
|
14 |
+
# typical import path with dispatch available
|
15 |
+
from triton.language.extra.libdevice import rsqrt
|
16 |
+
except ModuleNotFoundError:
|
17 |
+
# for working with NGC containers
|
18 |
+
from triton.language.extra.cuda.libdevice import rsqrt
|
19 |
+
else:
|
20 |
+
from triton.language.math import rsqrt
|
21 |
+
|
22 |
+
|
23 |
+
@triton.jit
|
24 |
+
def _layer_norm_forward_kernel(
|
25 |
+
Y_ptr, # pointer to output, shape (n_rows, n_cols)
|
26 |
+
Y_row_stride, # stride of each row in output
|
27 |
+
X_ptr, # pointer to input, shape (n_rows, n_cols)
|
28 |
+
X_row_stride, # stride of each row in input
|
29 |
+
W_ptr, # pointer to weights, shape (n_cols,)
|
30 |
+
W_row_stride, # stride of each row in weights
|
31 |
+
B_ptr, # pointer to bias, shape (n_cols,)
|
32 |
+
B_row_stride, # stride of each row in bias
|
33 |
+
Mean_ptr, # pointer to mean, shape (n_rows,)
|
34 |
+
Mean_row_stride, # stride of each row in mean
|
35 |
+
RSTD_ptr, # pointer to rstd, shape (n_rows,)
|
36 |
+
RSTD_row_stride, # stride of each row in rstd
|
37 |
+
n_cols,
|
38 |
+
eps,
|
39 |
+
BLOCK_SIZE: tl.constexpr,
|
40 |
+
):
|
41 |
+
"""
|
42 |
+
References:
|
43 |
+
https://arxiv.org/abs/1607.06450
|
44 |
+
https://github.com/karpathy/llm.c/blob/master/doc/layernorm/layernorm.md
|
45 |
+
"""
|
46 |
+
row_idx = tl.program_id(0)
|
47 |
+
col_offsets = tl.arange(0, BLOCK_SIZE)
|
48 |
+
mask = col_offsets < n_cols
|
49 |
+
|
50 |
+
Y_ptr += row_idx * Y_row_stride
|
51 |
+
X_ptr += row_idx * X_row_stride
|
52 |
+
Mean_ptr += row_idx * Mean_row_stride
|
53 |
+
RSTD_ptr += row_idx * RSTD_row_stride
|
54 |
+
|
55 |
+
X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0)
|
56 |
+
W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0)
|
57 |
+
B_row = tl.load(B_ptr + col_offsets, mask=mask, other=0)
|
58 |
+
|
59 |
+
mean = tl.sum(X_row, axis=0) / n_cols
|
60 |
+
Xmm = tl.where(mask, X_row - mean, 0)
|
61 |
+
var = tl.sum(Xmm * Xmm, axis=0) / n_cols
|
62 |
+
rstd = rsqrt(var + eps)
|
63 |
+
|
64 |
+
tl.store(Mean_ptr, mean)
|
65 |
+
tl.store(RSTD_ptr, rstd)
|
66 |
+
|
67 |
+
Y_row = Xmm * rstd * W_row + B_row
|
68 |
+
|
69 |
+
tl.store(Y_ptr + col_offsets, Y_row, mask=mask)
|
70 |
+
|
71 |
+
|
72 |
+
@triton.jit
|
73 |
+
def _layer_norm_backward_kernel(
|
74 |
+
X_ptr, # pointer to input, shape (n_rows, n_cols)
|
75 |
+
W_ptr, # pointer to weights, shape (n_cols,)
|
76 |
+
Mean_ptr, # pointer to mean, shape (n_rows,)
|
77 |
+
RSTD_ptr, # pointer to rstd, shape (n_rows,)
|
78 |
+
DX_ptr, # pointer to input grad, shape (n_rows, n_cols)
|
79 |
+
DW_ptr, # pointer to weights grad, shape (n_cols,)
|
80 |
+
DB_ptr, # pointer to bias grad, shape (n_cols,)
|
81 |
+
DY_ptr, # pointer to output grad, shape (n_rows, n_cols)
|
82 |
+
stride_x, # stride of each row in input
|
83 |
+
stride_dx, # stride of each row in input grad
|
84 |
+
stride_dw, # stride of each row in weights grad
|
85 |
+
stride_db, # stride of each row in bias grad
|
86 |
+
stride_dy, # stride of each row in output grad
|
87 |
+
n_rows,
|
88 |
+
n_cols,
|
89 |
+
rows_per_program: tl.constexpr,
|
90 |
+
BLOCK_SIZE: tl.constexpr,
|
91 |
+
dtype: tl.constexpr,
|
92 |
+
):
|
93 |
+
"""
|
94 |
+
References:
|
95 |
+
https://arxiv.org/abs/1607.06450
|
96 |
+
https://github.com/karpathy/llm.c/blob/master/doc/layernorm/layernorm.md
|
97 |
+
https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
|
98 |
+
https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/triton/layer_norm.py
|
99 |
+
"""
|
100 |
+
row_block_id = tl.program_id(0)
|
101 |
+
row_start = row_block_id * rows_per_program
|
102 |
+
row_end = min((row_block_id + 1) * rows_per_program, n_rows)
|
103 |
+
cols = tl.arange(0, BLOCK_SIZE)
|
104 |
+
mask = cols < n_cols
|
105 |
+
|
106 |
+
dw_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
|
107 |
+
db_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
|
108 |
+
|
109 |
+
X_ptr += row_start * stride_x
|
110 |
+
Mean_ptr += row_start
|
111 |
+
RSTD_ptr += row_start
|
112 |
+
DX_ptr += row_start * stride_dx
|
113 |
+
DY_ptr += row_start * stride_dy
|
114 |
+
|
115 |
+
for _ in range(row_start, row_end):
|
116 |
+
x = tl.load(X_ptr + cols, mask=mask, other=0.0)
|
117 |
+
w = tl.load(W_ptr + cols, mask=mask, other=0.0)
|
118 |
+
dy = tl.load(DY_ptr + cols, mask=mask, other=0.0)
|
119 |
+
mean = tl.load(Mean_ptr)
|
120 |
+
rstd = tl.load(RSTD_ptr)
|
121 |
+
|
122 |
+
x_hat = (x - mean) * rstd
|
123 |
+
wdy = w * dy
|
124 |
+
c1 = tl.sum(x_hat * wdy, axis=0) / n_cols
|
125 |
+
c2 = tl.sum(wdy, axis=0) / n_cols
|
126 |
+
dx = (wdy - (x_hat * c1 + c2)) * rstd
|
127 |
+
tl.store(DX_ptr + cols, dx.to(dtype), mask=mask)
|
128 |
+
|
129 |
+
dw_row += dy * x_hat
|
130 |
+
db_row += dy
|
131 |
+
|
132 |
+
X_ptr += stride_x
|
133 |
+
Mean_ptr += 1
|
134 |
+
RSTD_ptr += 1
|
135 |
+
DX_ptr += stride_dx
|
136 |
+
DY_ptr += stride_dy
|
137 |
+
|
138 |
+
tl.store(DW_ptr + row_block_id * stride_dw + cols, dw_row.to(dtype), mask=mask)
|
139 |
+
tl.store(DB_ptr + row_block_id * stride_db + cols, db_row.to(dtype), mask=mask)
|
140 |
+
|
141 |
+
|
142 |
+
def layer_norm_forward(X, W, B, eps):
|
143 |
+
shape = X.shape
|
144 |
+
dim = shape[-1]
|
145 |
+
X = X.view(-1, dim)
|
146 |
+
n_rows, n_cols = X.shape
|
147 |
+
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
148 |
+
Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
|
149 |
+
Mean = torch.empty(n_rows, dtype=X.dtype, device=X.device)
|
150 |
+
RSTD = torch.empty(n_rows, dtype=X.dtype, device=X.device)
|
151 |
+
if X.shape[1] != W.shape[0]:
|
152 |
+
raise ValueError(
|
153 |
+
f"Incompatible dimensions: input feature size (X.shape[1]={X.shape[1]}) "
|
154 |
+
f"must match weight size (W.shape[0]={W.shape[0]})"
|
155 |
+
)
|
156 |
+
|
157 |
+
# XPU-specific optimization
|
158 |
+
kernel_args = {}
|
159 |
+
if X.device.type == "xpu":
|
160 |
+
kernel_args["grf_mode"] = "large"
|
161 |
+
|
162 |
+
_layer_norm_forward_kernel[(n_rows,)](
|
163 |
+
Y,
|
164 |
+
Y.stride(0),
|
165 |
+
X,
|
166 |
+
X.stride(0),
|
167 |
+
W,
|
168 |
+
W.stride(0),
|
169 |
+
B,
|
170 |
+
B.stride(0),
|
171 |
+
Mean,
|
172 |
+
Mean.stride(0),
|
173 |
+
RSTD,
|
174 |
+
RSTD.stride(0),
|
175 |
+
n_cols,
|
176 |
+
eps,
|
177 |
+
BLOCK_SIZE=BLOCK_SIZE,
|
178 |
+
num_warps=num_warps,
|
179 |
+
**kernel_args, # XPU-specific optimization
|
180 |
+
)
|
181 |
+
return Y.view(*shape), X, Mean, RSTD, BLOCK_SIZE, num_warps
|
182 |
+
|
183 |
+
|
184 |
+
def layer_norm_backward(dY, X, W, B, Mean, RSTD):
|
185 |
+
shape = dY.shape
|
186 |
+
dim = shape[-1]
|
187 |
+
dY = dY.view(-1, dim)
|
188 |
+
n_rows, n_cols = dY.shape
|
189 |
+
|
190 |
+
sm_count = 1
|
191 |
+
if X.device.type == "cuda":
|
192 |
+
sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
|
193 |
+
elif X.device.type == "xpu":
|
194 |
+
sm_count = torch.xpu.get_device_properties(X.device).gpu_eu_count
|
195 |
+
|
196 |
+
DX = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
|
197 |
+
_DW = torch.empty((sm_count, n_cols), dtype=W.dtype, device=W.device)
|
198 |
+
_DB = torch.empty((sm_count, n_cols), dtype=W.dtype, device=W.device)
|
199 |
+
|
200 |
+
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
201 |
+
if n_cols > BLOCK_SIZE:
|
202 |
+
raise RuntimeError(
|
203 |
+
f"Feature dimension {n_cols} exceeds maximum supported size of {BLOCK_SIZE}. Consider using a smaller feature dimension."
|
204 |
+
)
|
205 |
+
|
206 |
+
rows_per_program = math.ceil(n_rows / sm_count)
|
207 |
+
grid = (sm_count,)
|
208 |
+
triton_dtype = (
|
209 |
+
tl.float32
|
210 |
+
if X.dtype == torch.float32
|
211 |
+
else tl.bfloat16
|
212 |
+
if X.dtype == torch.bfloat16
|
213 |
+
else tl.float16
|
214 |
+
if X.dtype == torch.float16
|
215 |
+
else tl.float32 # fallback to float32 for other types
|
216 |
+
)
|
217 |
+
|
218 |
+
# XPU-specific optimization
|
219 |
+
kernel_args = {}
|
220 |
+
if X.device.type == "xpu":
|
221 |
+
kernel_args.update({"grf_mode": "large", "num_warps": 32, "num_stages": 4})
|
222 |
+
|
223 |
+
_layer_norm_backward_kernel[grid](
|
224 |
+
X,
|
225 |
+
W,
|
226 |
+
Mean,
|
227 |
+
RSTD,
|
228 |
+
DX,
|
229 |
+
_DW,
|
230 |
+
_DB,
|
231 |
+
dY,
|
232 |
+
X.stride(0),
|
233 |
+
DX.stride(0),
|
234 |
+
_DW.stride(0),
|
235 |
+
_DB.stride(0),
|
236 |
+
dY.stride(0),
|
237 |
+
n_rows,
|
238 |
+
n_cols,
|
239 |
+
rows_per_program,
|
240 |
+
BLOCK_SIZE=BLOCK_SIZE,
|
241 |
+
dtype=triton_dtype,
|
242 |
+
**kernel_args, # XPU-specific optimization
|
243 |
+
)
|
244 |
+
|
245 |
+
DW = _DW.sum(dim=0).to(W.dtype)
|
246 |
+
DB = _DB.sum(dim=0).to(W.dtype)
|
247 |
+
|
248 |
+
DX = DX.view(*shape)
|
249 |
+
return DX, DW, DB
|
250 |
+
|
251 |
+
|
252 |
+
class LigerLayerNormFunction(torch.autograd.Function):
|
253 |
+
@staticmethod
|
254 |
+
@ensure_contiguous
|
255 |
+
def forward(ctx, X, W, B, eps):
|
256 |
+
Y, X, Mean, RSTD, BLOCK_SIZE, num_warps = layer_norm_forward(X, W, B, eps)
|
257 |
+
ctx.save_for_backward(X, W, B, Mean, RSTD)
|
258 |
+
return Y
|
259 |
+
|
260 |
+
@staticmethod
|
261 |
+
@ensure_contiguous
|
262 |
+
def backward(ctx, dY):
|
263 |
+
X, W, B, Mean, RSTD = ctx.saved_tensors
|
264 |
+
DX, DW, DB = layer_norm_backward(dY, X, W, B, Mean, RSTD)
|
265 |
+
return DX, DW, DB, None
|
build/torch-universal/liger_kernels/qwen2vl_mrope.py
ADDED
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import triton
|
3 |
+
import triton.language as tl
|
4 |
+
|
5 |
+
|
6 |
+
@triton.jit
|
7 |
+
def _triton_qwen2vl_mrope(
|
8 |
+
q_ptr,
|
9 |
+
k_ptr,
|
10 |
+
cos,
|
11 |
+
sin,
|
12 |
+
sl,
|
13 |
+
bs: tl.constexpr,
|
14 |
+
n_qh: tl.constexpr,
|
15 |
+
n_kh: tl.constexpr,
|
16 |
+
hd: tl.constexpr,
|
17 |
+
pad_n_qh: tl.constexpr,
|
18 |
+
pad_n_kh: tl.constexpr,
|
19 |
+
pad_hd: tl.constexpr,
|
20 |
+
mrope_section_t: tl.constexpr,
|
21 |
+
mrope_section_h: tl.constexpr,
|
22 |
+
BLOCK_SIZE: tl.constexpr,
|
23 |
+
BACKWARD_PASS: tl.constexpr = False,
|
24 |
+
):
|
25 |
+
pid = tl.program_id(0)
|
26 |
+
|
27 |
+
# locate start address
|
28 |
+
q_ptr = q_ptr + pid * (n_qh * hd)
|
29 |
+
k_ptr = k_ptr + pid * (n_kh * hd)
|
30 |
+
|
31 |
+
# ####################################################################
|
32 |
+
# get the cos(mθ_{i...d/2}) and sin(mθ_{i...d/2}) for token position
|
33 |
+
# m of this program instance
|
34 |
+
# ####################################################################
|
35 |
+
|
36 |
+
# 1. program instances are laid out in a 1D vector of size bsz * seq_len, which
|
37 |
+
# effectively represents a 2D grid of size [bsz, seq_len] with seq_len dimension
|
38 |
+
# being the fastest changing dimension. Thus we can simply do pid // sl to get the batch index
|
39 |
+
# and pid % sl to get the sequence index.
|
40 |
+
# 2. We only need the left half of cos and sin matrix because the right half is just
|
41 |
+
# a clone of the left half.
|
42 |
+
t_end = mrope_section_t
|
43 |
+
h_end = t_end + mrope_section_h
|
44 |
+
|
45 |
+
t_cos = cos + pid * hd
|
46 |
+
h_cos = t_cos + bs * sl * hd
|
47 |
+
w_cos = h_cos + bs * sl * hd
|
48 |
+
t_sin = sin + pid * hd
|
49 |
+
h_sin = t_sin + bs * sl * hd
|
50 |
+
w_sin = h_sin + bs * sl * hd
|
51 |
+
|
52 |
+
cos_offsets = tl.arange(0, pad_hd // 2)
|
53 |
+
t_mask = cos_offsets < t_end
|
54 |
+
h_mask = (t_end <= cos_offsets) & (cos_offsets < h_end)
|
55 |
+
w_mask = (h_end <= cos_offsets) & (cos_offsets < hd // 2)
|
56 |
+
t_cos_row = tl.load(t_cos + cos_offsets, mask=t_mask, other=0)
|
57 |
+
h_cos_row = tl.load(h_cos + cos_offsets, mask=h_mask, other=0)
|
58 |
+
w_cos_row = tl.load(w_cos + cos_offsets, mask=w_mask, other=0)
|
59 |
+
t_sin_row = tl.load(t_sin + cos_offsets, mask=t_mask, other=0)
|
60 |
+
h_sin_row = tl.load(h_sin + cos_offsets, mask=h_mask, other=0)
|
61 |
+
w_sin_row = tl.load(w_sin + cos_offsets, mask=w_mask, other=0)
|
62 |
+
cos_row = t_cos_row + h_cos_row + w_cos_row
|
63 |
+
sin_row = t_sin_row + h_sin_row + w_sin_row
|
64 |
+
|
65 |
+
# ####################################################################
|
66 |
+
# Load the left and right half of q and k for the current
|
67 |
+
# program instance (i.e. for the current token) separately
|
68 |
+
# ####################################################################
|
69 |
+
# left half of the head
|
70 |
+
first_half_q_offsets = tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
|
71 |
+
first_half_k_offsets = tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
|
72 |
+
first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (tl.arange(0, pad_hd // 2)[None, :] < hd // 2)
|
73 |
+
first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (tl.arange(0, pad_hd // 2)[None, :] < hd // 2)
|
74 |
+
q_tile_1 = tl.load(q_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to(sin_row.dtype)
|
75 |
+
k_tile_1 = tl.load(k_ptr + first_half_k_offsets, mask=first_k_mask, other=0).to(sin_row.dtype)
|
76 |
+
|
77 |
+
# right half of the head
|
78 |
+
second_half_q_offsets = first_half_q_offsets + (hd // 2)
|
79 |
+
second_half_k_offsets = first_half_k_offsets + (hd // 2)
|
80 |
+
second_q_mask = first_q_mask
|
81 |
+
second_k_mask = first_k_mask
|
82 |
+
q_tile_2 = tl.load(q_ptr + second_half_q_offsets, mask=second_q_mask, other=0).to(sin_row.dtype)
|
83 |
+
k_tile_2 = tl.load(k_ptr + second_half_k_offsets, mask=second_k_mask, other=0).to(sin_row.dtype)
|
84 |
+
|
85 |
+
if not BACKWARD_PASS:
|
86 |
+
# y = [x1, x2] * [cos, cos] + [-x2, x1] * [sin, sin]
|
87 |
+
new_q_tile_1 = q_tile_1 * cos_row - q_tile_2 * sin_row
|
88 |
+
tl.store(q_ptr + first_half_q_offsets, new_q_tile_1, mask=first_q_mask)
|
89 |
+
new_q_tile_2 = q_tile_2 * cos_row + q_tile_1 * sin_row
|
90 |
+
tl.store(q_ptr + second_half_q_offsets, new_q_tile_2, mask=second_q_mask)
|
91 |
+
|
92 |
+
new_k_tile_1 = k_tile_1 * cos_row - k_tile_2 * sin_row
|
93 |
+
tl.store(k_ptr + first_half_k_offsets, new_k_tile_1, mask=first_k_mask)
|
94 |
+
new_k_tile_2 = k_tile_2 * cos_row + k_tile_1 * sin_row
|
95 |
+
tl.store(k_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask)
|
96 |
+
else:
|
97 |
+
# with some math, we can get:
|
98 |
+
# dy = [dx1, dx2] * [cos, cos] + [-dx2, dx1] * [-sin, -sin]
|
99 |
+
new_q_tile_1 = q_tile_1 * cos_row + q_tile_2 * sin_row
|
100 |
+
tl.store(q_ptr + first_half_q_offsets, new_q_tile_1, mask=first_q_mask)
|
101 |
+
new_q_tile_2 = q_tile_2 * cos_row - q_tile_1 * sin_row
|
102 |
+
tl.store(q_ptr + second_half_q_offsets, new_q_tile_2, mask=second_q_mask)
|
103 |
+
|
104 |
+
new_k_tile_1 = k_tile_1 * cos_row + k_tile_2 * sin_row
|
105 |
+
tl.store(k_ptr + first_half_k_offsets, new_k_tile_1, mask=first_k_mask)
|
106 |
+
new_k_tile_2 = k_tile_2 * cos_row - k_tile_1 * sin_row
|
107 |
+
tl.store(k_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask)
|
108 |
+
|
109 |
+
|
110 |
+
def qwen2vl_mrope_forward(q, k, cos, sin, mrope_section):
|
111 |
+
# transpose it back to the physical shape because Triton looks at the physical storage
|
112 |
+
# note: q and k are incontiguous before the transformation and will become contiguous after transpose
|
113 |
+
q = q.transpose(1, 2)
|
114 |
+
k = k.transpose(1, 2)
|
115 |
+
|
116 |
+
batch_size, seq_len, n_q_head, head_dim = q.shape
|
117 |
+
n_kv_head = k.shape[2]
|
118 |
+
pad_hd = triton.next_power_of_2(head_dim)
|
119 |
+
pad_n_q_head = triton.next_power_of_2(n_q_head)
|
120 |
+
pad_n_kv_head = triton.next_power_of_2(n_kv_head)
|
121 |
+
BLOCK_SIZE = max(pad_n_q_head, pad_n_kv_head)
|
122 |
+
|
123 |
+
n_row = batch_size * seq_len
|
124 |
+
|
125 |
+
# ensure tensors passed into the kernel are contiguous. It will be no-op if they are already contiguous
|
126 |
+
q = q.contiguous()
|
127 |
+
k = k.contiguous()
|
128 |
+
cos = cos.contiguous()
|
129 |
+
sin = sin.contiguous()
|
130 |
+
|
131 |
+
_triton_qwen2vl_mrope[(n_row,)](
|
132 |
+
q,
|
133 |
+
k,
|
134 |
+
cos,
|
135 |
+
sin,
|
136 |
+
seq_len,
|
137 |
+
batch_size,
|
138 |
+
n_q_head,
|
139 |
+
n_kv_head,
|
140 |
+
head_dim,
|
141 |
+
pad_n_q_head,
|
142 |
+
pad_n_kv_head,
|
143 |
+
pad_hd,
|
144 |
+
mrope_section[0],
|
145 |
+
mrope_section[1],
|
146 |
+
BLOCK_SIZE=BLOCK_SIZE,
|
147 |
+
BACKWARD_PASS=False,
|
148 |
+
)
|
149 |
+
return q.transpose(1, 2), k.transpose(1, 2), cos, sin
|
150 |
+
|
151 |
+
|
152 |
+
def qwen2vl_mrope_backward(dq, dk, cos, sin, mrope_section):
|
153 |
+
dq = dq.transpose(1, 2)
|
154 |
+
dk = dk.transpose(1, 2)
|
155 |
+
|
156 |
+
batch_size, seq_len, n_q_head, head_dim = dq.shape
|
157 |
+
n_kv_head = dk.shape[2]
|
158 |
+
pad_hd = triton.next_power_of_2(head_dim)
|
159 |
+
pad_n_q_head = triton.next_power_of_2(n_q_head)
|
160 |
+
pad_n_kv_head = triton.next_power_of_2(n_kv_head)
|
161 |
+
BLOCK_SIZE = max(pad_n_q_head, pad_n_kv_head)
|
162 |
+
|
163 |
+
n_row = batch_size * seq_len
|
164 |
+
|
165 |
+
# ensure dq and dk are contiguous
|
166 |
+
dq = dq.contiguous()
|
167 |
+
dk = dk.contiguous()
|
168 |
+
|
169 |
+
# backward is similar to forward except swapping few ops
|
170 |
+
_triton_qwen2vl_mrope[(n_row,)](
|
171 |
+
dq,
|
172 |
+
dk,
|
173 |
+
cos,
|
174 |
+
sin,
|
175 |
+
seq_len,
|
176 |
+
batch_size,
|
177 |
+
n_q_head,
|
178 |
+
n_kv_head,
|
179 |
+
head_dim,
|
180 |
+
pad_n_q_head,
|
181 |
+
pad_n_kv_head,
|
182 |
+
pad_hd,
|
183 |
+
mrope_section[0],
|
184 |
+
mrope_section[1],
|
185 |
+
BLOCK_SIZE=BLOCK_SIZE,
|
186 |
+
BACKWARD_PASS=True,
|
187 |
+
)
|
188 |
+
return dq.transpose(1, 2), dk.transpose(1, 2)
|
189 |
+
|
190 |
+
|
191 |
+
class LigerQwen2VLMRopeFunction(torch.autograd.Function):
|
192 |
+
"""
|
193 |
+
Triton implementation of the Qwen2VL Multimodal Rotary Positional Embedding (M-RoPE) operation.
|
194 |
+
|
195 |
+
Please find the corresponding HuggingFace implementation here:
|
196 |
+
https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py
|
197 |
+
"""
|
198 |
+
|
199 |
+
@staticmethod
|
200 |
+
def forward(ctx, q, k, cos, sin, mrope_section, unsqueeze_dim=1):
|
201 |
+
"""
|
202 |
+
q size: (bsz, n_q_head, seq_len, head_dim)
|
203 |
+
k size: (bsz, n_kv_head, seq_len, head_dim)
|
204 |
+
cos size: (3, bsz, seq_len, head_dim)
|
205 |
+
sin size: (3, bsz, seq_len, head_dim)
|
206 |
+
"""
|
207 |
+
q, k, cos, sin = qwen2vl_mrope_forward(q, k, cos, sin, mrope_section)
|
208 |
+
ctx.save_for_backward(cos, sin)
|
209 |
+
ctx.mrope_section = mrope_section
|
210 |
+
return q, k
|
211 |
+
|
212 |
+
def backward(ctx, dq, dk):
|
213 |
+
"""
|
214 |
+
dq size: (bsz, n_q_head, seq_len, head_dim)
|
215 |
+
dk size: (bsz, n_kv_head, seq_len, head_dim)
|
216 |
+
cos size: (3, bsz, seq_len, head_dim)
|
217 |
+
sin size: (3, bsz, seq_len, head_dim)
|
218 |
+
"""
|
219 |
+
cos, sin = ctx.saved_tensors
|
220 |
+
mrope_section = ctx.mrope_section
|
221 |
+
dq, dk = qwen2vl_mrope_backward(dq, dk, cos, sin, mrope_section)
|
222 |
+
return dq, dk, None, None, None, None
|
build/torch-universal/liger_kernels/rms_norm.py
ADDED
@@ -0,0 +1,365 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This file incorporates code from Unsloth licensed under the Apache License, Version 2.0.
|
3 |
+
See the original Unsloth repository at https://github.com/unslothai/unsloth.
|
4 |
+
|
5 |
+
The following line
|
6 |
+
https://github.com/linkedin/Liger-Kernel/blob/7382a8761f9af679482b968f9348013d933947c7/src/liger_kernel/ops/rms_norm.py#L30
|
7 |
+
is based on code from Unsloth, located at:
|
8 |
+
https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/rms_layernorm.py#L22
|
9 |
+
|
10 |
+
Modifications made by Yanning Chen, 2024.
|
11 |
+
"""
|
12 |
+
|
13 |
+
import math
|
14 |
+
import operator
|
15 |
+
|
16 |
+
import torch
|
17 |
+
import triton
|
18 |
+
import triton.language as tl
|
19 |
+
|
20 |
+
from utils import calculate_settings
|
21 |
+
from utils import compare_version
|
22 |
+
from utils import ensure_contiguous
|
23 |
+
from utils import torch_to_triton_dtype
|
24 |
+
|
25 |
+
if compare_version("triton", operator.ge, "3.0.0"):
|
26 |
+
try:
|
27 |
+
# typical import path with dispatch available
|
28 |
+
from triton.language.extra.libdevice import rsqrt
|
29 |
+
except ModuleNotFoundError:
|
30 |
+
# for working with NGC containers
|
31 |
+
from triton.language.extra.cuda.libdevice import rsqrt
|
32 |
+
else:
|
33 |
+
from triton.language.math import rsqrt
|
34 |
+
|
35 |
+
|
36 |
+
_CASTING_MODE_NONE: tl.constexpr = tl.constexpr(-1)
|
37 |
+
_CASTING_MODE_LLAMA: tl.constexpr = tl.constexpr(0)
|
38 |
+
_CASTING_MODE_GEMMA: tl.constexpr = tl.constexpr(1)
|
39 |
+
|
40 |
+
|
41 |
+
@triton.jit
|
42 |
+
def _rms_norm_forward_kernel(
|
43 |
+
Y_ptr,
|
44 |
+
Y_row_stride,
|
45 |
+
X_ptr,
|
46 |
+
X_row_stride,
|
47 |
+
W_ptr,
|
48 |
+
W_row_stride,
|
49 |
+
RSTD_ptr,
|
50 |
+
RSTD_row_stride,
|
51 |
+
n_cols,
|
52 |
+
eps,
|
53 |
+
offset,
|
54 |
+
casting_mode: tl.constexpr, # constexpr so the `if` blocks can be optimized out
|
55 |
+
BLOCK_SIZE: tl.constexpr,
|
56 |
+
):
|
57 |
+
"""
|
58 |
+
y_i = (x_i / (RMS)) * (offset + wi), RMS = sqrt(sum(x_i^2) / N)
|
59 |
+
|
60 |
+
Reference:
|
61 |
+
1. https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
|
62 |
+
2. https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/rms_layernorm.py#L22
|
63 |
+
3. https://arxiv.org/pdf/1910.07467
|
64 |
+
"""
|
65 |
+
|
66 |
+
row_idx = tl.program_id(0)
|
67 |
+
col_offsets = tl.arange(0, BLOCK_SIZE)
|
68 |
+
mask = col_offsets < n_cols
|
69 |
+
|
70 |
+
Y_ptr += row_idx * Y_row_stride
|
71 |
+
X_ptr += row_idx * X_row_stride
|
72 |
+
RSTD_ptr += row_idx * RSTD_row_stride
|
73 |
+
|
74 |
+
X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0)
|
75 |
+
X_row_dtype = X_row.dtype
|
76 |
+
W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0)
|
77 |
+
|
78 |
+
# On Llama, only rstd is computed on fp32
|
79 |
+
if casting_mode == _CASTING_MODE_LLAMA:
|
80 |
+
X_row = X_row.to(tl.float32)
|
81 |
+
|
82 |
+
# Gemma computes everything on fp32, and then casts back the output to the original dtype
|
83 |
+
if casting_mode == _CASTING_MODE_GEMMA:
|
84 |
+
W_row = W_row.to(tl.float32)
|
85 |
+
X_row = X_row.to(tl.float32)
|
86 |
+
|
87 |
+
if casting_mode == _CASTING_MODE_NONE:
|
88 |
+
eps = eps.to(X_row_dtype)
|
89 |
+
offset = offset.to(X_row_dtype)
|
90 |
+
|
91 |
+
mean_square = tl.sum(X_row * X_row, axis=0) / n_cols
|
92 |
+
rstd = rsqrt(mean_square + eps)
|
93 |
+
|
94 |
+
# We can save time by caching rms with minimal memory overhead
|
95 |
+
# because rms is much smaller compared to X_row, as rms is for each row.
|
96 |
+
# However, on the computation side, it can save 4 operations (*, sum, /, sqrt).
|
97 |
+
tl.store(RSTD_ptr, rstd)
|
98 |
+
|
99 |
+
X_row = X_row * rstd
|
100 |
+
|
101 |
+
# On Llama, the multiplication with the weight is done on the original dtype
|
102 |
+
if casting_mode == _CASTING_MODE_LLAMA:
|
103 |
+
X_row = X_row.to(X_row_dtype)
|
104 |
+
|
105 |
+
Y_row = X_row * (offset + W_row)
|
106 |
+
|
107 |
+
if casting_mode == _CASTING_MODE_GEMMA:
|
108 |
+
Y_row = Y_row.to(X_row_dtype)
|
109 |
+
|
110 |
+
tl.store(Y_ptr + col_offsets, Y_row, mask=mask)
|
111 |
+
|
112 |
+
|
113 |
+
@triton.jit
|
114 |
+
def _rms_norm_backward_kernel(
|
115 |
+
dY_ptr,
|
116 |
+
dY_row_stride,
|
117 |
+
dX_ptr,
|
118 |
+
dX_row_stride,
|
119 |
+
X_ptr,
|
120 |
+
X_row_stride,
|
121 |
+
X_dtype: tl.constexpr,
|
122 |
+
W_ptr,
|
123 |
+
W_row_stride,
|
124 |
+
RSTD_ptr,
|
125 |
+
RSTD_row_stride,
|
126 |
+
dW_ptr,
|
127 |
+
dW_row_stride,
|
128 |
+
n_rows,
|
129 |
+
n_cols,
|
130 |
+
offset,
|
131 |
+
rows_per_program: tl.constexpr,
|
132 |
+
casting_mode: tl.constexpr,
|
133 |
+
BLOCK_SIZE: tl.constexpr,
|
134 |
+
):
|
135 |
+
"""
|
136 |
+
dx = (1 / RMS) * [dy * (w + offset - (1 / N) * (1 / RMS^2) * ((dy * (w + offset)) dot x) * x]. * means element-wise multiplication, whileas dot means dot product
|
137 |
+
dw = sum(dy * (x / RMS)). summation over BxT dimension
|
138 |
+
"""
|
139 |
+
|
140 |
+
row_block_id = tl.program_id(0)
|
141 |
+
row_start = row_block_id * rows_per_program
|
142 |
+
row_end = min((row_block_id + 1) * rows_per_program, n_rows)
|
143 |
+
col_offsets = tl.arange(0, BLOCK_SIZE)
|
144 |
+
mask = col_offsets < n_cols
|
145 |
+
|
146 |
+
dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
|
147 |
+
|
148 |
+
dY_ptr += row_start * dY_row_stride
|
149 |
+
dX_ptr += row_start * dX_row_stride
|
150 |
+
|
151 |
+
X_ptr += row_start * X_row_stride
|
152 |
+
RSTD_ptr += row_start
|
153 |
+
|
154 |
+
W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0.0)
|
155 |
+
W_row = W_row + offset
|
156 |
+
|
157 |
+
for _ in range(row_start, row_end):
|
158 |
+
dY_row = tl.load(dY_ptr + col_offsets, mask=mask, other=0.0)
|
159 |
+
X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0.0)
|
160 |
+
|
161 |
+
# Get cached rms
|
162 |
+
rstd_row = tl.load(RSTD_ptr)
|
163 |
+
|
164 |
+
X_row = X_row.to(tl.float32)
|
165 |
+
|
166 |
+
# Different bacward graphs for different casting modes
|
167 |
+
if casting_mode == _CASTING_MODE_LLAMA:
|
168 |
+
m = (dY_row * W_row).to(tl.float32)
|
169 |
+
|
170 |
+
elif casting_mode == _CASTING_MODE_GEMMA:
|
171 |
+
dY_row = dY_row.to(tl.float32)
|
172 |
+
m = dY_row * W_row
|
173 |
+
else:
|
174 |
+
m = dY_row * W_row
|
175 |
+
|
176 |
+
dX_row = rstd_row * m
|
177 |
+
|
178 |
+
dX_row += (rstd_row) * (-(1 / n_cols) * rstd_row * rstd_row * tl.sum(m * X_row, axis=0) * X_row)
|
179 |
+
|
180 |
+
# calculate the gradient of W
|
181 |
+
if casting_mode == _CASTING_MODE_LLAMA:
|
182 |
+
dW_row += dY_row * (X_row * rstd_row).to(X_dtype)
|
183 |
+
else:
|
184 |
+
# here X_row is already in fp32 (see previous if block)
|
185 |
+
dW_row += dY_row * (X_row * rstd_row)
|
186 |
+
|
187 |
+
tl.store(dX_ptr + col_offsets, dX_row.to(X_dtype), mask=mask)
|
188 |
+
|
189 |
+
dY_ptr += dY_row_stride
|
190 |
+
dX_ptr += dX_row_stride
|
191 |
+
X_ptr += X_row_stride
|
192 |
+
RSTD_ptr += RSTD_row_stride
|
193 |
+
|
194 |
+
tl.store(dW_ptr + row_block_id * dW_row_stride + col_offsets, dW_row, mask=mask)
|
195 |
+
|
196 |
+
|
197 |
+
_str_to_casting_mode = {
|
198 |
+
"llama": _CASTING_MODE_LLAMA.value,
|
199 |
+
"gemma": _CASTING_MODE_GEMMA.value,
|
200 |
+
"none": _CASTING_MODE_NONE.value,
|
201 |
+
}
|
202 |
+
|
203 |
+
|
204 |
+
def rms_norm_forward(X, W, eps, offset, casting_mode):
|
205 |
+
if not isinstance(casting_mode, int):
|
206 |
+
assert casting_mode in _str_to_casting_mode, f"Invalid casting mode: {casting_mode}"
|
207 |
+
casting_mode = _str_to_casting_mode[casting_mode]
|
208 |
+
else:
|
209 |
+
assert casting_mode in _str_to_casting_mode.values(), f"Invalid casting mode: {casting_mode}"
|
210 |
+
|
211 |
+
shape = X.shape
|
212 |
+
dim = shape[-1]
|
213 |
+
X = X.view(-1, dim)
|
214 |
+
n_rows, n_cols = X.shape
|
215 |
+
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
216 |
+
|
217 |
+
Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
|
218 |
+
# RSTD is to cache rstd for each row
|
219 |
+
# RSTD is always computed/stored in fp32 if we are using Llama or Gemma casting mode
|
220 |
+
rstd_dtype = torch.float32 if casting_mode in (_CASTING_MODE_LLAMA.value, _CASTING_MODE_GEMMA.value) else X.dtype
|
221 |
+
RSTD = torch.empty(n_rows, dtype=rstd_dtype, device=X.device)
|
222 |
+
|
223 |
+
# Check constraints.
|
224 |
+
assert X.shape[1] == W.shape[0], "Incompatible hidden size dimension between tensor1.shape[1] and tensor2.shape[0]"
|
225 |
+
|
226 |
+
# XPU-specific optimization
|
227 |
+
kernel_args = {}
|
228 |
+
if X.device.type == "xpu":
|
229 |
+
kernel_args["grf_mode"] = "large"
|
230 |
+
_rms_norm_forward_kernel[(n_rows,)](
|
231 |
+
Y,
|
232 |
+
Y.stride(0),
|
233 |
+
X,
|
234 |
+
X.stride(0),
|
235 |
+
W,
|
236 |
+
W.stride(0),
|
237 |
+
RSTD,
|
238 |
+
RSTD.stride(0),
|
239 |
+
n_cols,
|
240 |
+
eps,
|
241 |
+
offset,
|
242 |
+
casting_mode,
|
243 |
+
BLOCK_SIZE=BLOCK_SIZE,
|
244 |
+
num_warps=num_warps,
|
245 |
+
**kernel_args, # XPU-specific optimization
|
246 |
+
)
|
247 |
+
return Y.view(*shape), X, RSTD, BLOCK_SIZE, num_warps, casting_mode
|
248 |
+
|
249 |
+
|
250 |
+
def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warps, in_place):
|
251 |
+
shape = dY.shape
|
252 |
+
dim = shape[-1]
|
253 |
+
dY = dY.view(-1, dim)
|
254 |
+
n_rows, n_cols = dY.shape
|
255 |
+
|
256 |
+
sm_count = 1
|
257 |
+
if X.device.type == "cuda":
|
258 |
+
sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
|
259 |
+
elif X.device.type == "xpu":
|
260 |
+
sm_count = torch.xpu.get_device_properties(X.device).gpu_eu_count
|
261 |
+
|
262 |
+
# fp32 for numerical stability especially.
|
263 |
+
_dW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)
|
264 |
+
|
265 |
+
if n_cols > BLOCK_SIZE:
|
266 |
+
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
|
267 |
+
rows_per_program = math.ceil(n_rows / sm_count)
|
268 |
+
grid = (sm_count,)
|
269 |
+
|
270 |
+
if in_place is True:
|
271 |
+
dX = dY
|
272 |
+
else:
|
273 |
+
dX = torch.zeros_like(dY)
|
274 |
+
|
275 |
+
# XPU-specific optimization
|
276 |
+
kernel_args = {}
|
277 |
+
if X.device.type == "xpu":
|
278 |
+
kernel_args["grf_mode"] = "large"
|
279 |
+
|
280 |
+
_rms_norm_backward_kernel[grid](
|
281 |
+
dY,
|
282 |
+
dY.stride(0),
|
283 |
+
dX,
|
284 |
+
dX.stride(0),
|
285 |
+
X,
|
286 |
+
X.stride(0),
|
287 |
+
torch_to_triton_dtype[X.dtype],
|
288 |
+
W,
|
289 |
+
W.stride(0),
|
290 |
+
RSTD,
|
291 |
+
RSTD.stride(0),
|
292 |
+
_dW,
|
293 |
+
_dW.stride(0),
|
294 |
+
n_rows,
|
295 |
+
n_cols,
|
296 |
+
offset,
|
297 |
+
rows_per_program,
|
298 |
+
casting_mode,
|
299 |
+
BLOCK_SIZE=BLOCK_SIZE,
|
300 |
+
num_warps=num_warps,
|
301 |
+
**kernel_args, # XPU-specific optimization
|
302 |
+
)
|
303 |
+
dX = dX.view(*shape)
|
304 |
+
dW = _dW.sum(dim=0).to(W.dtype)
|
305 |
+
|
306 |
+
return dX, dW
|
307 |
+
|
308 |
+
|
309 |
+
class LigerRMSNormFunction(torch.autograd.Function):
|
310 |
+
"""
|
311 |
+
Performs RMSNorm (Root Mean Square Normalization), which normalizes the input tensor `X` using the
|
312 |
+
weight tensor `W`, with an optional offset and casting mode.
|
313 |
+
|
314 |
+
Some models use an 'offset' to shift the weight tensor `W` by a constant value. For example, Gemma
|
315 |
+
uses an offset of 1.0, so the computation becomes `(X / RMS(X)) * (W + 1.0)` instead of the usual
|
316 |
+
`(X / RMS(X)) * W`. You can pass the offset value as an argument to the forward function.
|
317 |
+
|
318 |
+
In addition, different models cast their inputs at different places during RMSNorm computation. For
|
319 |
+
example, Gemma casts everything to fp32 nefore starting the computation, while Llama casts only the
|
320 |
+
inverse RMS to fp32. You can specify the casting mode using the `casting_mode` argument. We currently
|
321 |
+
support the following casting modes (they match HuggingFace Transformers' implementations):
|
322 |
+
- 'llama': matches the Llama implementation, where only the inverse RMS is computed on fp32.
|
323 |
+
- 'gemma': matches the Gemma implementation, where everything is cast to fp32, then computed, then cast back to the original dtype.
|
324 |
+
- 'none': no casting is done. The computation is done in the original dtype. This saves memory and is slightly faster, but has more error w.r.t. the original implementation.
|
325 |
+
|
326 |
+
`in_place` option means whether to in_place modify dY to store dX. This is default to `True` to save memory. However, under certain cases, it can produce incorrect inputs.
|
327 |
+
For example, gemma2 uses two rmsnorm sequentially with residual in between. The resesidual part needs dY so it cannot be modified in-place.
|
328 |
+
Therefore, for the patching of RMSNorm in gemma2, we set `in_place` to `False`
|
329 |
+
"""
|
330 |
+
|
331 |
+
@staticmethod
|
332 |
+
@ensure_contiguous
|
333 |
+
def forward(ctx, X, W, eps, offset=0.0, casting_mode="llama", in_place=True):
|
334 |
+
"""
|
335 |
+
X: (B, T, H) or (BxT, H)
|
336 |
+
W: (H,)
|
337 |
+
"""
|
338 |
+
Y, X, RSTD, BLOCK_SIZE, num_warps, casting_mode = rms_norm_forward(X, W, eps, offset, casting_mode)
|
339 |
+
ctx.offset = offset
|
340 |
+
ctx.casting_mode = casting_mode
|
341 |
+
ctx.in_place = in_place
|
342 |
+
ctx.BLOCK_SIZE = BLOCK_SIZE
|
343 |
+
ctx.num_warps = num_warps
|
344 |
+
ctx.save_for_backward(X, W, RSTD)
|
345 |
+
return Y
|
346 |
+
|
347 |
+
@staticmethod
|
348 |
+
@ensure_contiguous
|
349 |
+
def backward(ctx, dY):
|
350 |
+
"""
|
351 |
+
Y: (B, T, H) or (BxT, H)
|
352 |
+
"""
|
353 |
+
X, W, RSTD = ctx.saved_tensors
|
354 |
+
dX, dW = rms_norm_backward(
|
355 |
+
dY,
|
356 |
+
X,
|
357 |
+
W,
|
358 |
+
RSTD,
|
359 |
+
ctx.offset,
|
360 |
+
ctx.casting_mode,
|
361 |
+
ctx.BLOCK_SIZE,
|
362 |
+
ctx.num_warps,
|
363 |
+
ctx.in_place,
|
364 |
+
)
|
365 |
+
return dX, dW, None, None, None, None
|
build/torch-universal/liger_kernels/rope.py
ADDED
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import triton
|
3 |
+
import triton.language as tl
|
4 |
+
|
5 |
+
|
6 |
+
@triton.jit
|
7 |
+
def _triton_rope(
|
8 |
+
q_ptr,
|
9 |
+
q_row_stride,
|
10 |
+
k_ptr,
|
11 |
+
k_row_stride,
|
12 |
+
cos,
|
13 |
+
cos_row_stride,
|
14 |
+
sin,
|
15 |
+
sin_row_stride,
|
16 |
+
sl,
|
17 |
+
bs: tl.constexpr,
|
18 |
+
cos_bs: tl.constexpr,
|
19 |
+
n_qh: tl.constexpr,
|
20 |
+
n_kh: tl.constexpr,
|
21 |
+
hd: tl.constexpr,
|
22 |
+
pad_n_qh: tl.constexpr,
|
23 |
+
pad_n_kh: tl.constexpr,
|
24 |
+
pad_hd: tl.constexpr,
|
25 |
+
BLOCK_SIZE: tl.constexpr,
|
26 |
+
BACKWARD_PASS: tl.constexpr = False,
|
27 |
+
):
|
28 |
+
# q size: (bsz, seq_len, num_q_heads, head_dim)
|
29 |
+
# q stride: (seq_len * num_q_heads * head_dim, num_q_heads * head_dim, head_dim, 1)
|
30 |
+
# k size: (bsz, seq_len, num_kv_heads, head_dim)
|
31 |
+
# k stride: (seq_len * num_kv_heads * head_dim, num_kv_heads * head_dim, head_dim, 1)
|
32 |
+
|
33 |
+
# cos size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
|
34 |
+
# stride: (seq_len * head_dim, head_dim, 1)
|
35 |
+
pid = tl.program_id(0)
|
36 |
+
|
37 |
+
# locate start address
|
38 |
+
q_ptr = q_ptr + pid * q_row_stride
|
39 |
+
k_ptr = k_ptr + pid * k_row_stride
|
40 |
+
|
41 |
+
# ####################################################################
|
42 |
+
# get the cos(mθ_{i...d/2}) and sin(mθ_{i...d/2}) for token position
|
43 |
+
# m of this program instance
|
44 |
+
# ####################################################################
|
45 |
+
|
46 |
+
# 1. program instances are laid out in a 1D vector of size bsz * seq_len, which
|
47 |
+
# effectively represents a 2D grid of size [bsz, seq_len] with seq_len dimension
|
48 |
+
# being the fastest changing dimension. Thus we can simply do pid // sl to get the batch index
|
49 |
+
# and pid % sl to get the sequence index.
|
50 |
+
# 2. We only need the left half of cos and sin matrix because the right half is just
|
51 |
+
# a clone of the left half.
|
52 |
+
batch_idx = pid // sl
|
53 |
+
cos_row_idx = pid % sl
|
54 |
+
cos = cos + tl.where(
|
55 |
+
cos_bs == 1,
|
56 |
+
cos_row_idx * cos_row_stride,
|
57 |
+
batch_idx * (sl * cos_row_stride) + cos_row_idx * cos_row_stride,
|
58 |
+
)
|
59 |
+
sin = sin + tl.where(
|
60 |
+
cos_bs == 1,
|
61 |
+
cos_row_idx * sin_row_stride,
|
62 |
+
batch_idx * (sl * sin_row_stride) + cos_row_idx * sin_row_stride,
|
63 |
+
)
|
64 |
+
|
65 |
+
cos_offsets = tl.arange(0, pad_hd // 2)
|
66 |
+
cos_mask = cos_offsets < hd // 2
|
67 |
+
cos_row = tl.load(cos + cos_offsets, mask=cos_mask, other=0)
|
68 |
+
sin_row = tl.load(sin + cos_offsets, mask=cos_mask, other=0)
|
69 |
+
|
70 |
+
# ####################################################################
|
71 |
+
# Load the left and right half of q and k for the current
|
72 |
+
# program instance (i.e. for the current token) separately
|
73 |
+
# ####################################################################
|
74 |
+
# left half of the head
|
75 |
+
first_half_q_offsets = tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
|
76 |
+
first_half_k_offsets = tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
|
77 |
+
first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (tl.arange(0, pad_hd // 2)[None, :] < hd // 2)
|
78 |
+
first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (tl.arange(0, pad_hd // 2)[None, :] < hd // 2)
|
79 |
+
q_tile_1 = tl.load(q_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to(sin_row.dtype)
|
80 |
+
k_tile_1 = tl.load(k_ptr + first_half_k_offsets, mask=first_k_mask, other=0).to(sin_row.dtype)
|
81 |
+
|
82 |
+
# right half of the head
|
83 |
+
second_half_q_offsets = first_half_q_offsets + (hd // 2)
|
84 |
+
second_half_k_offsets = first_half_k_offsets + (hd // 2)
|
85 |
+
second_q_mask = first_q_mask
|
86 |
+
second_k_mask = first_k_mask
|
87 |
+
q_tile_2 = tl.load(q_ptr + second_half_q_offsets, mask=second_q_mask, other=0).to(sin_row.dtype)
|
88 |
+
k_tile_2 = tl.load(k_ptr + second_half_k_offsets, mask=second_k_mask, other=0).to(sin_row.dtype)
|
89 |
+
|
90 |
+
if not BACKWARD_PASS:
|
91 |
+
# y = [x1, x2] * [cos, cos] + [-x2, x1] * [sin, sin]
|
92 |
+
new_q_tile_1 = q_tile_1 * cos_row - q_tile_2 * sin_row
|
93 |
+
tl.store(q_ptr + first_half_q_offsets, new_q_tile_1, mask=first_q_mask)
|
94 |
+
new_q_tile_2 = q_tile_2 * cos_row + q_tile_1 * sin_row
|
95 |
+
tl.store(q_ptr + second_half_q_offsets, new_q_tile_2, mask=second_q_mask)
|
96 |
+
|
97 |
+
new_k_tile_1 = k_tile_1 * cos_row - k_tile_2 * sin_row
|
98 |
+
tl.store(k_ptr + first_half_k_offsets, new_k_tile_1, mask=first_k_mask)
|
99 |
+
new_k_tile_2 = k_tile_2 * cos_row + k_tile_1 * sin_row
|
100 |
+
tl.store(k_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask)
|
101 |
+
else:
|
102 |
+
# with some math, we can get:
|
103 |
+
# dy = [dx1, dx2] * [cos, cos] + [-dx2, dx1] * [-sin, -sin]
|
104 |
+
new_q_tile_1 = q_tile_1 * cos_row + q_tile_2 * sin_row
|
105 |
+
tl.store(q_ptr + first_half_q_offsets, new_q_tile_1, mask=first_q_mask)
|
106 |
+
new_q_tile_2 = q_tile_2 * cos_row - q_tile_1 * sin_row
|
107 |
+
tl.store(q_ptr + second_half_q_offsets, new_q_tile_2, mask=second_q_mask)
|
108 |
+
|
109 |
+
new_k_tile_1 = k_tile_1 * cos_row + k_tile_2 * sin_row
|
110 |
+
tl.store(k_ptr + first_half_k_offsets, new_k_tile_1, mask=first_k_mask)
|
111 |
+
new_k_tile_2 = k_tile_2 * cos_row - k_tile_1 * sin_row
|
112 |
+
tl.store(k_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask)
|
113 |
+
|
114 |
+
|
115 |
+
def rope_forward(q, k, cos, sin):
|
116 |
+
# transpose it back to the physical shape because Triton looks at the physical storage
|
117 |
+
# note: q and k are incontiguous before the transformation and will become contiguous after transpose
|
118 |
+
q = q.transpose(1, 2)
|
119 |
+
k = k.transpose(1, 2)
|
120 |
+
|
121 |
+
batch_size, seq_len, n_q_head, head_dim = q.shape
|
122 |
+
n_kv_head = k.shape[2]
|
123 |
+
pad_hd = triton.next_power_of_2(head_dim)
|
124 |
+
pad_n_q_head = triton.next_power_of_2(n_q_head)
|
125 |
+
pad_n_kv_head = triton.next_power_of_2(n_kv_head)
|
126 |
+
BLOCK_SIZE = max(pad_n_q_head, pad_n_kv_head)
|
127 |
+
|
128 |
+
n_row = batch_size * seq_len
|
129 |
+
|
130 |
+
# ensure tensors passed into the kernel are contiguous. It will be no-op if they are already contiguous
|
131 |
+
q = q.contiguous()
|
132 |
+
k = k.contiguous()
|
133 |
+
cos = cos.contiguous()
|
134 |
+
sin = sin.contiguous()
|
135 |
+
cos_batch_size = cos.shape[0]
|
136 |
+
|
137 |
+
_triton_rope[(n_row,)](
|
138 |
+
q,
|
139 |
+
q.stride(1),
|
140 |
+
k,
|
141 |
+
k.stride(1),
|
142 |
+
cos,
|
143 |
+
cos.stride(-2),
|
144 |
+
sin,
|
145 |
+
sin.stride(-2),
|
146 |
+
seq_len,
|
147 |
+
batch_size,
|
148 |
+
cos_batch_size,
|
149 |
+
n_q_head,
|
150 |
+
n_kv_head,
|
151 |
+
head_dim,
|
152 |
+
pad_n_q_head,
|
153 |
+
pad_n_kv_head,
|
154 |
+
pad_hd,
|
155 |
+
BLOCK_SIZE=BLOCK_SIZE,
|
156 |
+
BACKWARD_PASS=False,
|
157 |
+
)
|
158 |
+
return q.transpose(1, 2), k.transpose(1, 2), cos, sin
|
159 |
+
|
160 |
+
|
161 |
+
def rope_backward(dq, dk, cos, sin):
|
162 |
+
dq = dq.transpose(1, 2)
|
163 |
+
dk = dk.transpose(1, 2)
|
164 |
+
|
165 |
+
batch_size, seq_len, n_q_head, head_dim = dq.shape
|
166 |
+
cos_batch_size = cos.shape[0]
|
167 |
+
n_kv_head = dk.shape[2]
|
168 |
+
pad_hd = triton.next_power_of_2(head_dim)
|
169 |
+
pad_n_q_head = triton.next_power_of_2(n_q_head)
|
170 |
+
pad_n_kv_head = triton.next_power_of_2(n_kv_head)
|
171 |
+
BLOCK_SIZE = max(pad_n_q_head, pad_n_kv_head)
|
172 |
+
|
173 |
+
n_row = batch_size * seq_len
|
174 |
+
|
175 |
+
# ensure dq and dk are contiguous
|
176 |
+
dq = dq.contiguous()
|
177 |
+
dk = dk.contiguous()
|
178 |
+
|
179 |
+
# backward is similar to forward except swapping few ops
|
180 |
+
_triton_rope[(n_row,)](
|
181 |
+
dq,
|
182 |
+
dq.stride(1),
|
183 |
+
dk,
|
184 |
+
dk.stride(1),
|
185 |
+
cos,
|
186 |
+
cos.stride(-2),
|
187 |
+
sin,
|
188 |
+
sin.stride(-2),
|
189 |
+
seq_len,
|
190 |
+
batch_size,
|
191 |
+
cos_batch_size,
|
192 |
+
n_q_head,
|
193 |
+
n_kv_head,
|
194 |
+
head_dim,
|
195 |
+
pad_n_q_head,
|
196 |
+
pad_n_kv_head,
|
197 |
+
pad_hd,
|
198 |
+
BLOCK_SIZE=BLOCK_SIZE,
|
199 |
+
BACKWARD_PASS=True,
|
200 |
+
)
|
201 |
+
return dq.transpose(1, 2), dk.transpose(1, 2)
|
202 |
+
|
203 |
+
|
204 |
+
class LigerRopeFunction(torch.autograd.Function):
|
205 |
+
"""
|
206 |
+
Triton implementation of the Rotary Positional Embedding (RoPE) operation. Please note that
|
207 |
+
this implements the HuggingFace Llama & Mistral version, whose rotation matrix is slightly different
|
208 |
+
than the original RoPE paper.
|
209 |
+
|
210 |
+
Please find the corresponding HuggingFace implementation here:
|
211 |
+
https://github.com/huggingface/transformers/blob/v4.40.2/src/transformers/models/llama/modeling_llama.py#L184
|
212 |
+
|
213 |
+
For more details about the rotation matrix used here, please refer to:
|
214 |
+
https://discuss.huggingface.co/t/is-llama-rotary-embedding-implementation-correct/44509/2
|
215 |
+
"""
|
216 |
+
|
217 |
+
@staticmethod
|
218 |
+
def forward(ctx, q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
219 |
+
"""
|
220 |
+
q size: (bsz, n_q_head, seq_len, head_dim)
|
221 |
+
k size: (bsz, n_kv_head, seq_len, head_dim)
|
222 |
+
cos size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
|
223 |
+
sin size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
|
224 |
+
"""
|
225 |
+
q, k, cos, sin = rope_forward(q, k, cos, sin)
|
226 |
+
ctx.save_for_backward(cos, sin)
|
227 |
+
return q, k
|
228 |
+
|
229 |
+
def backward(ctx, dq, dk):
|
230 |
+
"""
|
231 |
+
dq size: (bsz, n_q_head, seq_len, head_dim)
|
232 |
+
dk size: (bsz, n_kv_head, seq_len, head_dim)
|
233 |
+
cos size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
|
234 |
+
sin size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
|
235 |
+
"""
|
236 |
+
|
237 |
+
cos, sin = ctx.saved_tensors
|
238 |
+
dq, dk = rope_backward(dq, dk, cos, sin)
|
239 |
+
return dq, dk, None, None, None, None
|
build/torch-universal/liger_kernels/swiglu.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import triton
|
3 |
+
import triton.language as tl
|
4 |
+
|
5 |
+
from utils import calculate_settings
|
6 |
+
from utils import ensure_contiguous
|
7 |
+
|
8 |
+
|
9 |
+
@triton.jit
|
10 |
+
def silu(x):
|
11 |
+
return x * tl.sigmoid(x)
|
12 |
+
|
13 |
+
|
14 |
+
@triton.jit
|
15 |
+
def _swiglu_forward_kernel(a_ptr, b_ptr, c_ptr, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr):
|
16 |
+
program_id = tl.program_id(0).to(tl.int64)
|
17 |
+
|
18 |
+
# locate start index
|
19 |
+
a_ptr += program_id * stride
|
20 |
+
b_ptr += program_id * stride
|
21 |
+
c_ptr += program_id * stride
|
22 |
+
|
23 |
+
col_offsets = tl.arange(0, BLOCK_SIZE)
|
24 |
+
mask = col_offsets < n_cols
|
25 |
+
|
26 |
+
# sigmoid requires type float32
|
27 |
+
a_row = tl.load(a_ptr + col_offsets, mask=mask, other=0).to(tl.float32)
|
28 |
+
b_row = tl.load(b_ptr + col_offsets, mask=mask, other=0)
|
29 |
+
c_row = silu(a_row) * b_row
|
30 |
+
tl.store(c_ptr + col_offsets, c_row, mask=mask)
|
31 |
+
|
32 |
+
|
33 |
+
@triton.jit
|
34 |
+
def _swiglu_backward_kernel(dc_ptr, a_ptr, b_ptr, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr):
|
35 |
+
program_id = tl.program_id(0).to(tl.int64)
|
36 |
+
|
37 |
+
# locate start index
|
38 |
+
dc_ptr += program_id * stride
|
39 |
+
a_ptr += program_id * stride
|
40 |
+
b_ptr += program_id * stride
|
41 |
+
|
42 |
+
col_offsets = tl.arange(0, BLOCK_SIZE)
|
43 |
+
mask = col_offsets < n_cols
|
44 |
+
|
45 |
+
dc_row = tl.load(dc_ptr + col_offsets, mask=mask, other=0)
|
46 |
+
# sigmoid requires type float32
|
47 |
+
a_row = tl.load(a_ptr + col_offsets, mask=mask, other=0).to(tl.float32)
|
48 |
+
b_row = tl.load(b_ptr + col_offsets, mask=mask, other=0)
|
49 |
+
|
50 |
+
# recomputation to save memory
|
51 |
+
sig_a = tl.sigmoid(a_row)
|
52 |
+
silu_a = a_row * sig_a
|
53 |
+
db_row = dc_row * silu_a
|
54 |
+
da_row = dc_row * (silu_a * (1 - sig_a) + sig_a) * b_row
|
55 |
+
|
56 |
+
tl.store(a_ptr + col_offsets, da_row, mask=mask)
|
57 |
+
tl.store(b_ptr + col_offsets, db_row, mask=mask)
|
58 |
+
|
59 |
+
|
60 |
+
def swiglu_forward(a, b):
|
61 |
+
ori_shape = a.shape
|
62 |
+
|
63 |
+
n_cols = ori_shape[-1]
|
64 |
+
a = a.view(-1, n_cols)
|
65 |
+
b = b.view(-1, n_cols)
|
66 |
+
c = torch.empty_like(a)
|
67 |
+
n_rows = a.shape[0]
|
68 |
+
|
69 |
+
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
70 |
+
|
71 |
+
_swiglu_forward_kernel[(n_rows,)](
|
72 |
+
a,
|
73 |
+
b,
|
74 |
+
c,
|
75 |
+
c.stride(-2),
|
76 |
+
n_cols=n_cols,
|
77 |
+
BLOCK_SIZE=BLOCK_SIZE,
|
78 |
+
num_warps=num_warps,
|
79 |
+
)
|
80 |
+
return a, b, c.view(*ori_shape)
|
81 |
+
|
82 |
+
|
83 |
+
def swiglu_backward(a, b, dc):
|
84 |
+
ori_shape = dc.shape
|
85 |
+
n_cols = ori_shape[-1]
|
86 |
+
dc = dc.view(-1, n_cols)
|
87 |
+
n_rows = dc.shape[0]
|
88 |
+
|
89 |
+
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
90 |
+
|
91 |
+
_swiglu_backward_kernel[(n_rows,)](
|
92 |
+
dc,
|
93 |
+
a,
|
94 |
+
b,
|
95 |
+
dc.stride(-2),
|
96 |
+
n_cols=n_cols,
|
97 |
+
BLOCK_SIZE=BLOCK_SIZE,
|
98 |
+
num_warps=num_warps,
|
99 |
+
)
|
100 |
+
return a.view(*ori_shape), b.view(*ori_shape)
|
101 |
+
|
102 |
+
|
103 |
+
class LigerSiLUMulFunction(torch.autograd.Function):
|
104 |
+
@staticmethod
|
105 |
+
@ensure_contiguous
|
106 |
+
def forward(ctx, a, b):
|
107 |
+
a, b, c = swiglu_forward(a, b)
|
108 |
+
ctx.save_for_backward(a, b)
|
109 |
+
return c
|
110 |
+
|
111 |
+
@staticmethod
|
112 |
+
@ensure_contiguous
|
113 |
+
def backward(ctx, dc):
|
114 |
+
a, b = ctx.saved_tensors
|
115 |
+
a, b = swiglu_backward(a, b, dc)
|
116 |
+
return a, b
|
build/torch-universal/liger_kernels/tvd.py
ADDED
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Literal
|
2 |
+
from typing import Optional
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import triton
|
6 |
+
import triton.language as tl
|
7 |
+
|
8 |
+
from utils import ensure_contiguous
|
9 |
+
|
10 |
+
MAX_FUSED_SIZE = 65536 // 4
|
11 |
+
|
12 |
+
REDUCTION_LITERAL = Literal["none", "sum", "mean", "batchmean"]
|
13 |
+
|
14 |
+
_REDUCTION_MODE_NONE = tl.constexpr(0)
|
15 |
+
_REDUCTION_MODE_SUM = tl.constexpr(1)
|
16 |
+
_REDUCTION_MODE_MEAN = tl.constexpr(2)
|
17 |
+
_REDUCTION_MODE_BATCHMEAN = tl.constexpr(3)
|
18 |
+
|
19 |
+
_str_to_reduction_mode = {
|
20 |
+
"none": _REDUCTION_MODE_NONE.value,
|
21 |
+
"sum": _REDUCTION_MODE_SUM.value,
|
22 |
+
"mean": _REDUCTION_MODE_MEAN.value,
|
23 |
+
"batchmean": _REDUCTION_MODE_BATCHMEAN.value,
|
24 |
+
}
|
25 |
+
|
26 |
+
|
27 |
+
def get_num_warps(BLOCK_SIZE):
|
28 |
+
num_warps = 4
|
29 |
+
if BLOCK_SIZE >= 32768:
|
30 |
+
num_warps = 32
|
31 |
+
elif BLOCK_SIZE >= 8192:
|
32 |
+
num_warps = 16
|
33 |
+
elif BLOCK_SIZE >= 2048:
|
34 |
+
num_warps = 8
|
35 |
+
|
36 |
+
return num_warps
|
37 |
+
|
38 |
+
|
39 |
+
@triton.jit
|
40 |
+
def _tv_distance_kernel(
|
41 |
+
p_ptr,
|
42 |
+
p_stride,
|
43 |
+
q_ptr,
|
44 |
+
q_stride,
|
45 |
+
loss_ptr,
|
46 |
+
loss_stride,
|
47 |
+
grads_ptr,
|
48 |
+
grads_stride,
|
49 |
+
label_ptr,
|
50 |
+
ignore_index: tl.constexpr,
|
51 |
+
n_cols,
|
52 |
+
BLOCK_SIZE: tl.constexpr,
|
53 |
+
HAS_LABEL: tl.constexpr,
|
54 |
+
reduction: tl.constexpr = _REDUCTION_MODE_BATCHMEAN,
|
55 |
+
):
|
56 |
+
pid = tl.program_id(0).to(tl.int64)
|
57 |
+
p_ptr += pid * p_stride
|
58 |
+
q_ptr += pid * q_stride
|
59 |
+
loss_ptr += pid * loss_stride
|
60 |
+
grads_ptr += pid * grads_stride
|
61 |
+
label_ptr += pid
|
62 |
+
|
63 |
+
base_offsets = tl.arange(0, BLOCK_SIZE)
|
64 |
+
|
65 |
+
if HAS_LABEL:
|
66 |
+
label = tl.load(label_ptr)
|
67 |
+
if label == ignore_index:
|
68 |
+
for i in range(0, n_cols, BLOCK_SIZE):
|
69 |
+
offsets = i + base_offsets
|
70 |
+
mask = offsets < n_cols
|
71 |
+
tl.store(grads_ptr + offsets, 0.0, mask=mask)
|
72 |
+
if reduction == _REDUCTION_MODE_NONE:
|
73 |
+
tl.store(loss_ptr + offsets, 0.0, mask=mask)
|
74 |
+
return
|
75 |
+
|
76 |
+
loss_sum = 0.0
|
77 |
+
for i in range(0, n_cols, BLOCK_SIZE):
|
78 |
+
offsets = i + base_offsets
|
79 |
+
mask = offsets < n_cols
|
80 |
+
|
81 |
+
p = tl.load(p_ptr + offsets, mask=mask, other=0.0)
|
82 |
+
q = tl.load(q_ptr + offsets, mask=mask, other=0.0)
|
83 |
+
|
84 |
+
# TVD(P || Q) = 0.5 * |P - Q|
|
85 |
+
tv_loss = 0.5 * tl.abs(p - q)
|
86 |
+
|
87 |
+
grad_res = tl.where(p > q, 0.5, -0.5)
|
88 |
+
|
89 |
+
tl.store(grads_ptr + offsets, grad_res, mask=mask)
|
90 |
+
|
91 |
+
if reduction == _REDUCTION_MODE_NONE:
|
92 |
+
tl.store(loss_ptr + offsets, tv_loss, mask=mask)
|
93 |
+
else:
|
94 |
+
loss_sum += tl.sum(tv_loss, axis=0)
|
95 |
+
|
96 |
+
if reduction != _REDUCTION_MODE_NONE:
|
97 |
+
tl.store(loss_ptr, loss_sum)
|
98 |
+
|
99 |
+
|
100 |
+
def tv_distance_forward_triton(p, q, shift_labels, reduction, ignore_index, has_label):
|
101 |
+
BT, V = p.shape
|
102 |
+
|
103 |
+
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
|
104 |
+
num_warps = get_num_warps(BLOCK_SIZE)
|
105 |
+
|
106 |
+
grid = (BT,)
|
107 |
+
|
108 |
+
reduction = _str_to_reduction_mode[reduction]
|
109 |
+
|
110 |
+
out_size = (BT, V) if reduction == _REDUCTION_MODE_NONE.value else (BT,)
|
111 |
+
output_tensor = torch.zeros(out_size, device=p.device, dtype=torch.float32)
|
112 |
+
grads = torch.empty_like(p)
|
113 |
+
|
114 |
+
n_non_ignore = (shift_labels != ignore_index).sum().item() if has_label else BT
|
115 |
+
|
116 |
+
_tv_distance_kernel[grid](
|
117 |
+
p,
|
118 |
+
p.stride(0),
|
119 |
+
q,
|
120 |
+
q.stride(0),
|
121 |
+
output_tensor,
|
122 |
+
output_tensor.stride(0),
|
123 |
+
grads,
|
124 |
+
grads.stride(0),
|
125 |
+
shift_labels if has_label else torch.empty(1, device=p.device),
|
126 |
+
ignore_index,
|
127 |
+
V,
|
128 |
+
BLOCK_SIZE=BLOCK_SIZE,
|
129 |
+
HAS_LABEL=has_label,
|
130 |
+
num_warps=num_warps,
|
131 |
+
reduction=reduction,
|
132 |
+
)
|
133 |
+
|
134 |
+
if reduction == _REDUCTION_MODE_BATCHMEAN.value:
|
135 |
+
return output_tensor.sum() / n_non_ignore, grads / n_non_ignore
|
136 |
+
elif reduction == _REDUCTION_MODE_SUM.value:
|
137 |
+
return output_tensor.sum(dim=0), grads
|
138 |
+
elif reduction == _REDUCTION_MODE_MEAN.value:
|
139 |
+
return output_tensor.sum() / (n_non_ignore * V), grads / (n_non_ignore * V)
|
140 |
+
else:
|
141 |
+
return output_tensor, grads
|
142 |
+
|
143 |
+
|
144 |
+
def tvd_backward_triton(grad_output, grads):
|
145 |
+
# If cross entropy is the last layer, grad_output is 1.0. Skip the mul then.
|
146 |
+
if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
|
147 |
+
return grads
|
148 |
+
|
149 |
+
return grads * grad_output
|
150 |
+
|
151 |
+
|
152 |
+
class LigerTVDLossFunction(torch.autograd.Function):
|
153 |
+
"""
|
154 |
+
Class implementing the forward and backward pass for the Total Variation Distance Loss using Triton.
|
155 |
+
"""
|
156 |
+
|
157 |
+
@staticmethod
|
158 |
+
@ensure_contiguous
|
159 |
+
def forward(
|
160 |
+
ctx,
|
161 |
+
p: torch.Tensor,
|
162 |
+
q: torch.Tensor,
|
163 |
+
shift_labels: Optional[torch.Tensor] = None,
|
164 |
+
reduction: REDUCTION_LITERAL = "batchmean",
|
165 |
+
ignore_index: int = -100,
|
166 |
+
) -> torch.Tensor:
|
167 |
+
"""A forward pass for the Total Variation Distance Loss.
|
168 |
+
|
169 |
+
Args:
|
170 |
+
ctx: Torch autograd context
|
171 |
+
p (torch.Tensor): A tensor of shape (BT, V) containing the first distribution.
|
172 |
+
q (torch.Tensor): A tensor of shape (BT, V) containing the second distribution.
|
173 |
+
shift_labels (Optional[torch.Tensor]): A tensor of shape (BT,) containing the labels.
|
174 |
+
reduction (REDUCTION_LITERAL, optional): The reduction method to be applied. Defaults to "batchmean".
|
175 |
+
ignore_index (int, optional): The index to ignore during loss calculation. Defaults to -100.
|
176 |
+
|
177 |
+
Returns:
|
178 |
+
torch.Tensor: The computed Total Variation Distance Loss.
|
179 |
+
"""
|
180 |
+
has_label = False
|
181 |
+
if shift_labels is not None:
|
182 |
+
assert shift_labels.shape == (p.shape[0],), (
|
183 |
+
f"the shape of shift_labels must be (BT,). Got: {shift_labels.shape}"
|
184 |
+
)
|
185 |
+
shift_labels = shift_labels.contiguous()
|
186 |
+
has_label = True
|
187 |
+
|
188 |
+
loss, grads = tv_distance_forward_triton(p, q, shift_labels, reduction, ignore_index, has_label)
|
189 |
+
ctx.save_for_backward(grads)
|
190 |
+
return loss
|
191 |
+
|
192 |
+
@staticmethod
|
193 |
+
@ensure_contiguous
|
194 |
+
def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
|
195 |
+
"""A backward pass for the Total Variation Distance Loss.
|
196 |
+
|
197 |
+
Args:
|
198 |
+
ctx: Torch autograd context
|
199 |
+
grad_output (torch.Tensor): The gradient of the loss with respect to the output.
|
200 |
+
|
201 |
+
Returns:
|
202 |
+
tuple[torch.Tensor, None, None, None, None]: The gradient of the loss with respect to the inputs.
|
203 |
+
"""
|
204 |
+
(grads,) = ctx.saved_tensors
|
205 |
+
grads = tvd_backward_triton(grad_output, grads)
|
206 |
+
|
207 |
+
return grads, None, None, None, None
|
build/torch-universal/liger_kernels/utils.py
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This file incorporates code from Unsloth licensed under the Apache License, Version 2.0.
|
3 |
+
See the original Unsloth repository at https://github.com/unslothai/unsloth.
|
4 |
+
|
5 |
+
The following line
|
6 |
+
https://github.com/linkedin/Liger-Kernel/blob/7382a8761f9af679482b968f9348013d933947c7/src/liger_kernel/ops/utils.py#L23
|
7 |
+
is based on code from Unsloth, located at:
|
8 |
+
https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/utils.py#L43
|
9 |
+
|
10 |
+
Modifications made by Yanning Chen, 2024.
|
11 |
+
"""
|
12 |
+
|
13 |
+
import functools
|
14 |
+
import importlib
|
15 |
+
import operator
|
16 |
+
|
17 |
+
from typing import Callable
|
18 |
+
|
19 |
+
import torch
|
20 |
+
import triton
|
21 |
+
import triton.language as tl
|
22 |
+
|
23 |
+
from packaging.version import Version
|
24 |
+
|
25 |
+
def infer_device():
|
26 |
+
"""
|
27 |
+
Get current device name based on available devices
|
28 |
+
"""
|
29 |
+
if torch.cuda.is_available(): # Works for both Nvidia and AMD
|
30 |
+
return "cuda"
|
31 |
+
elif torch.xpu.is_available():
|
32 |
+
return "xpu"
|
33 |
+
else:
|
34 |
+
return "cpu"
|
35 |
+
|
36 |
+
def is_hip() -> bool:
|
37 |
+
return torch.version.hip is not None
|
38 |
+
|
39 |
+
|
40 |
+
def ensure_contiguous(fn):
|
41 |
+
@functools.wraps(fn)
|
42 |
+
def wrapper(ctx, *args, **kwargs):
|
43 |
+
def maybe_to_contiguous(x):
|
44 |
+
return x.contiguous() if isinstance(x, torch.Tensor) else x
|
45 |
+
|
46 |
+
args = [maybe_to_contiguous(arg) for arg in args]
|
47 |
+
kwargs = {k: maybe_to_contiguous(v) for k, v in kwargs.items()}
|
48 |
+
return fn(ctx, *args, **kwargs)
|
49 |
+
|
50 |
+
return wrapper
|
51 |
+
|
52 |
+
|
53 |
+
def calculate_settings(n):
|
54 |
+
# reference: https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/utils.py#L43
|
55 |
+
|
56 |
+
MAX_FUSED_SIZE = 65536
|
57 |
+
BLOCK_SIZE = triton.next_power_of_2(n)
|
58 |
+
if BLOCK_SIZE > MAX_FUSED_SIZE:
|
59 |
+
raise RuntimeError(
|
60 |
+
f"Cannot launch Triton kernel since n = {n} exceeds the recommended Triton blocksize = {MAX_FUSED_SIZE}."
|
61 |
+
)
|
62 |
+
|
63 |
+
num_warps = 4
|
64 |
+
if BLOCK_SIZE >= 32768:
|
65 |
+
num_warps = 32 if not is_hip() else 16
|
66 |
+
elif BLOCK_SIZE >= 8192:
|
67 |
+
num_warps = 16
|
68 |
+
elif BLOCK_SIZE >= 2048:
|
69 |
+
num_warps = 8
|
70 |
+
return BLOCK_SIZE, num_warps
|
71 |
+
|
72 |
+
|
73 |
+
def compare_version(package: str, operator: Callable, target: str):
|
74 |
+
try:
|
75 |
+
pkg = importlib.import_module(package)
|
76 |
+
except ImportError:
|
77 |
+
return False
|
78 |
+
pkg_version = Version(pkg.__version__)
|
79 |
+
return operator(pkg_version, Version(target))
|
80 |
+
|
81 |
+
|
82 |
+
def get_amp_custom_fwd_bwd() -> Callable:
|
83 |
+
device = infer_device()
|
84 |
+
if compare_version("torch", operator.ge, "2.4.0"):
|
85 |
+
return (
|
86 |
+
functools.partial(torch.amp.custom_fwd, device_type=device),
|
87 |
+
functools.partial(torch.amp.custom_bwd, device_type=device),
|
88 |
+
)
|
89 |
+
return torch.cuda.amp.custom_fwd, torch.cuda.amp.custom_bwd
|
90 |
+
|
91 |
+
|
92 |
+
amp_custom_fwd, amp_custom_bwd = get_amp_custom_fwd_bwd()
|
93 |
+
|
94 |
+
|
95 |
+
torch_to_triton_dtype = {
|
96 |
+
torch.float32: tl.float32,
|
97 |
+
torch.float16: tl.float16,
|
98 |
+
torch.bfloat16: tl.bfloat16,
|
99 |
+
}
|
100 |
+
|
101 |
+
|
102 |
+
@triton.jit
|
103 |
+
def element_mul_kernel(
|
104 |
+
X_ptr,
|
105 |
+
X_stride,
|
106 |
+
grad_output_ptr,
|
107 |
+
n_cols,
|
108 |
+
BLOCK_SIZE: tl.constexpr,
|
109 |
+
):
|
110 |
+
"""
|
111 |
+
This function multiplies each element of the tensor pointed by X_ptr with the value pointed by grad_output_ptr.
|
112 |
+
The multiplication is performed in-place on the tensor pointed by X_ptr.
|
113 |
+
|
114 |
+
Parameters:
|
115 |
+
X_ptr: Pointer to the input tensor.
|
116 |
+
X_stride (int): The stride of the input tensor.
|
117 |
+
grad_output_ptr: Pointer to the gradient output value.
|
118 |
+
n_cols (int): The number of columns in the input tensor.
|
119 |
+
BLOCK_SIZE (int): The block size for Triton operations.
|
120 |
+
"""
|
121 |
+
|
122 |
+
# Get the program ID and convert it to int64 to avoid overflow
|
123 |
+
program_id = tl.program_id(0).to(tl.int64)
|
124 |
+
|
125 |
+
# Locate the start index
|
126 |
+
X_ptr += program_id * X_stride
|
127 |
+
|
128 |
+
# Load the gradient output value
|
129 |
+
grad_output = tl.load(grad_output_ptr)
|
130 |
+
|
131 |
+
# Perform the element-wise multiplication
|
132 |
+
for i in range(0, n_cols, BLOCK_SIZE):
|
133 |
+
X_offsets = i + tl.arange(0, BLOCK_SIZE)
|
134 |
+
X_block = tl.load(X_ptr + X_offsets, mask=X_offsets < n_cols)
|
135 |
+
tl.store(X_ptr + X_offsets, X_block * grad_output, mask=X_offsets < n_cols)
|
flake.lock
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"nodes": {
|
3 |
+
"flake-compat": {
|
4 |
+
"locked": {
|
5 |
+
"lastModified": 1733328505,
|
6 |
+
"narHash": "sha256-NeCCThCEP3eCl2l/+27kNNK7QrwZB1IJCrXfrbv5oqU=",
|
7 |
+
"owner": "edolstra",
|
8 |
+
"repo": "flake-compat",
|
9 |
+
"rev": "ff81ac966bb2cae68946d5ed5fc4994f96d0ffec",
|
10 |
+
"type": "github"
|
11 |
+
},
|
12 |
+
"original": {
|
13 |
+
"owner": "edolstra",
|
14 |
+
"repo": "flake-compat",
|
15 |
+
"type": "github"
|
16 |
+
}
|
17 |
+
},
|
18 |
+
"flake-utils": {
|
19 |
+
"inputs": {
|
20 |
+
"systems": "systems"
|
21 |
+
},
|
22 |
+
"locked": {
|
23 |
+
"lastModified": 1731533236,
|
24 |
+
"narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
|
25 |
+
"owner": "numtide",
|
26 |
+
"repo": "flake-utils",
|
27 |
+
"rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
|
28 |
+
"type": "github"
|
29 |
+
},
|
30 |
+
"original": {
|
31 |
+
"owner": "numtide",
|
32 |
+
"repo": "flake-utils",
|
33 |
+
"type": "github"
|
34 |
+
}
|
35 |
+
},
|
36 |
+
"kernel-builder": {
|
37 |
+
"inputs": {
|
38 |
+
"flake-compat": "flake-compat",
|
39 |
+
"flake-utils": "flake-utils",
|
40 |
+
"nixpkgs": "nixpkgs",
|
41 |
+
"rocm-nix": "rocm-nix"
|
42 |
+
},
|
43 |
+
"locked": {
|
44 |
+
"lastModified": 1745579622,
|
45 |
+
"narHash": "sha256-g8BXijChxDCZNu17M4Jj0GPv/7faVnArbHBOMNMpHjM=",
|
46 |
+
"owner": "huggingface",
|
47 |
+
"repo": "kernel-builder",
|
48 |
+
"rev": "e2f6f338737c6f1c570f9b59e43182633c0879c1",
|
49 |
+
"type": "github"
|
50 |
+
},
|
51 |
+
"original": {
|
52 |
+
"owner": "huggingface",
|
53 |
+
"repo": "kernel-builder",
|
54 |
+
"type": "github"
|
55 |
+
}
|
56 |
+
},
|
57 |
+
"nixpkgs": {
|
58 |
+
"locked": {
|
59 |
+
"lastModified": 1743559129,
|
60 |
+
"narHash": "sha256-7gpAWsENV3tY2HmeHYQ2MoQxGpys+jQWnkS/BHAMXVk=",
|
61 |
+
"owner": "nixos",
|
62 |
+
"repo": "nixpkgs",
|
63 |
+
"rev": "adae22bea8bcc0aa2fd6e8732044660fb7755f5e",
|
64 |
+
"type": "github"
|
65 |
+
},
|
66 |
+
"original": {
|
67 |
+
"owner": "nixos",
|
68 |
+
"ref": "nixos-unstable-small",
|
69 |
+
"repo": "nixpkgs",
|
70 |
+
"type": "github"
|
71 |
+
}
|
72 |
+
},
|
73 |
+
"rocm-nix": {
|
74 |
+
"inputs": {
|
75 |
+
"nixpkgs": [
|
76 |
+
"kernel-builder",
|
77 |
+
"nixpkgs"
|
78 |
+
]
|
79 |
+
},
|
80 |
+
"locked": {
|
81 |
+
"lastModified": 1745310663,
|
82 |
+
"narHash": "sha256-1U3PzCO/jt7HUlEgLOY3RpxadKwTo6GSvb2j4m0UFw0=",
|
83 |
+
"owner": "huggingface",
|
84 |
+
"repo": "rocm-nix",
|
85 |
+
"rev": "e08373a0efa1c297b0c57af070e0a311df47481f",
|
86 |
+
"type": "github"
|
87 |
+
},
|
88 |
+
"original": {
|
89 |
+
"owner": "huggingface",
|
90 |
+
"repo": "rocm-nix",
|
91 |
+
"type": "github"
|
92 |
+
}
|
93 |
+
},
|
94 |
+
"root": {
|
95 |
+
"inputs": {
|
96 |
+
"kernel-builder": "kernel-builder"
|
97 |
+
}
|
98 |
+
},
|
99 |
+
"systems": {
|
100 |
+
"locked": {
|
101 |
+
"lastModified": 1681028828,
|
102 |
+
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
|
103 |
+
"owner": "nix-systems",
|
104 |
+
"repo": "default",
|
105 |
+
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
|
106 |
+
"type": "github"
|
107 |
+
},
|
108 |
+
"original": {
|
109 |
+
"owner": "nix-systems",
|
110 |
+
"repo": "default",
|
111 |
+
"type": "github"
|
112 |
+
}
|
113 |
+
}
|
114 |
+
},
|
115 |
+
"root": "root",
|
116 |
+
"version": 7
|
117 |
+
}
|
flake.nix
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
description = "Flake for Unsloth Kernels";
|
3 |
+
|
4 |
+
inputs = {
|
5 |
+
kernel-builder.url = "github:huggingface/kernel-builder";
|
6 |
+
};
|
7 |
+
|
8 |
+
outputs =
|
9 |
+
{
|
10 |
+
self,
|
11 |
+
kernel-builder,
|
12 |
+
}:
|
13 |
+
kernel-builder.lib.genFlakeOutputs {
|
14 |
+
path = ./.;
|
15 |
+
rev = self.shortRev or self.dirtyShortRev or self.lastModifiedDate;
|
16 |
+
};
|
17 |
+
}
|
torch-ext/liger_kernels/__init__.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from cross_entropy import LigerCrossEntropyFunction
|
2 |
+
from fused_linear_cross_entropy import LigerFusedLinearCrossEntropyFunction
|
3 |
+
from dyt import LigerDyTFunction
|
4 |
+
from geglu import LigerGELUMulFunction
|
5 |
+
from group_norm import LigerGroupNormFunction
|
6 |
+
from kl_div import LigerKLDivLossFunction
|
7 |
+
from layer_norm import LigerLayerNormFunction
|
8 |
+
from qwen2vl_mrope import LigerQwen2VLMRopeFunction
|
9 |
+
from rms_norm import LigerRMSNormFunction
|
10 |
+
from jsd import LigerJSDFunction
|
11 |
+
from rope import LigerRopeFunction
|
12 |
+
from swiglu import LigerSiLUMulFunction
|
13 |
+
from tvd import LigerTVDLossFunction
|
14 |
+
|
15 |
+
__all__ = [
|
16 |
+
"LigerCrossEntropyFunction",
|
17 |
+
"LigerFusedLinearCrossEntropyFunction",
|
18 |
+
"LigerDyTFunction",
|
19 |
+
"LigerGELUMulFunction",
|
20 |
+
"LigerGroupNormFunction",
|
21 |
+
"LigerKLDivLossFunction",
|
22 |
+
"LigerLayerNormFunction",
|
23 |
+
"LigerQwen2VLMRopeFunction",
|
24 |
+
"LigerRMSNormFunction",
|
25 |
+
"LigerJSDFunction",
|
26 |
+
"LigerRopeFunction",
|
27 |
+
"LigerSiLUMulFunction",
|
28 |
+
"LigerTVDLossFunction",
|
29 |
+
]
|
torch-ext/liger_kernels/cross_entropy.py
ADDED
@@ -0,0 +1,460 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import operator
|
2 |
+
|
3 |
+
from typing import Optional
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import triton
|
7 |
+
import triton.language as tl
|
8 |
+
|
9 |
+
from utils import compare_version
|
10 |
+
from utils import element_mul_kernel
|
11 |
+
from utils import is_hip
|
12 |
+
from utils import infer_device
|
13 |
+
|
14 |
+
if compare_version("triton", operator.ge, "3.0.0"):
|
15 |
+
try:
|
16 |
+
# typical import path with dispatch available
|
17 |
+
from triton.language.extra.libdevice import tanh
|
18 |
+
except ModuleNotFoundError:
|
19 |
+
# for working with NGC containers
|
20 |
+
from triton.language.extra.cuda.libdevice import tanh
|
21 |
+
else:
|
22 |
+
from triton.language.math import tanh
|
23 |
+
|
24 |
+
|
25 |
+
@triton.jit
|
26 |
+
def liger_cross_entropy_kernel(
|
27 |
+
X_ptr,
|
28 |
+
X_stride,
|
29 |
+
Y_ptr,
|
30 |
+
Y_stride,
|
31 |
+
weight_ptr,
|
32 |
+
loss_ptr,
|
33 |
+
z_loss_ptr,
|
34 |
+
loss_stride,
|
35 |
+
n_cols,
|
36 |
+
n_non_ignore,
|
37 |
+
sum_non_ignore_weight,
|
38 |
+
weight_sum,
|
39 |
+
ignore_index,
|
40 |
+
lse_square_scale: tl.constexpr,
|
41 |
+
label_smoothing: tl.constexpr,
|
42 |
+
reduction: tl.constexpr, # set it as constexpr since reduction is always known at compile time
|
43 |
+
softcap,
|
44 |
+
RETURN_Z_LOSS: tl.constexpr,
|
45 |
+
BLOCK_SIZE: tl.constexpr,
|
46 |
+
HAS_WEIGHT: tl.constexpr,
|
47 |
+
HAS_SOFTCAPPING: tl.constexpr,
|
48 |
+
):
|
49 |
+
"""
|
50 |
+
This kernel computes both cross entropy loss and the gradient of the input.
|
51 |
+
We only consider hard label + mean reduction for now. Please refer to https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html for the math.
|
52 |
+
|
53 |
+
Parameters:
|
54 |
+
X_ptr: Pointer to input tensor.
|
55 |
+
X_stride (int): The stride of the input tensor.
|
56 |
+
Y_ptr: Pointer to target tensor.
|
57 |
+
Y_stride (int): The stride of the target tensor.
|
58 |
+
weight_ptr: Pointer to weight tensor.
|
59 |
+
loss_ptr: Pointer to tensor to store the loss.
|
60 |
+
z_loss_ptr: Pointer to tensor to store the z loss. No operation if RETURN_Z_LOSS is 0.
|
61 |
+
loss_stride (int): The stride of the loss tensor.
|
62 |
+
n_cols (int): The number of columns in the input tensor.
|
63 |
+
n_non_ignore (float): The number of non-ignored elements in the batch.
|
64 |
+
sum_non_ignore_weight (float): The sum of non-ignored target's weights in the batch.
|
65 |
+
weight_sum (float): The sum of weight tensor.
|
66 |
+
ignore_index (int): The index to ignore in the target.
|
67 |
+
label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
|
68 |
+
lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training.
|
69 |
+
reduction (str): The string for the reduction to apply
|
70 |
+
softcap (float): The upper threshold for scaling logits to the range (-softcap, +softcap).
|
71 |
+
RETURN_Z_LOSS (int): The boolean value to decide whether storing z loss to z_loss_ptr or not. It must be 0 or 1.
|
72 |
+
BLOCK_SIZE (int): The block size for Triton operations.
|
73 |
+
HAS_WEIGHT (bool): The boolean value to determine whether assigning weight to each of the classes.
|
74 |
+
HAS_SOFTCAPPING (bool): The boolean value to determine whether applying soft-capping or not.
|
75 |
+
"""
|
76 |
+
|
77 |
+
# https://github.com/triton-lang/triton/issues/1058
|
78 |
+
# If B*T*V is too large, program_id * stride will overflow out of int32, so we convert to int64
|
79 |
+
program_id = tl.program_id(0).to(tl.int64)
|
80 |
+
|
81 |
+
# 1. Load Y_ptr first because if the target is ignore_index, we can return right away
|
82 |
+
Y_ptr += program_id * Y_stride
|
83 |
+
y = tl.load(Y_ptr)
|
84 |
+
|
85 |
+
# 2. locate the start index
|
86 |
+
X_ptr += program_id * X_stride
|
87 |
+
|
88 |
+
if y == ignore_index:
|
89 |
+
# set all X_ptr as 0
|
90 |
+
for i in range(0, n_cols, BLOCK_SIZE):
|
91 |
+
X_offsets = i + tl.arange(0, BLOCK_SIZE)
|
92 |
+
tl.store(X_ptr + X_offsets, 0.0, mask=X_offsets < n_cols)
|
93 |
+
return
|
94 |
+
|
95 |
+
loss_ptr += program_id * loss_stride
|
96 |
+
if RETURN_Z_LOSS:
|
97 |
+
z_loss_ptr += program_id * loss_stride
|
98 |
+
|
99 |
+
if HAS_WEIGHT:
|
100 |
+
weight_y = tl.load(weight_ptr + y).cast(tl.float32)
|
101 |
+
|
102 |
+
# Online softmax: 2 loads + 1 store (compared with 3 loads + 1 store for the safe softmax)
|
103 |
+
# Refer to Algorithm 3 in the paper: https://arxiv.org/pdf/1805.02867
|
104 |
+
|
105 |
+
# 3. [Online softmax] first pass: find max + sum
|
106 |
+
m = float("-inf") # m is the max value. use the notation from the paper
|
107 |
+
d = 0.0 # d is the sum. use the notation from the paper
|
108 |
+
ori_X_y = tl.load(X_ptr + y).cast(tl.float32) # we need to store the original value of X_y for the loss calculation
|
109 |
+
if HAS_SOFTCAPPING:
|
110 |
+
ori_X_y = softcap * tanh(ori_X_y / softcap)
|
111 |
+
|
112 |
+
# Label smoothing is a general case of normal cross entropy
|
113 |
+
# See the full derivation at https://github.com/linkedin/Liger-Kernel/pull/198#issue-2503665310
|
114 |
+
scaled_x_sum = 0.0
|
115 |
+
eps = label_smoothing / n_cols
|
116 |
+
|
117 |
+
for i in range(0, n_cols, BLOCK_SIZE):
|
118 |
+
X_offsets = i + tl.arange(0, BLOCK_SIZE)
|
119 |
+
X_block = tl.load(
|
120 |
+
X_ptr + X_offsets,
|
121 |
+
mask=X_offsets < n_cols,
|
122 |
+
other=float("-inf"),
|
123 |
+
# Ensure float32 precision for softmax calculation
|
124 |
+
).cast(tl.float32)
|
125 |
+
if HAS_SOFTCAPPING:
|
126 |
+
X_block = softcap * tanh(X_block / softcap)
|
127 |
+
block_max = tl.max(X_block)
|
128 |
+
if label_smoothing > 0:
|
129 |
+
# scale X beforehand to avoid overflow
|
130 |
+
if HAS_WEIGHT:
|
131 |
+
weight_block = tl.load(weight_ptr + X_offsets, mask=X_offsets < n_cols)
|
132 |
+
scaled_x_sum += tl.sum(tl.where(X_offsets < n_cols, -eps * X_block * weight_block, 0.0))
|
133 |
+
else:
|
134 |
+
scaled_x_sum += tl.sum(tl.where(X_offsets < n_cols, -eps * X_block, 0.0))
|
135 |
+
m_new = tl.maximum(m, block_max)
|
136 |
+
d = d * tl.exp(m - m_new) + tl.sum(tl.exp(X_block - m_new))
|
137 |
+
m = m_new
|
138 |
+
|
139 |
+
# log (sum(e^(X_i))) = log (sum(e ^ (max(X) * e ^ (X_i - max(X)))))
|
140 |
+
# = log (e^(max(X)) * sum(e ^ (X_i - max(X))))
|
141 |
+
# = max(X) + log (sum(e ^ (X_i - max(X)))) = m + log d
|
142 |
+
lse = m + tl.log(d)
|
143 |
+
|
144 |
+
# 4. [Online Softmax] Second pass: compute gradients
|
145 |
+
# For 'mean' reduction, gradients are normalized by number of non-ignored elements (N)
|
146 |
+
# dx_y = (softmax(x_y) - 1) / N
|
147 |
+
# dx_i = softmax(x_i) / N, i != y
|
148 |
+
# For label smoothing:
|
149 |
+
# dx_i = (softmax(x_i) - label_smoothing / V) / N, V = n_cols, i != y
|
150 |
+
# dx_y = (softmax(x_y) - label_smoothing / V - (1 - label_smoothing)) / N
|
151 |
+
# = dx_i - (1 - label_smoothing) / N
|
152 |
+
# With Z loss:
|
153 |
+
# dx_i = ((1 + 2 * lse_square_scale * lse) * softmax(x_i) - label_smoothing / V) / N, i != y
|
154 |
+
# dx_y = dx_i - (1 - label_smoothing) / N
|
155 |
+
# For 'sum' reduction, no normalization is applied:
|
156 |
+
# dx_y = softmax(x_y) - 1
|
157 |
+
# dx_i = softmax(x_i), for i ≠ y
|
158 |
+
|
159 |
+
for i in range(0, n_cols, BLOCK_SIZE):
|
160 |
+
X_offsets = i + tl.arange(0, BLOCK_SIZE)
|
161 |
+
X_block = tl.load(
|
162 |
+
X_ptr + X_offsets,
|
163 |
+
mask=X_offsets < n_cols,
|
164 |
+
other=float("-inf"),
|
165 |
+
# Ensure float32 precision for softmax calculation
|
166 |
+
).cast(tl.float32)
|
167 |
+
if HAS_SOFTCAPPING:
|
168 |
+
intermediate = tanh(X_block / softcap)
|
169 |
+
X_block = softcap * intermediate
|
170 |
+
|
171 |
+
if not HAS_WEIGHT:
|
172 |
+
# softmax(x_i)
|
173 |
+
X_block = tl.exp(X_block - m) / d
|
174 |
+
# derivative of z-loss: 2 * lse_square_scale * lse * softmax(x_i)
|
175 |
+
X_block += 2 * lse_square_scale * lse * X_block
|
176 |
+
# smoothing term
|
177 |
+
X_block += -eps
|
178 |
+
# special handle dx_y
|
179 |
+
X_block = tl.where(X_offsets != y, X_block, X_block - (1 - label_smoothing))
|
180 |
+
# reduction scale
|
181 |
+
if reduction == "mean":
|
182 |
+
X_block = X_block / n_non_ignore
|
183 |
+
else:
|
184 |
+
weight_block = tl.load(weight_ptr + X_offsets, mask=X_offsets < n_cols)
|
185 |
+
softmax_X = tl.exp(X_block - m) / d
|
186 |
+
# derivative of original_loss
|
187 |
+
dloss_ori = (1 - label_smoothing) * softmax_X
|
188 |
+
# specially handle dx_y
|
189 |
+
dloss_ori = tl.where(X_offsets != y, dloss_ori, dloss_ori - (1 - label_smoothing))
|
190 |
+
dloss_ori = dloss_ori * weight_y
|
191 |
+
# derivative of smooth_loss
|
192 |
+
dloss_smooth = eps * (-weight_block + softmax_X * weight_sum)
|
193 |
+
# derivative of z-loss
|
194 |
+
dz_loss = 2 * lse_square_scale * lse * softmax_X
|
195 |
+
# reduction scale
|
196 |
+
if reduction == "mean":
|
197 |
+
dloss_ori = dloss_ori / sum_non_ignore_weight
|
198 |
+
dloss_smooth = dloss_smooth / sum_non_ignore_weight
|
199 |
+
# TODO: Implement weighted z_loss. Currently, z_loss is not scaled by weight.
|
200 |
+
dz_loss = dz_loss / n_non_ignore
|
201 |
+
# derivative of total_loss
|
202 |
+
X_block = dloss_ori + dloss_smooth + dz_loss
|
203 |
+
|
204 |
+
# chain rule softcapping
|
205 |
+
# d(softcap * tanh(x / softcap)) = (1 - tanh^2(x / softcap))
|
206 |
+
if HAS_SOFTCAPPING:
|
207 |
+
X_block = X_block * (1 - intermediate * intermediate)
|
208 |
+
|
209 |
+
tl.store(X_ptr + X_offsets, X_block, mask=X_offsets < n_cols)
|
210 |
+
|
211 |
+
# We need tl.debug_barrier() to ensure the new result of X_ptr is written as mentioned in
|
212 |
+
# https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/ops/cross_entropy.py#L34
|
213 |
+
tl.debug_barrier()
|
214 |
+
|
215 |
+
# 5. Calculate the loss
|
216 |
+
|
217 |
+
# loss = log (softmax(X_y)) = log ((e ^ (X_y - max(X)) / sum(e ^ (X - max(X))))
|
218 |
+
# = (X_y - max(X)) - log(sum(e ^ (X - max(X))))
|
219 |
+
# = X_y - m - log d = X_y - lse
|
220 |
+
# sum(e ^ (X - max(X))) must >= 1 because the max term is e ^ 0 = 1
|
221 |
+
# So we can safely calculate log (softmax(X_y)) without overflow
|
222 |
+
loss = lse - ori_X_y
|
223 |
+
if HAS_WEIGHT:
|
224 |
+
loss = weight_y * loss
|
225 |
+
|
226 |
+
# Original loss = H(q, p), with label smoothing regularization = H(q', p) and (label_smoothing / V) = eps
|
227 |
+
# H(q', p) = (1 - label_smoothing) * H(q, p) + label_smoothing * H(u, p)
|
228 |
+
# = (1 - label_smoothing) * H(q, p) + eps * sum(logsoftmax(x_i))
|
229 |
+
# By using m (global max of xi) and d (sum of e^(xi-m)), we can simplify as:
|
230 |
+
# = (1 - label_smoothing) * H(q, p) + (sum(-eps * x_i) + label_smoothing * (m + logd))
|
231 |
+
# Refer to H(q', p) in section 7 of the paper: https://arxiv.org/pdf/1512.00567
|
232 |
+
# pytorch: https://github.com/pytorch/pytorch/blob/2981534f54d49fa3a9755c9b0855e7929c2527f0/aten/src/ATen/native/LossNLL.cpp#L516
|
233 |
+
# See full derivation at https://github.com/linkedin/Liger-Kernel/pull/198#issuecomment-2333753087
|
234 |
+
if label_smoothing > 0:
|
235 |
+
if HAS_WEIGHT:
|
236 |
+
smooth_loss = scaled_x_sum + eps * lse * weight_sum
|
237 |
+
else:
|
238 |
+
smooth_loss = scaled_x_sum + label_smoothing * lse
|
239 |
+
loss = loss * (1 - label_smoothing) + smooth_loss
|
240 |
+
|
241 |
+
# An auxiliary loss, z_loss
|
242 |
+
# Refer to Page14 Loss function section in the paper PaLM: https://www.jmlr.org/papers/v24/22-1144.html
|
243 |
+
z_loss = lse_square_scale * lse * lse
|
244 |
+
# Normalize the loss by the number of non-ignored elements if reduction is "mean"
|
245 |
+
if reduction == "mean":
|
246 |
+
if HAS_WEIGHT:
|
247 |
+
loss = loss / sum_non_ignore_weight
|
248 |
+
else:
|
249 |
+
loss = loss / n_non_ignore
|
250 |
+
# TODO: Implement weighted z_loss. Currently, z_loss is not scaled by weight.
|
251 |
+
z_loss = z_loss / n_non_ignore
|
252 |
+
loss += z_loss
|
253 |
+
|
254 |
+
tl.store(loss_ptr, loss)
|
255 |
+
if RETURN_Z_LOSS:
|
256 |
+
tl.store(z_loss_ptr, z_loss)
|
257 |
+
|
258 |
+
|
259 |
+
# The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
|
260 |
+
# However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
|
261 |
+
# The optimal maximum block size depends on your hardware, your kernel, and your dtype
|
262 |
+
MAX_FUSED_SIZE = 4096 if infer_device() == "xpu" else 65536 // 2 # the best size we found by manually tuning
|
263 |
+
|
264 |
+
|
265 |
+
def cross_entropy_forward(
|
266 |
+
_input,
|
267 |
+
target,
|
268 |
+
weight,
|
269 |
+
ignore_index,
|
270 |
+
lse_square_scale,
|
271 |
+
label_smoothing,
|
272 |
+
reduction,
|
273 |
+
softcap,
|
274 |
+
return_z_loss,
|
275 |
+
):
|
276 |
+
assert isinstance(return_z_loss, bool), f"return_z_loss must be True or False. Got: {return_z_loss}"
|
277 |
+
|
278 |
+
BT, V = _input.shape
|
279 |
+
n_rows = BT
|
280 |
+
|
281 |
+
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
|
282 |
+
|
283 |
+
# unreduced loss
|
284 |
+
loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device)
|
285 |
+
z_loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device) if return_z_loss else None
|
286 |
+
|
287 |
+
target_mask = target != ignore_index
|
288 |
+
n_non_ignore = target_mask.sum().item()
|
289 |
+
assert (target * target_mask).max() < _input.shape[-1], (
|
290 |
+
f"Target {target.max()} is out of bounds. Expected < {_input.shape[-1]}"
|
291 |
+
)
|
292 |
+
assert (target * target_mask).min() >= 0, f"Target {target.min()} is out of bounds. Expected >= 0"
|
293 |
+
sum_non_ignore_weight = n_non_ignore
|
294 |
+
weight_sum = 0.0
|
295 |
+
if weight is not None:
|
296 |
+
assert weight.shape[0] == V, f"If given, weight has to be a Tensor of size V. Got: {weight.shape}"
|
297 |
+
assert torch.is_floating_point(weight), (
|
298 |
+
f"If given, weight has to be a Tensor of floating point dtype. Got: {weight.dtype}"
|
299 |
+
)
|
300 |
+
sum_non_ignore_weight = torch.gather(weight, dim=0, index=target.masked_select(target_mask)).sum().item()
|
301 |
+
weight_sum = weight.sum().item()
|
302 |
+
# ensure weight is contiguous
|
303 |
+
if weight.stride(-1) != 1:
|
304 |
+
weight = weight.contiguous()
|
305 |
+
|
306 |
+
# ensure _input and target are contiguous in the last dimension
|
307 |
+
if _input.stride(-1) != 1:
|
308 |
+
_input = _input.contiguous()
|
309 |
+
if target.stride(-1) != 1:
|
310 |
+
target = target.contiguous()
|
311 |
+
|
312 |
+
# Here we use a trick to store X_ptr gradient in X_ptr so we can save memory
|
313 |
+
liger_cross_entropy_kernel[(n_rows,)](
|
314 |
+
X_ptr=_input,
|
315 |
+
X_stride=_input.stride(-2),
|
316 |
+
Y_ptr=target,
|
317 |
+
Y_stride=target.stride(-1), # always 1
|
318 |
+
weight_ptr=weight, # dummy if None
|
319 |
+
loss_ptr=loss_1d,
|
320 |
+
z_loss_ptr=z_loss_1d,
|
321 |
+
loss_stride=loss_1d.stride(-1), # always 1
|
322 |
+
n_cols=V,
|
323 |
+
n_non_ignore=n_non_ignore,
|
324 |
+
sum_non_ignore_weight=sum_non_ignore_weight,
|
325 |
+
ignore_index=ignore_index,
|
326 |
+
weight_sum=weight_sum,
|
327 |
+
lse_square_scale=lse_square_scale,
|
328 |
+
label_smoothing=label_smoothing,
|
329 |
+
reduction=reduction,
|
330 |
+
softcap=softcap,
|
331 |
+
RETURN_Z_LOSS=return_z_loss,
|
332 |
+
BLOCK_SIZE=BLOCK_SIZE,
|
333 |
+
HAS_WEIGHT=True if weight is not None else False,
|
334 |
+
HAS_SOFTCAPPING=True if softcap is not None else False,
|
335 |
+
# TODO: 32 seems to give the best performance
|
336 |
+
# Performance is quite sensitive to num_warps
|
337 |
+
num_warps=32 if not is_hip() else 16,
|
338 |
+
)
|
339 |
+
|
340 |
+
if reduction == "none":
|
341 |
+
loss = loss_1d
|
342 |
+
z_loss = z_loss_1d if return_z_loss else None
|
343 |
+
else:
|
344 |
+
loss = torch.sum(loss_1d)
|
345 |
+
z_loss = torch.sum(z_loss_1d) if return_z_loss else None
|
346 |
+
|
347 |
+
return loss, z_loss, _input
|
348 |
+
|
349 |
+
|
350 |
+
def cross_entropy_backward(_input, grad_output):
|
351 |
+
# If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time
|
352 |
+
if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
|
353 |
+
pass
|
354 |
+
# If reduction is 'none'
|
355 |
+
elif grad_output.ndim > 0:
|
356 |
+
_input = _input * grad_output.unsqueeze(dim=1)
|
357 |
+
# If reduction is ['mean', 'sum'], grad_output is just a scalar
|
358 |
+
# We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place
|
359 |
+
# for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton.
|
360 |
+
else:
|
361 |
+
BT, V = _input.shape
|
362 |
+
n_rows = BT
|
363 |
+
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
|
364 |
+
|
365 |
+
element_mul_kernel[(n_rows,)](
|
366 |
+
_input,
|
367 |
+
_input.stride(-2),
|
368 |
+
grad_output,
|
369 |
+
V,
|
370 |
+
BLOCK_SIZE=BLOCK_SIZE,
|
371 |
+
num_warps=32 if not is_hip() else 16,
|
372 |
+
)
|
373 |
+
|
374 |
+
return _input
|
375 |
+
|
376 |
+
|
377 |
+
class LigerCrossEntropyFunction(torch.autograd.Function):
|
378 |
+
"""
|
379 |
+
This class implements a custom autograd function for the Liger Cross Entropy loss.
|
380 |
+
It overrides the forward and backward methods of the torch.autograd.Function class.
|
381 |
+
"""
|
382 |
+
|
383 |
+
@staticmethod
|
384 |
+
def forward(
|
385 |
+
ctx,
|
386 |
+
_input: torch.Tensor,
|
387 |
+
target: torch.Tensor,
|
388 |
+
weight: Optional[torch.FloatTensor],
|
389 |
+
ignore_index: int = -100,
|
390 |
+
lse_square_scale: float = 0.0,
|
391 |
+
label_smoothing: float = 0.0,
|
392 |
+
reduction: str = "mean",
|
393 |
+
softcap: Optional[float] = None,
|
394 |
+
return_z_loss: bool = False,
|
395 |
+
):
|
396 |
+
"""
|
397 |
+
The forward pass of the Liger Cross Entropy loss.
|
398 |
+
|
399 |
+
Parameters:
|
400 |
+
ctx : The context object.
|
401 |
+
_input (tensor): The input tensor of shape (BT, V) where B is batch size, T is sequence length, V is vocab size.
|
402 |
+
target (tensor): The target tensor of shape (BT) where each value is in [0, V-1].
|
403 |
+
weight(Tensor, optional): a manual rescaling weight given to each class. If given, has to be a Tensor of size V and floating point dtype
|
404 |
+
ignore_index (int): The index to ignore in the target.
|
405 |
+
lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training.
|
406 |
+
label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
|
407 |
+
reduction (str): The reduction to apply to the output: "none" | "mean | "sum".
|
408 |
+
softcap (Optional[float]): The upper threshold for scaling logits to the range (-softcap, +softcap).
|
409 |
+
return_z_loss (bool): When `return_z_loss` is `True`, returns (loss, z_loss) instead of (loss, None). Default: `False`
|
410 |
+
|
411 |
+
Returns:
|
412 |
+
tuple: A tuple with the compouted losses with respect to loss and z loss. The elements are tensors or None.
|
413 |
+
"""
|
414 |
+
loss, z_loss, _input = cross_entropy_forward(
|
415 |
+
_input,
|
416 |
+
target,
|
417 |
+
weight,
|
418 |
+
ignore_index,
|
419 |
+
lse_square_scale,
|
420 |
+
label_smoothing,
|
421 |
+
reduction,
|
422 |
+
softcap,
|
423 |
+
return_z_loss,
|
424 |
+
)
|
425 |
+
# TODO: investigation
|
426 |
+
# If we don't detach the _input tensor, the memory will double
|
427 |
+
# Not sure why but seems that there will be a time both grad and value exist but in different location
|
428 |
+
ctx.save_for_backward(_input.detach())
|
429 |
+
ctx.return_z_loss = return_z_loss
|
430 |
+
|
431 |
+
return loss, z_loss
|
432 |
+
|
433 |
+
@staticmethod
|
434 |
+
def backward(ctx, grad_output, grad_ouput2):
|
435 |
+
"""
|
436 |
+
The backward pass of the Liger Cross Entropy loss.
|
437 |
+
|
438 |
+
Parameters:
|
439 |
+
ctx : The context object with saved tensors.
|
440 |
+
grad_output (tensor): The tensor containing the gradient of the loss with respect to the output.
|
441 |
+
grad_output2 (tenosr): No use.
|
442 |
+
Returns:
|
443 |
+
tuple: A tuple with the gradients with respect to the inputs. The elements are tensors or None.
|
444 |
+
"""
|
445 |
+
if ctx.return_z_loss:
|
446 |
+
del grad_ouput2 # z_loss is only for logging
|
447 |
+
|
448 |
+
(_input,) = ctx.saved_tensors
|
449 |
+
_input = cross_entropy_backward(_input, grad_output)
|
450 |
+
return (
|
451 |
+
_input,
|
452 |
+
None,
|
453 |
+
None,
|
454 |
+
None,
|
455 |
+
None,
|
456 |
+
None,
|
457 |
+
None,
|
458 |
+
None,
|
459 |
+
None,
|
460 |
+
)
|
torch-ext/liger_kernels/dyt.py
ADDED
@@ -0,0 +1,225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import operator
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import triton
|
5 |
+
import triton.language as tl
|
6 |
+
|
7 |
+
from utils import calculate_settings
|
8 |
+
from utils import compare_version
|
9 |
+
from utils import ensure_contiguous
|
10 |
+
from utils import infer_device
|
11 |
+
|
12 |
+
if compare_version("triton", operator.ge, "3.0.0"):
|
13 |
+
try:
|
14 |
+
# typical import path with dispatch available
|
15 |
+
from triton.language.extra.libdevice import tanh
|
16 |
+
except ModuleNotFoundError:
|
17 |
+
# for working with NGC containers
|
18 |
+
from triton.language.extra.cuda.libdevice import tanh
|
19 |
+
else:
|
20 |
+
from triton.language.math import tanh
|
21 |
+
|
22 |
+
|
23 |
+
@triton.jit
|
24 |
+
def _dyt_fwd_kernel(
|
25 |
+
x_ptr,
|
26 |
+
x_row_stride,
|
27 |
+
alpha_ptr,
|
28 |
+
gamma_ptr,
|
29 |
+
beta_ptr,
|
30 |
+
y_ptr,
|
31 |
+
y_row_stride,
|
32 |
+
n_cols,
|
33 |
+
BLOCK_SIZE: tl.constexpr,
|
34 |
+
):
|
35 |
+
"""
|
36 |
+
Reference:
|
37 |
+
https://arxiv.org/abs/2503.10622
|
38 |
+
|
39 |
+
Shapes:
|
40 |
+
- x: (BT, C)
|
41 |
+
- alpha: (1)
|
42 |
+
- gamma: (C)
|
43 |
+
- beta: (C)
|
44 |
+
"""
|
45 |
+
row_idx = tl.program_id(0)
|
46 |
+
offsets = tl.arange(0, BLOCK_SIZE)
|
47 |
+
mask = offsets < n_cols
|
48 |
+
|
49 |
+
x_ptr += row_idx * x_row_stride
|
50 |
+
y_ptr += row_idx * y_row_stride
|
51 |
+
|
52 |
+
alpha = tl.load(alpha_ptr)
|
53 |
+
gamma = tl.load(gamma_ptr + offsets, mask=mask)
|
54 |
+
beta = tl.load(beta_ptr + offsets, mask=mask)
|
55 |
+
x = tl.load(x_ptr + offsets, mask=mask)
|
56 |
+
y = gamma * tanh((alpha * x).cast(tl.float32)) + beta
|
57 |
+
tl.store(y_ptr + offsets, y, mask=mask)
|
58 |
+
|
59 |
+
|
60 |
+
@triton.jit
|
61 |
+
def _dyt_bwd_kernel(
|
62 |
+
x_ptr,
|
63 |
+
x_row_stride,
|
64 |
+
dy_ptr,
|
65 |
+
dy_row_stride,
|
66 |
+
dx_ptr,
|
67 |
+
dx_row_stride,
|
68 |
+
alpha_ptr,
|
69 |
+
dalpha_ptr,
|
70 |
+
gamma_ptr,
|
71 |
+
dgamma_ptr,
|
72 |
+
dgamma_row_stride,
|
73 |
+
n_cols,
|
74 |
+
n_rows,
|
75 |
+
ROWS_PER_PROGRAM: tl.constexpr,
|
76 |
+
BLOCK_SIZE: tl.constexpr,
|
77 |
+
):
|
78 |
+
"""
|
79 |
+
Reference:
|
80 |
+
https://arxiv.org/abs/2503.10622
|
81 |
+
|
82 |
+
Shapes:
|
83 |
+
- x: (BT, C)
|
84 |
+
- alpha: (1)
|
85 |
+
- gamma: (C)
|
86 |
+
- dx: (BT, C)
|
87 |
+
- dy: (BT, C)
|
88 |
+
- dgamma: (sm_count, C)
|
89 |
+
- dalpha: (sm_count,)
|
90 |
+
"""
|
91 |
+
# d(gamma * tanh(alpha * x) + beta) / dx
|
92 |
+
# = gamma * (1 - tanh^2(alpha * x)) * alpha
|
93 |
+
# d(gamma * tanh(alpha * x) + beta) / dalpha
|
94 |
+
# = gamma * (1 - tanh^2(alpha * x)) * x
|
95 |
+
# d(gamma * tanh(alpha * x) + beta) / dgamma
|
96 |
+
# = tanh(alpha * x)
|
97 |
+
# d(gamma * tanh(alpha * x)) / dbeta = 1
|
98 |
+
pid = tl.program_id(0)
|
99 |
+
|
100 |
+
row_start = pid * ROWS_PER_PROGRAM
|
101 |
+
row_end = min((pid + 1) * ROWS_PER_PROGRAM, n_rows)
|
102 |
+
offsets = tl.arange(0, BLOCK_SIZE)
|
103 |
+
mask = offsets < n_cols
|
104 |
+
|
105 |
+
dalpha = 0.0
|
106 |
+
dgamma = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
|
107 |
+
|
108 |
+
x_ptr += row_start * x_row_stride
|
109 |
+
dx_ptr += row_start * dx_row_stride
|
110 |
+
dy_ptr += row_start * dy_row_stride
|
111 |
+
alpha = tl.load(alpha_ptr)
|
112 |
+
gamma = tl.load(gamma_ptr + offsets, mask=mask, other=0.0)
|
113 |
+
|
114 |
+
for _ in tl.range(row_start, row_end):
|
115 |
+
dy = tl.load(dy_ptr + offsets, mask=mask, other=0.0)
|
116 |
+
x = tl.load(x_ptr + offsets, mask=mask, other=0.0)
|
117 |
+
tanh_ax = tanh((alpha * x).cast(tl.float32))
|
118 |
+
sech2_ax = 1 - tanh_ax * tanh_ax
|
119 |
+
|
120 |
+
dx = dy * gamma * sech2_ax * alpha
|
121 |
+
dalpha += tl.sum(dy * gamma * sech2_ax * x)
|
122 |
+
dgamma += dy * tanh_ax
|
123 |
+
tl.store(dx_ptr + offsets, dx, mask=mask)
|
124 |
+
|
125 |
+
dy_ptr += dy_row_stride
|
126 |
+
x_ptr += x_row_stride
|
127 |
+
dx_ptr += dx_row_stride
|
128 |
+
|
129 |
+
tl.store(dgamma_ptr + pid * dgamma_row_stride + offsets, dgamma, mask=mask)
|
130 |
+
tl.store(dalpha_ptr + pid, dalpha)
|
131 |
+
|
132 |
+
pass
|
133 |
+
|
134 |
+
|
135 |
+
def liger_dyt_fwd(x, alpha, gamma, beta):
|
136 |
+
shape = x.shape
|
137 |
+
dim = shape[-1]
|
138 |
+
x = x.view(-1, dim)
|
139 |
+
n_rows, n_cols = x.shape
|
140 |
+
y = torch.empty_like(x)
|
141 |
+
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
142 |
+
_dyt_fwd_kernel[(n_rows,)](
|
143 |
+
x_ptr=x,
|
144 |
+
alpha_ptr=alpha,
|
145 |
+
gamma_ptr=gamma,
|
146 |
+
beta_ptr=beta,
|
147 |
+
y_ptr=y,
|
148 |
+
x_row_stride=x.stride(0),
|
149 |
+
y_row_stride=y.stride(0),
|
150 |
+
n_cols=n_cols,
|
151 |
+
BLOCK_SIZE=BLOCK_SIZE,
|
152 |
+
num_warps=num_warps,
|
153 |
+
)
|
154 |
+
return y.view(*shape)
|
155 |
+
|
156 |
+
|
157 |
+
def liger_dyt_bwd(dy, x, alpha, gamma):
|
158 |
+
shape = dy.shape
|
159 |
+
dtype = x.dtype
|
160 |
+
dim = shape[-1]
|
161 |
+
dy = dy.view(-1, dim)
|
162 |
+
x = x.view(-1, dim)
|
163 |
+
n_rows, n_cols = dy.shape
|
164 |
+
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
165 |
+
sm_count = 1
|
166 |
+
device = infer_device()
|
167 |
+
if device == "cuda":
|
168 |
+
sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count
|
169 |
+
elif device == "xpu":
|
170 |
+
sm_count = torch.xpu.get_device_properties(x.device).gpu_subslice_count
|
171 |
+
if n_cols > BLOCK_SIZE:
|
172 |
+
raise RuntimeError(
|
173 |
+
f"Feature dimension {dim} exceeds maximum supported size of {BLOCK_SIZE}. Consider using a smaller feature dimension."
|
174 |
+
)
|
175 |
+
|
176 |
+
dx = torch.empty_like(x, dtype=torch.float32)
|
177 |
+
_dalpha = torch.empty((sm_count,), dtype=torch.float32, device=x.device)
|
178 |
+
_dgamma = torch.empty((sm_count, n_cols), dtype=torch.float32, device=x.device)
|
179 |
+
|
180 |
+
grid = (sm_count,)
|
181 |
+
rows_per_program = triton.cdiv(n_rows, sm_count)
|
182 |
+
_dyt_bwd_kernel[grid](
|
183 |
+
x_ptr=x,
|
184 |
+
x_row_stride=x.stride(0),
|
185 |
+
dy_ptr=dy,
|
186 |
+
dy_row_stride=dy.stride(0),
|
187 |
+
dx_ptr=dx,
|
188 |
+
dx_row_stride=dx.stride(0),
|
189 |
+
alpha_ptr=alpha,
|
190 |
+
dalpha_ptr=_dalpha,
|
191 |
+
gamma_ptr=gamma,
|
192 |
+
dgamma_ptr=_dgamma,
|
193 |
+
dgamma_row_stride=_dgamma.stride(0),
|
194 |
+
n_cols=n_cols,
|
195 |
+
n_rows=n_rows,
|
196 |
+
ROWS_PER_PROGRAM=rows_per_program,
|
197 |
+
BLOCK_SIZE=BLOCK_SIZE,
|
198 |
+
num_warps=num_warps,
|
199 |
+
)
|
200 |
+
dalpha = _dalpha.sum(dim=0, keepdim=True).to(dtype)
|
201 |
+
dgamma = _dgamma.sum(dim=0).to(dtype)
|
202 |
+
dbeta = dy.sum(dim=0).to(dtype)
|
203 |
+
return dx.view(*shape), dalpha, dgamma, dbeta
|
204 |
+
|
205 |
+
|
206 |
+
class LigerDyTFunction(torch.autograd.Function):
|
207 |
+
@staticmethod
|
208 |
+
@ensure_contiguous
|
209 |
+
def forward(ctx, x, alpha, gamma, beta):
|
210 |
+
y = liger_dyt_fwd(x, alpha, gamma, beta)
|
211 |
+
ctx.save_for_backward(x, alpha, gamma)
|
212 |
+
return y
|
213 |
+
|
214 |
+
@staticmethod
|
215 |
+
@ensure_contiguous
|
216 |
+
def backward(ctx, grad_output):
|
217 |
+
x, alpha, gamma = ctx.saved_tensors
|
218 |
+
dx, dalpha, dgamma, dbeta = liger_dyt_bwd(
|
219 |
+
grad_output,
|
220 |
+
x,
|
221 |
+
alpha,
|
222 |
+
gamma,
|
223 |
+
)
|
224 |
+
|
225 |
+
return (dx, dalpha, dgamma, dbeta)
|
torch-ext/liger_kernels/fused_linear_cross_entropy.py
ADDED
@@ -0,0 +1,283 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import triton
|
3 |
+
|
4 |
+
from cross_entropy import liger_cross_entropy_kernel
|
5 |
+
from utils import amp_custom_bwd
|
6 |
+
from utils import amp_custom_fwd
|
7 |
+
from utils import element_mul_kernel
|
8 |
+
from utils import is_hip
|
9 |
+
|
10 |
+
# The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
|
11 |
+
# However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
|
12 |
+
# The optimal maximum block size depends on your hardware, your kernel, and your dtype
|
13 |
+
MAX_FUSED_SIZE = 65536 // 2
|
14 |
+
|
15 |
+
|
16 |
+
def fused_linear_cross_entropy_forward(
|
17 |
+
_input,
|
18 |
+
weight,
|
19 |
+
target,
|
20 |
+
ce_weight=None,
|
21 |
+
bias=None,
|
22 |
+
ignore_index=-100,
|
23 |
+
lse_square_scale=0.0,
|
24 |
+
label_smoothing=0.0,
|
25 |
+
reduction="mean",
|
26 |
+
softcap=None,
|
27 |
+
return_z_loss=False,
|
28 |
+
):
|
29 |
+
assert isinstance(return_z_loss, bool), f"return_z_loss must be True or False. Got: {return_z_loss}"
|
30 |
+
device = _input.device
|
31 |
+
|
32 |
+
# inputs have shape: BT x H
|
33 |
+
# materialized activations will have shape: BT x V
|
34 |
+
# the increase in memory = BT x V
|
35 |
+
# reduction can be achieved by partitioning the number of tokens BT into smaller chunks.
|
36 |
+
# for ex: if we were to achieve the same memory consumption as BT x H, then the chunk size should be:
|
37 |
+
# inc_factor = (V+H-1)//H, chunk_size = (BT + inc_factor - 1)//inc_factor
|
38 |
+
# for ex: BT = 4096*4, V = 32000, H = 4096 ==> inc_factor = 8, chunk_size = 2048
|
39 |
+
BT, H = _input.shape
|
40 |
+
V = weight.shape[0]
|
41 |
+
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
|
42 |
+
|
43 |
+
inc_factor = triton.cdiv(V, H) # (V + H - 1) // H
|
44 |
+
chunk_size = triton.next_power_of_2(triton.cdiv(BT, inc_factor)) # (BT + inc_factor - 1) // inc_factor
|
45 |
+
num_chunks = triton.cdiv(BT, chunk_size) # (BT + chunk_size - 1) // chunk_size
|
46 |
+
|
47 |
+
grad_weight = torch.zeros_like(weight, device=device) if weight.requires_grad else None
|
48 |
+
grad_input = torch.zeros_like(_input, device=device)
|
49 |
+
grad_bias = torch.zeros_like(bias, device=device) if bias is not None else None
|
50 |
+
# we use fp32 for loss accumulator
|
51 |
+
loss_1d = torch.zeros(BT, dtype=torch.float32, device=device)
|
52 |
+
z_loss_1d = torch.zeros(BT, dtype=_input.dtype, device=_input.device) if return_z_loss else None
|
53 |
+
|
54 |
+
# TODO: evaluate how CUDA synchronization caused by .item() affects the speed
|
55 |
+
target_mask = target != ignore_index
|
56 |
+
total_n_non_ignore = target_mask.sum().item()
|
57 |
+
total_sum_non_ignore_ce_weight = total_n_non_ignore
|
58 |
+
ce_weight_sum = 0.0
|
59 |
+
if ce_weight is not None:
|
60 |
+
assert ce_weight.shape[0] == V, f"If given, weight has to be a Tensor of size V. Got: {ce_weight.shape}"
|
61 |
+
assert torch.is_floating_point(ce_weight), (
|
62 |
+
f"If given, weight has to be a Tensor of floating point dtype. Got: {ce_weight.dtype}"
|
63 |
+
)
|
64 |
+
total_sum_non_ignore_ce_weight = (
|
65 |
+
torch.gather(ce_weight, dim=0, index=target.masked_select(target_mask)).sum().item()
|
66 |
+
)
|
67 |
+
ce_weight_sum = ce_weight.sum().item()
|
68 |
+
if ce_weight.stride(-1) != 1:
|
69 |
+
ce_weight = ce_weight.contiguous()
|
70 |
+
|
71 |
+
for chunk_id in range(num_chunks):
|
72 |
+
start_idx = chunk_id * chunk_size
|
73 |
+
end_idx = min((chunk_id + 1) * chunk_size, BT)
|
74 |
+
_input_chunk = _input[start_idx:end_idx] # chunk_size x H
|
75 |
+
|
76 |
+
# when doing matmul, use the original precision
|
77 |
+
logits_chunk = _input_chunk @ weight.t() # chunk_size x V
|
78 |
+
if bias is not None:
|
79 |
+
logits_chunk = logits_chunk + bias
|
80 |
+
|
81 |
+
target_chunk = target[start_idx:end_idx] # chunk_size,
|
82 |
+
|
83 |
+
n_rows = logits_chunk.shape[0]
|
84 |
+
|
85 |
+
# unreduced loss
|
86 |
+
loss_1d_slice = loss_1d[start_idx:end_idx] # chunk_size,
|
87 |
+
z_loss_1d_slice = z_loss_1d[start_idx:end_idx] if return_z_loss else None
|
88 |
+
|
89 |
+
# ensure _input and target are contiguous
|
90 |
+
logits_chunk = logits_chunk.contiguous()
|
91 |
+
target_chunk = target_chunk.contiguous()
|
92 |
+
|
93 |
+
# Here we calculate the gradient of logits_chunk in place so we can save memory.
|
94 |
+
liger_cross_entropy_kernel[(n_rows,)](
|
95 |
+
X_ptr=logits_chunk,
|
96 |
+
X_stride=logits_chunk.stride(-2),
|
97 |
+
Y_ptr=target_chunk,
|
98 |
+
Y_stride=target_chunk.stride(-1), # always 1
|
99 |
+
weight_ptr=ce_weight,
|
100 |
+
loss_ptr=loss_1d_slice,
|
101 |
+
z_loss_ptr=z_loss_1d_slice,
|
102 |
+
loss_stride=loss_1d_slice.stride(-1), # always 1
|
103 |
+
n_cols=V,
|
104 |
+
n_non_ignore=total_n_non_ignore,
|
105 |
+
sum_non_ignore_weight=total_sum_non_ignore_ce_weight,
|
106 |
+
weight_sum=ce_weight_sum,
|
107 |
+
ignore_index=ignore_index,
|
108 |
+
lse_square_scale=lse_square_scale,
|
109 |
+
label_smoothing=label_smoothing,
|
110 |
+
reduction=reduction,
|
111 |
+
softcap=softcap,
|
112 |
+
RETURN_Z_LOSS=return_z_loss,
|
113 |
+
HAS_WEIGHT=True if ce_weight is not None else False,
|
114 |
+
HAS_SOFTCAPPING=True if softcap is not None else False,
|
115 |
+
BLOCK_SIZE=BLOCK_SIZE,
|
116 |
+
num_warps=32 if not is_hip() else 16,
|
117 |
+
)
|
118 |
+
|
119 |
+
loss_1d[start_idx:end_idx] = loss_1d_slice
|
120 |
+
if return_z_loss:
|
121 |
+
z_loss_1d[start_idx:end_idx] = z_loss_1d_slice
|
122 |
+
grad_logits_chunk = logits_chunk # chunk_size x V
|
123 |
+
|
124 |
+
grad_input[start_idx:end_idx] = grad_logits_chunk @ weight
|
125 |
+
|
126 |
+
if grad_weight is not None:
|
127 |
+
torch.addmm(
|
128 |
+
input=grad_weight,
|
129 |
+
mat1=logits_chunk.t().to(
|
130 |
+
_input_chunk.dtype
|
131 |
+
), # In an autocast scenario without bias, differing logits_chunk data types will cause an addmm operation error.
|
132 |
+
mat2=_input_chunk,
|
133 |
+
out=grad_weight,
|
134 |
+
alpha=1.0,
|
135 |
+
beta=1.0,
|
136 |
+
)
|
137 |
+
|
138 |
+
if bias is not None:
|
139 |
+
torch.add(
|
140 |
+
input=grad_bias,
|
141 |
+
other=logits_chunk.sum(dim=0),
|
142 |
+
out=grad_bias,
|
143 |
+
alpha=1.0,
|
144 |
+
)
|
145 |
+
|
146 |
+
# Need extra calculations for backward if reduction=='none'. Not supporting reduction='none' now.
|
147 |
+
# if reduction == "none":
|
148 |
+
# loss = loss_1d
|
149 |
+
# z_loss = z_loss_1d if return_z_loss else None
|
150 |
+
|
151 |
+
else:
|
152 |
+
loss = torch.sum(loss_1d)
|
153 |
+
z_loss = torch.sum(z_loss_1d) if return_z_loss else None
|
154 |
+
return loss, z_loss, grad_input, grad_weight, grad_bias
|
155 |
+
|
156 |
+
|
157 |
+
def fused_linear_cross_entropy_backward(grad_output, grad_input, grad_weight, grad_bias):
|
158 |
+
# If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time
|
159 |
+
if not torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
|
160 |
+
# We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place
|
161 |
+
# for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton.
|
162 |
+
BT, H = grad_input.shape
|
163 |
+
n_rows = BT
|
164 |
+
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(H))
|
165 |
+
|
166 |
+
element_mul_kernel[(n_rows,)](
|
167 |
+
grad_input,
|
168 |
+
grad_input.stride(-2),
|
169 |
+
grad_output,
|
170 |
+
H,
|
171 |
+
BLOCK_SIZE=BLOCK_SIZE,
|
172 |
+
num_warps=32 if not is_hip() else 16,
|
173 |
+
)
|
174 |
+
|
175 |
+
# handle grad_weight
|
176 |
+
if grad_weight is not None:
|
177 |
+
V, H = grad_weight.shape
|
178 |
+
n_rows = V
|
179 |
+
|
180 |
+
element_mul_kernel[(n_rows,)](
|
181 |
+
grad_weight,
|
182 |
+
grad_weight.stride(-2),
|
183 |
+
grad_output,
|
184 |
+
H,
|
185 |
+
BLOCK_SIZE=BLOCK_SIZE,
|
186 |
+
num_warps=32 if not is_hip() else 16,
|
187 |
+
)
|
188 |
+
|
189 |
+
if grad_bias is not None:
|
190 |
+
V = grad_bias.shape[0]
|
191 |
+
n_rows = V
|
192 |
+
|
193 |
+
element_mul_kernel[(n_rows,)](
|
194 |
+
grad_bias,
|
195 |
+
grad_bias.stride(-1),
|
196 |
+
grad_output,
|
197 |
+
1,
|
198 |
+
BLOCK_SIZE=BLOCK_SIZE,
|
199 |
+
num_warps=32 if not is_hip() else 16,
|
200 |
+
)
|
201 |
+
return grad_input, grad_weight, grad_bias
|
202 |
+
|
203 |
+
|
204 |
+
class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
|
205 |
+
@staticmethod
|
206 |
+
@amp_custom_fwd
|
207 |
+
def forward(
|
208 |
+
ctx,
|
209 |
+
_input,
|
210 |
+
weight,
|
211 |
+
target,
|
212 |
+
bias=None,
|
213 |
+
ce_weight=None,
|
214 |
+
ignore_index=-100,
|
215 |
+
lse_square_scale=0.0,
|
216 |
+
label_smoothing=0.0,
|
217 |
+
reduction="mean",
|
218 |
+
softcap=None,
|
219 |
+
return_z_loss: bool = False,
|
220 |
+
):
|
221 |
+
"""
|
222 |
+
Fusing the last linear layer with cross-entropy loss
|
223 |
+
Reference: https://github.com/mgmalek/efficient_cross_entropy
|
224 |
+
|
225 |
+
Handle the forward and backward pass of the final linear layer via cross-entropy loss by avoiding
|
226 |
+
the materialization of the large logits tensor. Since Cross Entropy Loss is the last layer, we can
|
227 |
+
compute the gradient at the forward pass. By doing so, we don't have to store the _input and target
|
228 |
+
for the backward pass.
|
229 |
+
|
230 |
+
_input: (B*T, H) where B is batch size, T is sequence length, H is hidden dimension.
|
231 |
+
target: (B*T) where each value is in [0, V-1]
|
232 |
+
weight: (V, H) where V is the number of classes
|
233 |
+
bias: (V) where V is the number of classes
|
234 |
+
ce_weight: a manual rescaling weight given to each class. If given, has to be a Tensor of size V and floating point dtype
|
235 |
+
ignore_index: the index to ignore in the target
|
236 |
+
label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
|
237 |
+
reduction: reduction to apply
|
238 |
+
"""
|
239 |
+
|
240 |
+
loss, z_loss, grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_forward(
|
241 |
+
_input=_input,
|
242 |
+
weight=weight,
|
243 |
+
target=target,
|
244 |
+
bias=bias,
|
245 |
+
ce_weight=ce_weight,
|
246 |
+
ignore_index=ignore_index,
|
247 |
+
lse_square_scale=lse_square_scale,
|
248 |
+
label_smoothing=label_smoothing,
|
249 |
+
reduction=reduction,
|
250 |
+
softcap=softcap,
|
251 |
+
return_z_loss=return_z_loss,
|
252 |
+
)
|
253 |
+
# downcast to dtype and store for backward
|
254 |
+
ctx.save_for_backward(
|
255 |
+
grad_input.detach(),
|
256 |
+
grad_weight.detach() if grad_weight is not None else None,
|
257 |
+
grad_bias.detach() if bias is not None else None,
|
258 |
+
)
|
259 |
+
ctx.return_z_loss = return_z_loss
|
260 |
+
return loss, z_loss
|
261 |
+
|
262 |
+
@staticmethod
|
263 |
+
@amp_custom_bwd
|
264 |
+
def backward(ctx, grad_output, grad_output2):
|
265 |
+
if ctx.return_z_loss:
|
266 |
+
del grad_output2 # z_loss is only for logging
|
267 |
+
(grad_input, grad_weight, grad_bias) = ctx.saved_tensors
|
268 |
+
grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_backward(
|
269 |
+
grad_output, grad_input, grad_weight, grad_bias
|
270 |
+
)
|
271 |
+
return (
|
272 |
+
grad_input,
|
273 |
+
grad_weight,
|
274 |
+
None,
|
275 |
+
grad_bias,
|
276 |
+
None,
|
277 |
+
None,
|
278 |
+
None,
|
279 |
+
None,
|
280 |
+
None,
|
281 |
+
None,
|
282 |
+
None,
|
283 |
+
)
|
torch-ext/liger_kernels/geglu.py
ADDED
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import operator
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import triton
|
5 |
+
import triton.language as tl
|
6 |
+
|
7 |
+
from utils import calculate_settings
|
8 |
+
from utils import compare_version
|
9 |
+
from utils import ensure_contiguous
|
10 |
+
|
11 |
+
if compare_version("triton", operator.ge, "3.0.0"):
|
12 |
+
try:
|
13 |
+
# typical import path with dispatch available
|
14 |
+
from triton.language.extra.libdevice import tanh
|
15 |
+
except ModuleNotFoundError:
|
16 |
+
# for working with NGC containers
|
17 |
+
from triton.language.extra.cuda.libdevice import tanh
|
18 |
+
else:
|
19 |
+
from triton.language.math import tanh
|
20 |
+
|
21 |
+
|
22 |
+
@triton.jit
|
23 |
+
def _geglu_tanh_forward_kernel(a, b, c, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr):
|
24 |
+
program_id = tl.program_id(0).to(tl.int64)
|
25 |
+
|
26 |
+
# locate start index
|
27 |
+
a += program_id * stride
|
28 |
+
b += program_id * stride
|
29 |
+
c += program_id * stride
|
30 |
+
|
31 |
+
col_offsets = tl.arange(0, BLOCK_SIZE)
|
32 |
+
mask = col_offsets < n_cols
|
33 |
+
a_row = tl.load(a + col_offsets, mask=mask, other=0).to(tl.float32)
|
34 |
+
b_row = tl.load(b + col_offsets, mask=mask, other=0)
|
35 |
+
|
36 |
+
# tanh approximation form of GELU is computed with:
|
37 |
+
# 0.5 * a * (1 + tanh(sqrt(2 / pi) * (a + 0.044715 * a^3)))
|
38 |
+
sqrt_2_over_pi = 0.7978845608028654 # sqrt(2 / pi)
|
39 |
+
a_cubed = a_row * a_row * a_row
|
40 |
+
tanh_arg = sqrt_2_over_pi * (a_row + 0.044715 * a_cubed)
|
41 |
+
tanh_result = tanh(tanh_arg)
|
42 |
+
geglu_a = 0.5 * a_row * (1 + tanh_result)
|
43 |
+
c_row = geglu_a * b_row
|
44 |
+
tl.store(c + col_offsets, c_row, mask=mask)
|
45 |
+
|
46 |
+
|
47 |
+
@triton.jit
|
48 |
+
def _geglu_tanh_backward_kernel(dc, a, b, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr):
|
49 |
+
program_id = tl.program_id(0).to(tl.int64)
|
50 |
+
|
51 |
+
# locate start index
|
52 |
+
dc += program_id * stride
|
53 |
+
a += program_id * stride
|
54 |
+
b += program_id * stride
|
55 |
+
|
56 |
+
col_offsets = tl.arange(0, BLOCK_SIZE)
|
57 |
+
mask = col_offsets < n_cols
|
58 |
+
|
59 |
+
dc_row = tl.load(dc + col_offsets, mask=mask, other=0)
|
60 |
+
a_row = tl.load(a + col_offsets, mask=mask, other=0).to(tl.float32)
|
61 |
+
b_row = tl.load(b + col_offsets, mask=mask, other=0)
|
62 |
+
|
63 |
+
# recomputation to save memory
|
64 |
+
sqrt_2_over_pi = 0.7978845608028654 # sqrt(2 / pi)
|
65 |
+
a_cubed = a_row * a_row * a_row
|
66 |
+
tanh_arg = sqrt_2_over_pi * (a_row + 0.044715 * a_cubed)
|
67 |
+
tanh_result = tanh(tanh_arg)
|
68 |
+
geglu_a = 0.5 * a_row * (1 + tanh_result)
|
69 |
+
|
70 |
+
db_row = dc_row * geglu_a
|
71 |
+
|
72 |
+
# Gradient w.r.t. a can be computed with:
|
73 |
+
# b * (0.5 * (1 + tanh(z)) + 0.5 * a * (1 - tanh(z)^2) * (sqrt(2/pi) * (1 + 3 * 0.044715 * a^2)))
|
74 |
+
# where z = sqrt(2/pi) * (a + 0.044715 * a^3)
|
75 |
+
term1 = 0.5 * (1 + tanh_result)
|
76 |
+
tanh_sq = tanh_result * tanh_result
|
77 |
+
term2 = 0.5 * a_row * (1 - tanh_sq) * (sqrt_2_over_pi * (1 + 3 * 0.044715 * a_row * a_row))
|
78 |
+
da_row = dc_row * b_row * (term1 + term2)
|
79 |
+
|
80 |
+
tl.store(a + col_offsets, da_row, mask=mask)
|
81 |
+
tl.store(b + col_offsets, db_row, mask=mask)
|
82 |
+
|
83 |
+
|
84 |
+
def geglu_forward(a, b):
|
85 |
+
ori_shape = a.shape
|
86 |
+
|
87 |
+
n_cols = ori_shape[-1]
|
88 |
+
a = a.view(-1, n_cols)
|
89 |
+
b = b.view(-1, n_cols)
|
90 |
+
c = torch.empty_like(a)
|
91 |
+
n_rows = a.shape[0]
|
92 |
+
|
93 |
+
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
94 |
+
|
95 |
+
_geglu_tanh_forward_kernel[(n_rows,)](
|
96 |
+
a,
|
97 |
+
b,
|
98 |
+
c,
|
99 |
+
c.stride(-2),
|
100 |
+
n_cols=n_cols,
|
101 |
+
BLOCK_SIZE=BLOCK_SIZE,
|
102 |
+
num_warps=num_warps,
|
103 |
+
)
|
104 |
+
return a, b, c.view(*ori_shape)
|
105 |
+
|
106 |
+
|
107 |
+
def geglu_backward(a, b, dc):
|
108 |
+
ori_shape = dc.shape
|
109 |
+
n_cols = ori_shape[-1]
|
110 |
+
dc = dc.view(-1, n_cols)
|
111 |
+
n_rows = dc.shape[0]
|
112 |
+
|
113 |
+
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
114 |
+
|
115 |
+
_geglu_tanh_backward_kernel[(n_rows,)](
|
116 |
+
dc,
|
117 |
+
a,
|
118 |
+
b,
|
119 |
+
dc.stride(-2),
|
120 |
+
n_cols=n_cols,
|
121 |
+
BLOCK_SIZE=BLOCK_SIZE,
|
122 |
+
num_warps=num_warps,
|
123 |
+
)
|
124 |
+
|
125 |
+
return a.view(*ori_shape), b.view(*ori_shape)
|
126 |
+
|
127 |
+
|
128 |
+
class LigerGELUMulFunction(torch.autograd.Function):
|
129 |
+
@staticmethod
|
130 |
+
@ensure_contiguous
|
131 |
+
def forward(ctx, a, b):
|
132 |
+
a, b, c = geglu_forward(a, b)
|
133 |
+
ctx.save_for_backward(a, b)
|
134 |
+
return c
|
135 |
+
|
136 |
+
@staticmethod
|
137 |
+
@ensure_contiguous
|
138 |
+
def backward(ctx, dc):
|
139 |
+
a, b = ctx.saved_tensors
|
140 |
+
a, b = geglu_backward(a, b, dc)
|
141 |
+
return a, b
|
torch-ext/liger_kernels/group_norm.py
ADDED
@@ -0,0 +1,305 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import operator
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import triton
|
5 |
+
import triton.language as tl
|
6 |
+
|
7 |
+
from utils import compare_version
|
8 |
+
from utils import ensure_contiguous
|
9 |
+
|
10 |
+
if compare_version("triton", operator.ge, "3.0.0"):
|
11 |
+
try:
|
12 |
+
# typical import path with dispatch available
|
13 |
+
from triton.language.extra.libdevice import rsqrt
|
14 |
+
except ModuleNotFoundError:
|
15 |
+
# for working with NGC containers
|
16 |
+
from triton.language.extra.cuda.libdevice import rsqrt
|
17 |
+
else:
|
18 |
+
from triton.language.math import rsqrt
|
19 |
+
|
20 |
+
MAX_FUSED_SIZE = 65536
|
21 |
+
|
22 |
+
|
23 |
+
@triton.jit
|
24 |
+
def _group_norm_forward_kernel(
|
25 |
+
Y_ptr, # pointer to output, shape (n_rows, n_groups, hidden_size)
|
26 |
+
Y_row_stride, # stride of each row in output
|
27 |
+
Y_col_stride, # stride of each column in output
|
28 |
+
X_ptr, # pointer to input, shape (n_rows, n_groups, hidden_size)
|
29 |
+
X_row_stride, # stride of each row in input
|
30 |
+
X_col_stride, # stride of each column in input
|
31 |
+
Mean_ptr, # pointer to mean, shape (n_rows, n_groups)
|
32 |
+
Mean_row_stride, # stride of each row in mean
|
33 |
+
Mean_col_stride, # stride of each column in mean
|
34 |
+
RSTD_ptr, # pointer to rstd, shape (n_rows, n_groups)
|
35 |
+
RSTD_row_stride, # stride of each row in rstd
|
36 |
+
RSTD_col_stride, # stride of each column in rstd
|
37 |
+
W_ptr, # pointer to W
|
38 |
+
B_ptr, # pointer to B
|
39 |
+
hidden_size, # hidden size of X
|
40 |
+
channels_per_group, # the number of channels per group
|
41 |
+
eps,
|
42 |
+
BLOCK_SIZE: tl.constexpr,
|
43 |
+
):
|
44 |
+
"""
|
45 |
+
References:
|
46 |
+
https://nn.labml.ai/normalization/group_norm/index.html
|
47 |
+
"""
|
48 |
+
batch_idx = tl.program_id(0)
|
49 |
+
group_idx = tl.program_id(1)
|
50 |
+
|
51 |
+
X_ptr += batch_idx * X_row_stride + group_idx * X_col_stride
|
52 |
+
Y_ptr += batch_idx * Y_row_stride + group_idx * Y_col_stride
|
53 |
+
|
54 |
+
block_range = tl.arange(0, BLOCK_SIZE)
|
55 |
+
|
56 |
+
# Compute mean and variance using the online algorithm
|
57 |
+
s = 0.0
|
58 |
+
squared_sum = 0.0
|
59 |
+
for i in tl.range(0, hidden_size, BLOCK_SIZE):
|
60 |
+
hidden_size_offsets = i + block_range
|
61 |
+
mask = hidden_size_offsets < hidden_size
|
62 |
+
X = tl.load(X_ptr + hidden_size_offsets, mask=mask, other=0.0)
|
63 |
+
s += tl.sum(X)
|
64 |
+
# X**2
|
65 |
+
squared_sum += tl.sum(X * X)
|
66 |
+
|
67 |
+
m = s / hidden_size
|
68 |
+
|
69 |
+
# variance = E[X**2] - E[X]**2
|
70 |
+
variance = (squared_sum / hidden_size) - (m * m)
|
71 |
+
|
72 |
+
# 1/std
|
73 |
+
rstd = rsqrt(variance + eps)
|
74 |
+
|
75 |
+
# Normalize
|
76 |
+
hidden_size_per_channel = hidden_size // channels_per_group
|
77 |
+
for channel_idx in tl.range(group_idx * channels_per_group, (group_idx + 1) * channels_per_group):
|
78 |
+
W = tl.load(W_ptr + channel_idx)
|
79 |
+
B = tl.load(B_ptr + channel_idx)
|
80 |
+
for i in range(0, hidden_size_per_channel, BLOCK_SIZE):
|
81 |
+
hidden_size_offsets = i + block_range
|
82 |
+
mask = hidden_size_offsets < hidden_size_per_channel
|
83 |
+
X = tl.load(X_ptr + hidden_size_offsets, mask=mask, other=m)
|
84 |
+
Y = (X - m) * rstd * W + B
|
85 |
+
tl.store(Y_ptr + hidden_size_offsets, Y, mask=mask)
|
86 |
+
|
87 |
+
X_ptr += hidden_size_per_channel
|
88 |
+
Y_ptr += hidden_size_per_channel
|
89 |
+
|
90 |
+
tl.store(Mean_ptr + batch_idx * Mean_row_stride + group_idx * Mean_col_stride, m)
|
91 |
+
tl.store(RSTD_ptr + batch_idx * RSTD_row_stride + group_idx * RSTD_col_stride, rstd)
|
92 |
+
|
93 |
+
|
94 |
+
@triton.jit
|
95 |
+
def _group_norm_backward_kernel(
|
96 |
+
X_ptr, # pointer to input, shape (n_rows, n_channels, hidden_size)
|
97 |
+
X_row_stride, # stride of each row in input
|
98 |
+
X_col_stride, # stride of each column in input
|
99 |
+
W_ptr, # pointer to weights, shape (n_channels)
|
100 |
+
Mean_ptr, # pointer to mean, shape (n_rows, n_groups)
|
101 |
+
Mean_ptr_row_stride, # stride of each column in mean
|
102 |
+
Mean_ptr_col_stride, # stride of each column in mean
|
103 |
+
RSTD_ptr, # pointer to rstd, shape (n_rows, n_groups)
|
104 |
+
DX_ptr, # pointer to input grad, shape (n_rows, n_groups, hidden_size)
|
105 |
+
DW_ptr, # pointer to weights grad, shape (n_channels)
|
106 |
+
DB_ptr, # pointer to bias grad, shape (n_channels)
|
107 |
+
UPSTREAM_ptr, # pointer to output grad, shape (n_rows, n_channels, hidden_size)
|
108 |
+
hidden_size: tl.constexpr, # hidden size
|
109 |
+
channels_per_group: tl.constexpr, # number of groups in group norm
|
110 |
+
BLOCK_SIZE: tl.constexpr,
|
111 |
+
dtype: tl.constexpr,
|
112 |
+
):
|
113 |
+
"""
|
114 |
+
References:
|
115 |
+
https://nn.labml.ai/normalization/group_norm/index.html
|
116 |
+
https://github.com/karpathy/llm.c/blob/master/doc/layernorm/layernorm.md
|
117 |
+
|
118 |
+
The backprop equations are the same for group_norm and layer_norm
|
119 |
+
the only difference here is that we load the Mean, Rstd corresponding to the
|
120 |
+
group we're computing gradients for and the mean and rstd are computed over n-channels
|
121 |
+
so the total number of elements we compute the mean over is num_channels_per_group * hidden_size
|
122 |
+
|
123 |
+
We also need to load the Weights corresponding to the current channel to compute the gradients.
|
124 |
+
"""
|
125 |
+
batch_idx = tl.program_id(0)
|
126 |
+
group_idx = tl.program_id(1)
|
127 |
+
|
128 |
+
# Move the pointers to the correct batch
|
129 |
+
X_ptr += batch_idx * X_row_stride
|
130 |
+
DX_ptr += batch_idx * X_row_stride
|
131 |
+
UPSTREAM_ptr += batch_idx * X_row_stride
|
132 |
+
|
133 |
+
# Mean and rstd are the same shape so have the same strides
|
134 |
+
mean = tl.load(Mean_ptr + batch_idx * Mean_ptr_row_stride + group_idx * Mean_ptr_col_stride)
|
135 |
+
rstd = tl.load(RSTD_ptr + batch_idx * Mean_ptr_row_stride + group_idx * Mean_ptr_col_stride)
|
136 |
+
|
137 |
+
c1 = 0.0
|
138 |
+
c2 = 0.0
|
139 |
+
block_range = tl.arange(0, BLOCK_SIZE)
|
140 |
+
|
141 |
+
# We need to compute the sum terms of the backprop equations across all channels in the group
|
142 |
+
for channel_idx in range(group_idx * channels_per_group, (group_idx + 1) * channels_per_group):
|
143 |
+
dW = 0.0
|
144 |
+
dB = 0.0
|
145 |
+
# Move the pointers to the correct channel
|
146 |
+
W = tl.load(W_ptr + channel_idx)
|
147 |
+
for i in tl.range(0, hidden_size, BLOCK_SIZE):
|
148 |
+
hidden_size_offsets = i + block_range
|
149 |
+
mask = hidden_size_offsets < hidden_size
|
150 |
+
X = tl.load(
|
151 |
+
X_ptr + channel_idx * X_col_stride + hidden_size_offsets,
|
152 |
+
mask=mask,
|
153 |
+
other=0.0,
|
154 |
+
)
|
155 |
+
UPSTREAM_grad = tl.load(
|
156 |
+
UPSTREAM_ptr + channel_idx * X_col_stride + hidden_size_offsets,
|
157 |
+
mask=mask,
|
158 |
+
other=0.0,
|
159 |
+
)
|
160 |
+
|
161 |
+
x_hat = (X - mean) * rstd
|
162 |
+
dW += tl.sum(UPSTREAM_grad * x_hat)
|
163 |
+
dB += tl.sum(UPSTREAM_grad)
|
164 |
+
|
165 |
+
wdy = W * UPSTREAM_grad
|
166 |
+
c1 += tl.sum(x_hat * wdy)
|
167 |
+
c2 += tl.sum(wdy)
|
168 |
+
|
169 |
+
# Need to ensure additions to the same channel are atomic
|
170 |
+
tl.atomic_add(DW_ptr + channel_idx, dW.to(dtype))
|
171 |
+
tl.atomic_add(DB_ptr + channel_idx, dB.to(dtype))
|
172 |
+
|
173 |
+
N = hidden_size * channels_per_group
|
174 |
+
c1 = c1 / N
|
175 |
+
c2 = c2 / N
|
176 |
+
|
177 |
+
for channel_idx in tl.range(group_idx * channels_per_group, (group_idx + 1) * channels_per_group):
|
178 |
+
# Move the pointers to the correct channel
|
179 |
+
W = tl.load(W_ptr + channel_idx)
|
180 |
+
for i in range(0, hidden_size, BLOCK_SIZE):
|
181 |
+
hidden_size_offsets = i + block_range
|
182 |
+
mask = hidden_size_offsets < hidden_size
|
183 |
+
X = tl.load(
|
184 |
+
X_ptr + channel_idx * X_col_stride + hidden_size_offsets,
|
185 |
+
mask=mask,
|
186 |
+
other=0.0,
|
187 |
+
)
|
188 |
+
UPSTREAM_grad = tl.load(
|
189 |
+
UPSTREAM_ptr + channel_idx * X_col_stride + hidden_size_offsets,
|
190 |
+
mask=mask,
|
191 |
+
other=0.0,
|
192 |
+
)
|
193 |
+
|
194 |
+
x_hat = (X - mean) * rstd
|
195 |
+
wdy = W * UPSTREAM_grad
|
196 |
+
dx = (wdy - (x_hat * c1 + c2)) * rstd
|
197 |
+
tl.store(DX_ptr + channel_idx * X_col_stride + hidden_size_offsets, dx, mask=mask)
|
198 |
+
|
199 |
+
|
200 |
+
def group_norm_forward(X, num_channels, num_groups, W, B, eps):
|
201 |
+
shape = X.shape
|
202 |
+
batch_size = shape[0]
|
203 |
+
channels_per_group = num_channels // num_groups
|
204 |
+
# Reshape X so that the mean and std are computed across the groups
|
205 |
+
X = X.view(batch_size, num_groups, -1).contiguous()
|
206 |
+
hidden_size = X.shape[-1]
|
207 |
+
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(hidden_size))
|
208 |
+
Y = torch.empty((batch_size, num_groups, hidden_size), dtype=X.dtype, device=X.device)
|
209 |
+
Mean = torch.zeros((batch_size, num_groups), dtype=X.dtype, device=X.device)
|
210 |
+
RSTD = torch.zeros((batch_size, num_groups), dtype=X.dtype, device=X.device)
|
211 |
+
|
212 |
+
_group_norm_forward_kernel[(batch_size, num_groups)](
|
213 |
+
Y,
|
214 |
+
Y.stride(0),
|
215 |
+
Y.stride(1),
|
216 |
+
X,
|
217 |
+
X.stride(0),
|
218 |
+
X.stride(1),
|
219 |
+
Mean,
|
220 |
+
Mean.stride(0),
|
221 |
+
Mean.stride(1),
|
222 |
+
RSTD,
|
223 |
+
RSTD.stride(0),
|
224 |
+
RSTD.stride(1),
|
225 |
+
W,
|
226 |
+
B,
|
227 |
+
hidden_size,
|
228 |
+
channels_per_group,
|
229 |
+
eps,
|
230 |
+
BLOCK_SIZE=BLOCK_SIZE,
|
231 |
+
)
|
232 |
+
# Return tensors in the original shape
|
233 |
+
return Y.view(*shape), X.view(*shape), Mean, RSTD, BLOCK_SIZE
|
234 |
+
|
235 |
+
|
236 |
+
def group_norm_backward(dY, X, W, B, Mean, RSTD, num_channels, num_groups):
|
237 |
+
shape = dY.shape
|
238 |
+
batch_size = shape[0]
|
239 |
+
hidden_size = dY.shape[-1]
|
240 |
+
channels_per_group = num_channels // num_groups
|
241 |
+
dY = dY.view(batch_size, num_groups, -1)
|
242 |
+
DX = torch.empty(
|
243 |
+
(batch_size, num_groups, hidden_size * channels_per_group),
|
244 |
+
dtype=X.dtype,
|
245 |
+
device=X.device,
|
246 |
+
)
|
247 |
+
DW = torch.zeros((num_channels), dtype=W.dtype, device=W.device)
|
248 |
+
DB = torch.zeros((num_channels), dtype=B.dtype, device=B.device)
|
249 |
+
triton_dtype = tl.float32 if X.dtype == torch.float32 else tl.bfloat16
|
250 |
+
|
251 |
+
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(hidden_size))
|
252 |
+
_group_norm_backward_kernel[(batch_size, num_groups)](
|
253 |
+
X,
|
254 |
+
X.stride(0),
|
255 |
+
X.stride(1),
|
256 |
+
W,
|
257 |
+
Mean,
|
258 |
+
Mean.stride(0),
|
259 |
+
Mean.stride(1),
|
260 |
+
RSTD,
|
261 |
+
DX,
|
262 |
+
DW,
|
263 |
+
DB,
|
264 |
+
dY,
|
265 |
+
hidden_size,
|
266 |
+
channels_per_group,
|
267 |
+
BLOCK_SIZE=BLOCK_SIZE,
|
268 |
+
dtype=triton_dtype,
|
269 |
+
)
|
270 |
+
|
271 |
+
# Return tensors in the original shape
|
272 |
+
return DX.view(*shape), DW, DB
|
273 |
+
|
274 |
+
|
275 |
+
class LigerGroupNormFunction(torch.autograd.Function):
|
276 |
+
@staticmethod
|
277 |
+
@ensure_contiguous
|
278 |
+
def forward(
|
279 |
+
ctx,
|
280 |
+
X,
|
281 |
+
affine_scaling_weight,
|
282 |
+
affine_shifting_bias,
|
283 |
+
num_channels,
|
284 |
+
num_groups,
|
285 |
+
eps,
|
286 |
+
):
|
287 |
+
Y, X, Mean, RSTD, BLOCK_SIZE = group_norm_forward(
|
288 |
+
X,
|
289 |
+
num_channels,
|
290 |
+
num_groups,
|
291 |
+
affine_scaling_weight,
|
292 |
+
affine_shifting_bias,
|
293 |
+
eps,
|
294 |
+
)
|
295 |
+
ctx.num_channels = num_channels
|
296 |
+
ctx.num_groups = num_groups
|
297 |
+
ctx.save_for_backward(X, affine_scaling_weight, affine_shifting_bias, Mean, RSTD)
|
298 |
+
return Y
|
299 |
+
|
300 |
+
@staticmethod
|
301 |
+
@ensure_contiguous
|
302 |
+
def backward(ctx, dY):
|
303 |
+
X, W, B, Mean, RSTD = ctx.saved_tensors
|
304 |
+
DX, DW, DB = group_norm_backward(dY, X, W, B, Mean, RSTD, ctx.num_channels, ctx.num_groups)
|
305 |
+
return DX, DW, DB, None, None, None
|
torch-ext/liger_kernels/jsd.py
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import triton
|
5 |
+
import triton.language as tl
|
6 |
+
|
7 |
+
from utils import ensure_contiguous
|
8 |
+
from utils import infer_device
|
9 |
+
|
10 |
+
|
11 |
+
@triton.jit
|
12 |
+
def _jsd_kernel(
|
13 |
+
X_ptr, # input in logspace, X = log Q
|
14 |
+
X_stride,
|
15 |
+
Y_ptr, # ground truth in logspace, Y = log P
|
16 |
+
Y_stride,
|
17 |
+
loss_ptr,
|
18 |
+
loss_stride,
|
19 |
+
dX_ptr,
|
20 |
+
dX_stride,
|
21 |
+
label_ptr,
|
22 |
+
beta: tl.constexpr,
|
23 |
+
n_non_ignore: int,
|
24 |
+
ignore_index: tl.constexpr,
|
25 |
+
n_cols,
|
26 |
+
BLOCK_SIZE: tl.constexpr,
|
27 |
+
HAS_LABEL: tl.constexpr,
|
28 |
+
):
|
29 |
+
# JSD(P || Q) = (KL(P || M) + KL(Q || M)) / 2, M = (1/2) * (P + Q) = (1/2) * (e ^ Y + e ^ X)
|
30 |
+
# = sum(P * log P + Q * log Q - 2 * M * log M) / 2
|
31 |
+
# = sum(e ^ Y * Y + e ^ X * X - 2 * M * log M) / 2
|
32 |
+
# grad_x_i = 0.5 * Q * (X - log_M)
|
33 |
+
pid = tl.program_id(0).to(tl.int64)
|
34 |
+
X_ptr += pid * X_stride
|
35 |
+
dX_ptr += pid * dX_stride
|
36 |
+
Y_ptr += pid * Y_stride
|
37 |
+
loss_ptr += pid * loss_stride
|
38 |
+
label_ptr += pid
|
39 |
+
|
40 |
+
if HAS_LABEL:
|
41 |
+
label = tl.load(label_ptr)
|
42 |
+
if label == ignore_index:
|
43 |
+
for i in range(0, n_cols, BLOCK_SIZE):
|
44 |
+
offsets = i + tl.arange(0, BLOCK_SIZE)
|
45 |
+
tl.store(dX_ptr + offsets, 0.0, mask=offsets < n_cols)
|
46 |
+
return
|
47 |
+
|
48 |
+
for i in range(0, n_cols, BLOCK_SIZE):
|
49 |
+
offsets = i + tl.arange(0, BLOCK_SIZE)
|
50 |
+
mask = offsets < n_cols
|
51 |
+
X = tl.load(X_ptr + offsets, mask=mask, other=float("-inf")).to(tl.float32)
|
52 |
+
Y = tl.load(Y_ptr + offsets, mask=mask, other=float("-inf")).to(tl.float32)
|
53 |
+
|
54 |
+
if beta == 0.0: # forward KL
|
55 |
+
Y_max = tl.max(Y, axis=0)
|
56 |
+
Y_shifted = Y - Y_max
|
57 |
+
Y_prob = tl.exp(Y_shifted) * tl.exp(Y_max) # Compensate for the shift
|
58 |
+
loss = Y_prob * (Y - X)
|
59 |
+
dX = -Y_prob
|
60 |
+
elif beta == 1.0: # reverse KL
|
61 |
+
X_max = tl.max(X, axis=0)
|
62 |
+
X_shifted = X - X_max
|
63 |
+
X_prob = tl.exp(X_shifted) * tl.exp(X_max) # Compensate for the shift
|
64 |
+
loss = X_prob * (X - Y)
|
65 |
+
dX = loss + X_prob
|
66 |
+
else:
|
67 |
+
max_val = tl.maximum(tl.max(X, axis=0), tl.max(Y, axis=0))
|
68 |
+
X_shifted = X - max_val
|
69 |
+
Y_shifted = Y - max_val
|
70 |
+
|
71 |
+
# Pre-compute exp(max_val) since it's used twice
|
72 |
+
exp_max = tl.exp(max_val)
|
73 |
+
|
74 |
+
# Compute exp terms with compensation
|
75 |
+
Q = tl.exp(X_shifted) * exp_max # = exp(X)
|
76 |
+
P = tl.exp(Y_shifted) * exp_max # = exp(Y)
|
77 |
+
|
78 |
+
# Pre-compute common terms
|
79 |
+
beta_P = beta * P
|
80 |
+
one_minus_beta_Q = (1 - beta) * Q
|
81 |
+
M = beta_P + one_minus_beta_Q
|
82 |
+
log_M = tl.log(M) # No need to compensate as M is already in original scale
|
83 |
+
|
84 |
+
loss = beta_P * Y + one_minus_beta_Q * X - M * log_M
|
85 |
+
dX = one_minus_beta_Q * (X - log_M)
|
86 |
+
|
87 |
+
# Pre-compute scaling factor
|
88 |
+
scale = 1.0 / n_non_ignore
|
89 |
+
loss = loss * scale
|
90 |
+
dX = dX * scale
|
91 |
+
|
92 |
+
tl.store(loss_ptr + offsets, loss, mask=mask)
|
93 |
+
tl.store(dX_ptr + offsets, dX, mask=mask)
|
94 |
+
|
95 |
+
|
96 |
+
MAX_FUSED_SIZE = 4096 if infer_device() == "xpu" else 65536
|
97 |
+
|
98 |
+
|
99 |
+
def jsd_forward(_input, target, shift_labels, beta, ignore_index, has_label):
|
100 |
+
BT, V = _input.shape
|
101 |
+
n_rows = BT
|
102 |
+
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
|
103 |
+
# non reduction loss
|
104 |
+
loss = torch.zeros(_input.shape, dtype=torch.float32, device=_input.device)
|
105 |
+
dX = torch.empty_like(_input)
|
106 |
+
|
107 |
+
if has_label:
|
108 |
+
n_non_ignore = (shift_labels != ignore_index).sum().item()
|
109 |
+
else:
|
110 |
+
n_non_ignore = BT
|
111 |
+
|
112 |
+
_jsd_kernel[(n_rows,)](
|
113 |
+
X_ptr=_input, # input in logspace, X = log Q
|
114 |
+
X_stride=_input.stride(-2),
|
115 |
+
Y_ptr=target, # ground truth in logspace, Y = log P
|
116 |
+
Y_stride=target.stride(-2),
|
117 |
+
loss_ptr=loss,
|
118 |
+
loss_stride=loss.stride(-2),
|
119 |
+
dX_ptr=dX,
|
120 |
+
dX_stride=dX.stride(-2),
|
121 |
+
label_ptr=(shift_labels if has_label else torch.empty(1, device=_input.device)), # dummy ptr if no label
|
122 |
+
beta=beta,
|
123 |
+
n_non_ignore=n_non_ignore,
|
124 |
+
ignore_index=ignore_index,
|
125 |
+
n_cols=V,
|
126 |
+
BLOCK_SIZE=BLOCK_SIZE,
|
127 |
+
HAS_LABEL=has_label,
|
128 |
+
)
|
129 |
+
|
130 |
+
loss = torch.sum(loss)
|
131 |
+
return loss.to(_input.dtype), dX
|
132 |
+
|
133 |
+
|
134 |
+
def jsd_backward(dX, grad_output):
|
135 |
+
# If jsd is the last layer, grad_output is 1.0. Skip the mul to save time
|
136 |
+
if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
|
137 |
+
return dX
|
138 |
+
else:
|
139 |
+
return grad_output * dX
|
140 |
+
|
141 |
+
|
142 |
+
class LigerJSDFunction(torch.autograd.Function):
|
143 |
+
r"""
|
144 |
+
This class implements the forward and backward pass for the generalized Jensen-Shannon Divergence.
|
145 |
+
.. math::
|
146 |
+
JSD(\beta)(P || Q)
|
147 |
+
= \beta * KLDiv(P || (\beta * P + (1 - \beta) * Q)) + (1 - \beta) * KLDiv(Q || (\beta * P + (1 - \beta) * Q))
|
148 |
+
|
149 |
+
.. note::
|
150 |
+
As all the other losses in PyTorch, this function expects the first argument,
|
151 |
+
:attr:`_input`, to be the predictions, the output of the student model, in log-space
|
152 |
+
and the second, :attr:`target`, to be the observations, the output of the teacher model, in log-space.
|
153 |
+
This differs from the standard mathematical notation :math:`JSD(P || Q)` where
|
154 |
+
:math:`P` denotes the teacher model and :math:`Q` denotes the student model.
|
155 |
+
"""
|
156 |
+
|
157 |
+
@staticmethod
|
158 |
+
@ensure_contiguous
|
159 |
+
def forward(
|
160 |
+
ctx,
|
161 |
+
_input: torch.Tensor,
|
162 |
+
target: torch.Tensor,
|
163 |
+
shift_labels: Optional[torch.Tensor] = None,
|
164 |
+
beta: float = 0.5,
|
165 |
+
ignore_index: int = -100,
|
166 |
+
) -> torch.Tensor:
|
167 |
+
"""
|
168 |
+
Args:
|
169 |
+
_input (torch.Tensor): predict values with shape (BT, V) in logspace
|
170 |
+
target (torch.Tensor): ground truth values with shape (BT, V) in logspace
|
171 |
+
shift_labels (Optional[torch.LongTensor]): indicator of next predicted vocab with shape (BT) where each value is in [0, V-1].
|
172 |
+
beta (float): coefficient beta of generalized JSD in the interval [0, 1]. It implements forward/reverse KL when beta equals 0 and 1 respectively. Default: `0.5`
|
173 |
+
ignore_index (int): the index to ignore. Default: -100
|
174 |
+
|
175 |
+
Returns:
|
176 |
+
loss (torch.Tensor): generalized JSD
|
177 |
+
"""
|
178 |
+
has_label = False
|
179 |
+
if shift_labels is not None:
|
180 |
+
assert shift_labels.shape == (_input.shape[0],), (
|
181 |
+
f"the shape of shift_labels must be (BT,). Got: {shift_labels.shape}"
|
182 |
+
)
|
183 |
+
shift_labels = shift_labels.contiguous()
|
184 |
+
has_label = True
|
185 |
+
|
186 |
+
loss, dX = jsd_forward(_input, target, shift_labels, beta, ignore_index, has_label)
|
187 |
+
ctx.save_for_backward(dX)
|
188 |
+
return loss
|
189 |
+
|
190 |
+
@staticmethod
|
191 |
+
@ensure_contiguous
|
192 |
+
def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
|
193 |
+
(dX,) = ctx.saved_tensors
|
194 |
+
dX = jsd_backward(dX, grad_output)
|
195 |
+
return (
|
196 |
+
dX,
|
197 |
+
None,
|
198 |
+
None,
|
199 |
+
None,
|
200 |
+
None,
|
201 |
+
)
|
torch-ext/liger_kernels/kl_div.py
ADDED
@@ -0,0 +1,262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Literal
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import triton
|
5 |
+
import triton.language as tl
|
6 |
+
|
7 |
+
from utils import ensure_contiguous
|
8 |
+
from utils import is_hip
|
9 |
+
from utils import infer_device
|
10 |
+
|
11 |
+
|
12 |
+
def get_num_warps(BLOCK_SIZE):
|
13 |
+
num_warps = 4
|
14 |
+
if BLOCK_SIZE >= 32768:
|
15 |
+
num_warps = 32 if not is_hip() else 16
|
16 |
+
elif BLOCK_SIZE >= 8192:
|
17 |
+
num_warps = 16
|
18 |
+
elif BLOCK_SIZE >= 2048:
|
19 |
+
num_warps = 8
|
20 |
+
|
21 |
+
return num_warps
|
22 |
+
|
23 |
+
|
24 |
+
MAX_FUSED_SIZE = 65536 // 4 # 65536 // 4 or 8 works the best
|
25 |
+
|
26 |
+
REDUCTION_LITERAL = Literal["none", "sum", "mean", "batchmean"]
|
27 |
+
|
28 |
+
_REDUCTION_MODE_NONE: tl.constexpr = tl.constexpr(0)
|
29 |
+
_REDUCTION_MODE_SUM: tl.constexpr = tl.constexpr(1)
|
30 |
+
_REDUCTION_MODE_MEAN: tl.constexpr = tl.constexpr(2)
|
31 |
+
_REDUCTION_MODE_BATCHMEAN: tl.constexpr = tl.constexpr(3)
|
32 |
+
|
33 |
+
_str_to_reduction_mode = {
|
34 |
+
"none": _REDUCTION_MODE_NONE.value,
|
35 |
+
"sum": _REDUCTION_MODE_SUM.value,
|
36 |
+
"mean": _REDUCTION_MODE_MEAN.value,
|
37 |
+
"batchmean": _REDUCTION_MODE_BATCHMEAN.value,
|
38 |
+
}
|
39 |
+
|
40 |
+
|
41 |
+
@triton.jit
|
42 |
+
def _kldiv_kernel_forward(
|
43 |
+
y_ptr, # [B, S], prediction ptr, the kernel expects the prediction in log-space
|
44 |
+
y_stride, # int, prediction stride
|
45 |
+
gt_ptr, # [B, S], ground truth ptr
|
46 |
+
gt_stride, # int, ground truth stride
|
47 |
+
loss_ptr, # [B] or [B, S] if reduction == _REDUCTION_MODE_NONE, output ptr
|
48 |
+
loss_stride, # int, output stride
|
49 |
+
n_cols, # int, number of columns in the input tensor
|
50 |
+
eps,
|
51 |
+
BLOCK_SIZE: tl.constexpr,
|
52 |
+
log_target: tl.constexpr = False,
|
53 |
+
reduction: tl.constexpr = _REDUCTION_MODE_BATCHMEAN,
|
54 |
+
):
|
55 |
+
pid = tl.program_id(0).to(tl.int64)
|
56 |
+
y_ptr += pid * y_stride
|
57 |
+
gt_ptr += pid * gt_stride
|
58 |
+
loss_ptr += pid * loss_stride
|
59 |
+
|
60 |
+
base_offsets = tl.arange(0, BLOCK_SIZE)
|
61 |
+
|
62 |
+
loss_sum = 0.0
|
63 |
+
for i in range(0, n_cols, BLOCK_SIZE):
|
64 |
+
offsets = i + base_offsets
|
65 |
+
mask = offsets < n_cols
|
66 |
+
y = tl.load(y_ptr + offsets, mask=mask, other=0.0)
|
67 |
+
y_true = tl.load(gt_ptr + offsets, mask=mask, other=0.0)
|
68 |
+
|
69 |
+
# KL(y_true || y) = y_true * (log(y_true) - log(y))
|
70 |
+
# We compute KL(y_true || y) with y in the log-space
|
71 |
+
if not log_target:
|
72 |
+
loss = y_true * (tl.log(tl.maximum(y_true, eps)) - y)
|
73 |
+
else:
|
74 |
+
loss = tl.exp(y_true) * (y_true - y)
|
75 |
+
|
76 |
+
if reduction == _REDUCTION_MODE_NONE:
|
77 |
+
tl.store(loss_ptr + offsets, loss, mask=mask)
|
78 |
+
else:
|
79 |
+
loss_sum += tl.sum(loss, axis=0)
|
80 |
+
|
81 |
+
if reduction != _REDUCTION_MODE_NONE:
|
82 |
+
tl.store(loss_ptr, loss_sum)
|
83 |
+
|
84 |
+
|
85 |
+
@triton.jit
|
86 |
+
def _kldiv_kernel_backward(
|
87 |
+
target_ptr,
|
88 |
+
target_stride,
|
89 |
+
new_grads_ptr,
|
90 |
+
new_grads_stride,
|
91 |
+
n_cols,
|
92 |
+
BLOCK_SIZE: tl.constexpr,
|
93 |
+
log_target: tl.constexpr = False,
|
94 |
+
):
|
95 |
+
pid = tl.program_id(0).to(tl.int64)
|
96 |
+
|
97 |
+
target_ptr += pid * target_stride
|
98 |
+
new_grads_ptr += pid * new_grads_stride
|
99 |
+
|
100 |
+
offsets = tl.arange(0, BLOCK_SIZE)
|
101 |
+
mask = offsets < n_cols
|
102 |
+
|
103 |
+
for i in range(0, n_cols, BLOCK_SIZE):
|
104 |
+
offsets = i + tl.arange(0, BLOCK_SIZE)
|
105 |
+
mask = offsets < n_cols
|
106 |
+
|
107 |
+
target = tl.load(target_ptr + offsets, mask=mask, other=0.0)
|
108 |
+
|
109 |
+
if not log_target:
|
110 |
+
res = target * -1
|
111 |
+
else:
|
112 |
+
res = -tl.exp(target)
|
113 |
+
|
114 |
+
tl.store(new_grads_ptr + offsets, res, mask=mask)
|
115 |
+
|
116 |
+
|
117 |
+
def kldiv_forward_triton(y_pred, y_true, log_target, reduction, eps): # [BT, V]
|
118 |
+
BT, V = y_pred.shape
|
119 |
+
BLOCK_SIZE = (
|
120 |
+
min(8192, triton.next_power_of_2(V))
|
121 |
+
if infer_device() == "xpu"
|
122 |
+
else min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
|
123 |
+
)
|
124 |
+
num_warps = 32 if infer_device() == "xpu" else get_num_warps(BLOCK_SIZE)
|
125 |
+
|
126 |
+
grid = (BT,)
|
127 |
+
reduction = _str_to_reduction_mode[reduction]
|
128 |
+
|
129 |
+
out_size = (BT, V) if reduction == _REDUCTION_MODE_NONE.value else (BT,)
|
130 |
+
output_tensor = torch.zeros(out_size, device=y_pred.device, dtype=torch.float32)
|
131 |
+
|
132 |
+
_kldiv_kernel_forward[grid](
|
133 |
+
y_pred,
|
134 |
+
y_pred.stride(0),
|
135 |
+
y_true,
|
136 |
+
y_true.stride(0),
|
137 |
+
output_tensor,
|
138 |
+
output_tensor.stride(0),
|
139 |
+
V,
|
140 |
+
eps=eps,
|
141 |
+
BLOCK_SIZE=BLOCK_SIZE,
|
142 |
+
num_warps=num_warps,
|
143 |
+
log_target=log_target,
|
144 |
+
reduction=reduction,
|
145 |
+
)
|
146 |
+
|
147 |
+
# calculated according to the reduction mode same as in Pytorch. In the later versions, `mean` will be changed to the same behavior as `batchmean`
|
148 |
+
# https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html
|
149 |
+
# https://github.com/pytorch/pytorch/blob/d7b57c4d63edb42e1deeeba9497fcb5f1f748ff2/torch/nn/functional.py#L3372
|
150 |
+
if reduction == _REDUCTION_MODE_BATCHMEAN.value:
|
151 |
+
return output_tensor.sum() / BT
|
152 |
+
elif reduction == _REDUCTION_MODE_SUM.value:
|
153 |
+
return output_tensor.sum(dim=0)
|
154 |
+
elif reduction == _REDUCTION_MODE_MEAN.value:
|
155 |
+
return output_tensor.sum() / (BT * V)
|
156 |
+
else:
|
157 |
+
return output_tensor
|
158 |
+
|
159 |
+
|
160 |
+
def kldiv_backward_triton(target, grad_output, new_grads, log_target):
|
161 |
+
BT, V = target.shape
|
162 |
+
BLOCK_SIZE = (
|
163 |
+
min(8192, triton.next_power_of_2(V))
|
164 |
+
if infer_device() == "xpu"
|
165 |
+
else min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
|
166 |
+
)
|
167 |
+
num_warps = 32 if infer_device() == "xpu" else get_num_warps(BLOCK_SIZE)
|
168 |
+
|
169 |
+
grid = (BT,)
|
170 |
+
|
171 |
+
# We store the gradients in-place in the input tensor
|
172 |
+
_kldiv_kernel_backward[grid](
|
173 |
+
target,
|
174 |
+
target.stride(0),
|
175 |
+
new_grads,
|
176 |
+
new_grads.stride(0),
|
177 |
+
V,
|
178 |
+
BLOCK_SIZE=BLOCK_SIZE,
|
179 |
+
num_warps=num_warps,
|
180 |
+
log_target=log_target,
|
181 |
+
)
|
182 |
+
|
183 |
+
# If cross entropy is the last layer, grad_output is 1.0. Skip the mul then.
|
184 |
+
if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
|
185 |
+
return new_grads
|
186 |
+
|
187 |
+
return new_grads * grad_output
|
188 |
+
|
189 |
+
|
190 |
+
class LigerKLDivLossFunction(torch.autograd.Function):
|
191 |
+
"""
|
192 |
+
Class implementing the forward and backward pass for the KL Divergence Loss using Triton, as defined by the following formula:
|
193 |
+
```python
|
194 |
+
if log_target:
|
195 |
+
loss = target.exp() * (target - input)
|
196 |
+
else:
|
197 |
+
loss = target * (target.log() - input)
|
198 |
+
```,
|
199 |
+
then the loss is reduced according to the `reduction` parameter.
|
200 |
+
as defined in the PyTorch documentation: https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html
|
201 |
+
"""
|
202 |
+
|
203 |
+
@staticmethod
|
204 |
+
@ensure_contiguous
|
205 |
+
def forward(
|
206 |
+
ctx,
|
207 |
+
y_pred: torch.Tensor,
|
208 |
+
y_true: torch.Tensor,
|
209 |
+
reduction: REDUCTION_LITERAL = "batchmean",
|
210 |
+
log_target: bool = False,
|
211 |
+
eps: float = 1e-10,
|
212 |
+
) -> torch.Tensor:
|
213 |
+
"""A forward pass for the KL Divergence Loss.
|
214 |
+
|
215 |
+
Args:
|
216 |
+
ctx: Torch autograd context
|
217 |
+
y_pred (torch.Tensor): A tensor of shape (BT, V) containing the predicted values, expected to be log-probabilities.
|
218 |
+
y_true (torch.Tensor): A tensor of shape (BT, V) containing the target values, expected to be either probabilities or log-probabilities, depending on the value of `log_target`.
|
219 |
+
reduction (REDUCTION_LITERAL, optional): Reduction to be used. Defaults to "batchmean".
|
220 |
+
log_target (bool, optional): If set to true, expects the ground truth to already be log-probabilities. Defaults to False.
|
221 |
+
eps: (float, optional): A small value to avoid division by zero. Defaults to 1e-10.
|
222 |
+
|
223 |
+
Returns:
|
224 |
+
torch.Tensor: The computed KL Divergence Loss, with shape (BT, V) if `reduction` is "none", else a scalar.
|
225 |
+
"""
|
226 |
+
ctx.save_for_backward(y_true)
|
227 |
+
ctx.reduction = reduction
|
228 |
+
ctx.log_target = log_target
|
229 |
+
return kldiv_forward_triton(y_pred, y_true, log_target=log_target, reduction=reduction, eps=eps)
|
230 |
+
|
231 |
+
@staticmethod
|
232 |
+
@ensure_contiguous
|
233 |
+
def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
|
234 |
+
"""A backward pass for the KL Divergence Loss.
|
235 |
+
|
236 |
+
Args:
|
237 |
+
ctx: Torch autograd context
|
238 |
+
grad_output (torch.Tensor): The gradient of the loss with respect to the output.
|
239 |
+
|
240 |
+
Returns:
|
241 |
+
tuple[torch.Tensor, None, None, None, None]: The gradient of the loss with respect to the inputs and None for the other arguments of the forward method.
|
242 |
+
"""
|
243 |
+
(y_true,) = ctx.saved_tensors
|
244 |
+
|
245 |
+
new_grads = torch.empty_like(y_true)
|
246 |
+
|
247 |
+
derivative = kldiv_backward_triton(y_true, grad_output, new_grads, ctx.log_target)
|
248 |
+
|
249 |
+
if ctx.reduction == "batchmean":
|
250 |
+
derivative = derivative / y_true.shape[0]
|
251 |
+
elif ctx.reduction == "sum" or ctx.reduction == "none":
|
252 |
+
pass
|
253 |
+
elif ctx.reduction == "mean":
|
254 |
+
derivative = derivative / (y_true.shape[0] * y_true.shape[1])
|
255 |
+
|
256 |
+
return (
|
257 |
+
derivative,
|
258 |
+
None,
|
259 |
+
None,
|
260 |
+
None,
|
261 |
+
None,
|
262 |
+
)
|
torch-ext/liger_kernels/layer_norm.py
ADDED
@@ -0,0 +1,265 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import operator
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import triton
|
6 |
+
import triton.language as tl
|
7 |
+
|
8 |
+
from utils import calculate_settings
|
9 |
+
from utils import compare_version
|
10 |
+
from utils import ensure_contiguous
|
11 |
+
|
12 |
+
if compare_version("triton", operator.ge, "3.0.0"):
|
13 |
+
try:
|
14 |
+
# typical import path with dispatch available
|
15 |
+
from triton.language.extra.libdevice import rsqrt
|
16 |
+
except ModuleNotFoundError:
|
17 |
+
# for working with NGC containers
|
18 |
+
from triton.language.extra.cuda.libdevice import rsqrt
|
19 |
+
else:
|
20 |
+
from triton.language.math import rsqrt
|
21 |
+
|
22 |
+
|
23 |
+
@triton.jit
|
24 |
+
def _layer_norm_forward_kernel(
|
25 |
+
Y_ptr, # pointer to output, shape (n_rows, n_cols)
|
26 |
+
Y_row_stride, # stride of each row in output
|
27 |
+
X_ptr, # pointer to input, shape (n_rows, n_cols)
|
28 |
+
X_row_stride, # stride of each row in input
|
29 |
+
W_ptr, # pointer to weights, shape (n_cols,)
|
30 |
+
W_row_stride, # stride of each row in weights
|
31 |
+
B_ptr, # pointer to bias, shape (n_cols,)
|
32 |
+
B_row_stride, # stride of each row in bias
|
33 |
+
Mean_ptr, # pointer to mean, shape (n_rows,)
|
34 |
+
Mean_row_stride, # stride of each row in mean
|
35 |
+
RSTD_ptr, # pointer to rstd, shape (n_rows,)
|
36 |
+
RSTD_row_stride, # stride of each row in rstd
|
37 |
+
n_cols,
|
38 |
+
eps,
|
39 |
+
BLOCK_SIZE: tl.constexpr,
|
40 |
+
):
|
41 |
+
"""
|
42 |
+
References:
|
43 |
+
https://arxiv.org/abs/1607.06450
|
44 |
+
https://github.com/karpathy/llm.c/blob/master/doc/layernorm/layernorm.md
|
45 |
+
"""
|
46 |
+
row_idx = tl.program_id(0)
|
47 |
+
col_offsets = tl.arange(0, BLOCK_SIZE)
|
48 |
+
mask = col_offsets < n_cols
|
49 |
+
|
50 |
+
Y_ptr += row_idx * Y_row_stride
|
51 |
+
X_ptr += row_idx * X_row_stride
|
52 |
+
Mean_ptr += row_idx * Mean_row_stride
|
53 |
+
RSTD_ptr += row_idx * RSTD_row_stride
|
54 |
+
|
55 |
+
X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0)
|
56 |
+
W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0)
|
57 |
+
B_row = tl.load(B_ptr + col_offsets, mask=mask, other=0)
|
58 |
+
|
59 |
+
mean = tl.sum(X_row, axis=0) / n_cols
|
60 |
+
Xmm = tl.where(mask, X_row - mean, 0)
|
61 |
+
var = tl.sum(Xmm * Xmm, axis=0) / n_cols
|
62 |
+
rstd = rsqrt(var + eps)
|
63 |
+
|
64 |
+
tl.store(Mean_ptr, mean)
|
65 |
+
tl.store(RSTD_ptr, rstd)
|
66 |
+
|
67 |
+
Y_row = Xmm * rstd * W_row + B_row
|
68 |
+
|
69 |
+
tl.store(Y_ptr + col_offsets, Y_row, mask=mask)
|
70 |
+
|
71 |
+
|
72 |
+
@triton.jit
|
73 |
+
def _layer_norm_backward_kernel(
|
74 |
+
X_ptr, # pointer to input, shape (n_rows, n_cols)
|
75 |
+
W_ptr, # pointer to weights, shape (n_cols,)
|
76 |
+
Mean_ptr, # pointer to mean, shape (n_rows,)
|
77 |
+
RSTD_ptr, # pointer to rstd, shape (n_rows,)
|
78 |
+
DX_ptr, # pointer to input grad, shape (n_rows, n_cols)
|
79 |
+
DW_ptr, # pointer to weights grad, shape (n_cols,)
|
80 |
+
DB_ptr, # pointer to bias grad, shape (n_cols,)
|
81 |
+
DY_ptr, # pointer to output grad, shape (n_rows, n_cols)
|
82 |
+
stride_x, # stride of each row in input
|
83 |
+
stride_dx, # stride of each row in input grad
|
84 |
+
stride_dw, # stride of each row in weights grad
|
85 |
+
stride_db, # stride of each row in bias grad
|
86 |
+
stride_dy, # stride of each row in output grad
|
87 |
+
n_rows,
|
88 |
+
n_cols,
|
89 |
+
rows_per_program: tl.constexpr,
|
90 |
+
BLOCK_SIZE: tl.constexpr,
|
91 |
+
dtype: tl.constexpr,
|
92 |
+
):
|
93 |
+
"""
|
94 |
+
References:
|
95 |
+
https://arxiv.org/abs/1607.06450
|
96 |
+
https://github.com/karpathy/llm.c/blob/master/doc/layernorm/layernorm.md
|
97 |
+
https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
|
98 |
+
https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/triton/layer_norm.py
|
99 |
+
"""
|
100 |
+
row_block_id = tl.program_id(0)
|
101 |
+
row_start = row_block_id * rows_per_program
|
102 |
+
row_end = min((row_block_id + 1) * rows_per_program, n_rows)
|
103 |
+
cols = tl.arange(0, BLOCK_SIZE)
|
104 |
+
mask = cols < n_cols
|
105 |
+
|
106 |
+
dw_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
|
107 |
+
db_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
|
108 |
+
|
109 |
+
X_ptr += row_start * stride_x
|
110 |
+
Mean_ptr += row_start
|
111 |
+
RSTD_ptr += row_start
|
112 |
+
DX_ptr += row_start * stride_dx
|
113 |
+
DY_ptr += row_start * stride_dy
|
114 |
+
|
115 |
+
for _ in range(row_start, row_end):
|
116 |
+
x = tl.load(X_ptr + cols, mask=mask, other=0.0)
|
117 |
+
w = tl.load(W_ptr + cols, mask=mask, other=0.0)
|
118 |
+
dy = tl.load(DY_ptr + cols, mask=mask, other=0.0)
|
119 |
+
mean = tl.load(Mean_ptr)
|
120 |
+
rstd = tl.load(RSTD_ptr)
|
121 |
+
|
122 |
+
x_hat = (x - mean) * rstd
|
123 |
+
wdy = w * dy
|
124 |
+
c1 = tl.sum(x_hat * wdy, axis=0) / n_cols
|
125 |
+
c2 = tl.sum(wdy, axis=0) / n_cols
|
126 |
+
dx = (wdy - (x_hat * c1 + c2)) * rstd
|
127 |
+
tl.store(DX_ptr + cols, dx.to(dtype), mask=mask)
|
128 |
+
|
129 |
+
dw_row += dy * x_hat
|
130 |
+
db_row += dy
|
131 |
+
|
132 |
+
X_ptr += stride_x
|
133 |
+
Mean_ptr += 1
|
134 |
+
RSTD_ptr += 1
|
135 |
+
DX_ptr += stride_dx
|
136 |
+
DY_ptr += stride_dy
|
137 |
+
|
138 |
+
tl.store(DW_ptr + row_block_id * stride_dw + cols, dw_row.to(dtype), mask=mask)
|
139 |
+
tl.store(DB_ptr + row_block_id * stride_db + cols, db_row.to(dtype), mask=mask)
|
140 |
+
|
141 |
+
|
142 |
+
def layer_norm_forward(X, W, B, eps):
|
143 |
+
shape = X.shape
|
144 |
+
dim = shape[-1]
|
145 |
+
X = X.view(-1, dim)
|
146 |
+
n_rows, n_cols = X.shape
|
147 |
+
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
148 |
+
Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
|
149 |
+
Mean = torch.empty(n_rows, dtype=X.dtype, device=X.device)
|
150 |
+
RSTD = torch.empty(n_rows, dtype=X.dtype, device=X.device)
|
151 |
+
if X.shape[1] != W.shape[0]:
|
152 |
+
raise ValueError(
|
153 |
+
f"Incompatible dimensions: input feature size (X.shape[1]={X.shape[1]}) "
|
154 |
+
f"must match weight size (W.shape[0]={W.shape[0]})"
|
155 |
+
)
|
156 |
+
|
157 |
+
# XPU-specific optimization
|
158 |
+
kernel_args = {}
|
159 |
+
if X.device.type == "xpu":
|
160 |
+
kernel_args["grf_mode"] = "large"
|
161 |
+
|
162 |
+
_layer_norm_forward_kernel[(n_rows,)](
|
163 |
+
Y,
|
164 |
+
Y.stride(0),
|
165 |
+
X,
|
166 |
+
X.stride(0),
|
167 |
+
W,
|
168 |
+
W.stride(0),
|
169 |
+
B,
|
170 |
+
B.stride(0),
|
171 |
+
Mean,
|
172 |
+
Mean.stride(0),
|
173 |
+
RSTD,
|
174 |
+
RSTD.stride(0),
|
175 |
+
n_cols,
|
176 |
+
eps,
|
177 |
+
BLOCK_SIZE=BLOCK_SIZE,
|
178 |
+
num_warps=num_warps,
|
179 |
+
**kernel_args, # XPU-specific optimization
|
180 |
+
)
|
181 |
+
return Y.view(*shape), X, Mean, RSTD, BLOCK_SIZE, num_warps
|
182 |
+
|
183 |
+
|
184 |
+
def layer_norm_backward(dY, X, W, B, Mean, RSTD):
|
185 |
+
shape = dY.shape
|
186 |
+
dim = shape[-1]
|
187 |
+
dY = dY.view(-1, dim)
|
188 |
+
n_rows, n_cols = dY.shape
|
189 |
+
|
190 |
+
sm_count = 1
|
191 |
+
if X.device.type == "cuda":
|
192 |
+
sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
|
193 |
+
elif X.device.type == "xpu":
|
194 |
+
sm_count = torch.xpu.get_device_properties(X.device).gpu_eu_count
|
195 |
+
|
196 |
+
DX = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
|
197 |
+
_DW = torch.empty((sm_count, n_cols), dtype=W.dtype, device=W.device)
|
198 |
+
_DB = torch.empty((sm_count, n_cols), dtype=W.dtype, device=W.device)
|
199 |
+
|
200 |
+
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
201 |
+
if n_cols > BLOCK_SIZE:
|
202 |
+
raise RuntimeError(
|
203 |
+
f"Feature dimension {n_cols} exceeds maximum supported size of {BLOCK_SIZE}. Consider using a smaller feature dimension."
|
204 |
+
)
|
205 |
+
|
206 |
+
rows_per_program = math.ceil(n_rows / sm_count)
|
207 |
+
grid = (sm_count,)
|
208 |
+
triton_dtype = (
|
209 |
+
tl.float32
|
210 |
+
if X.dtype == torch.float32
|
211 |
+
else tl.bfloat16
|
212 |
+
if X.dtype == torch.bfloat16
|
213 |
+
else tl.float16
|
214 |
+
if X.dtype == torch.float16
|
215 |
+
else tl.float32 # fallback to float32 for other types
|
216 |
+
)
|
217 |
+
|
218 |
+
# XPU-specific optimization
|
219 |
+
kernel_args = {}
|
220 |
+
if X.device.type == "xpu":
|
221 |
+
kernel_args.update({"grf_mode": "large", "num_warps": 32, "num_stages": 4})
|
222 |
+
|
223 |
+
_layer_norm_backward_kernel[grid](
|
224 |
+
X,
|
225 |
+
W,
|
226 |
+
Mean,
|
227 |
+
RSTD,
|
228 |
+
DX,
|
229 |
+
_DW,
|
230 |
+
_DB,
|
231 |
+
dY,
|
232 |
+
X.stride(0),
|
233 |
+
DX.stride(0),
|
234 |
+
_DW.stride(0),
|
235 |
+
_DB.stride(0),
|
236 |
+
dY.stride(0),
|
237 |
+
n_rows,
|
238 |
+
n_cols,
|
239 |
+
rows_per_program,
|
240 |
+
BLOCK_SIZE=BLOCK_SIZE,
|
241 |
+
dtype=triton_dtype,
|
242 |
+
**kernel_args, # XPU-specific optimization
|
243 |
+
)
|
244 |
+
|
245 |
+
DW = _DW.sum(dim=0).to(W.dtype)
|
246 |
+
DB = _DB.sum(dim=0).to(W.dtype)
|
247 |
+
|
248 |
+
DX = DX.view(*shape)
|
249 |
+
return DX, DW, DB
|
250 |
+
|
251 |
+
|
252 |
+
class LigerLayerNormFunction(torch.autograd.Function):
|
253 |
+
@staticmethod
|
254 |
+
@ensure_contiguous
|
255 |
+
def forward(ctx, X, W, B, eps):
|
256 |
+
Y, X, Mean, RSTD, BLOCK_SIZE, num_warps = layer_norm_forward(X, W, B, eps)
|
257 |
+
ctx.save_for_backward(X, W, B, Mean, RSTD)
|
258 |
+
return Y
|
259 |
+
|
260 |
+
@staticmethod
|
261 |
+
@ensure_contiguous
|
262 |
+
def backward(ctx, dY):
|
263 |
+
X, W, B, Mean, RSTD = ctx.saved_tensors
|
264 |
+
DX, DW, DB = layer_norm_backward(dY, X, W, B, Mean, RSTD)
|
265 |
+
return DX, DW, DB, None
|
torch-ext/liger_kernels/qwen2vl_mrope.py
ADDED
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import triton
|
3 |
+
import triton.language as tl
|
4 |
+
|
5 |
+
|
6 |
+
@triton.jit
|
7 |
+
def _triton_qwen2vl_mrope(
|
8 |
+
q_ptr,
|
9 |
+
k_ptr,
|
10 |
+
cos,
|
11 |
+
sin,
|
12 |
+
sl,
|
13 |
+
bs: tl.constexpr,
|
14 |
+
n_qh: tl.constexpr,
|
15 |
+
n_kh: tl.constexpr,
|
16 |
+
hd: tl.constexpr,
|
17 |
+
pad_n_qh: tl.constexpr,
|
18 |
+
pad_n_kh: tl.constexpr,
|
19 |
+
pad_hd: tl.constexpr,
|
20 |
+
mrope_section_t: tl.constexpr,
|
21 |
+
mrope_section_h: tl.constexpr,
|
22 |
+
BLOCK_SIZE: tl.constexpr,
|
23 |
+
BACKWARD_PASS: tl.constexpr = False,
|
24 |
+
):
|
25 |
+
pid = tl.program_id(0)
|
26 |
+
|
27 |
+
# locate start address
|
28 |
+
q_ptr = q_ptr + pid * (n_qh * hd)
|
29 |
+
k_ptr = k_ptr + pid * (n_kh * hd)
|
30 |
+
|
31 |
+
# ####################################################################
|
32 |
+
# get the cos(mθ_{i...d/2}) and sin(mθ_{i...d/2}) for token position
|
33 |
+
# m of this program instance
|
34 |
+
# ####################################################################
|
35 |
+
|
36 |
+
# 1. program instances are laid out in a 1D vector of size bsz * seq_len, which
|
37 |
+
# effectively represents a 2D grid of size [bsz, seq_len] with seq_len dimension
|
38 |
+
# being the fastest changing dimension. Thus we can simply do pid // sl to get the batch index
|
39 |
+
# and pid % sl to get the sequence index.
|
40 |
+
# 2. We only need the left half of cos and sin matrix because the right half is just
|
41 |
+
# a clone of the left half.
|
42 |
+
t_end = mrope_section_t
|
43 |
+
h_end = t_end + mrope_section_h
|
44 |
+
|
45 |
+
t_cos = cos + pid * hd
|
46 |
+
h_cos = t_cos + bs * sl * hd
|
47 |
+
w_cos = h_cos + bs * sl * hd
|
48 |
+
t_sin = sin + pid * hd
|
49 |
+
h_sin = t_sin + bs * sl * hd
|
50 |
+
w_sin = h_sin + bs * sl * hd
|
51 |
+
|
52 |
+
cos_offsets = tl.arange(0, pad_hd // 2)
|
53 |
+
t_mask = cos_offsets < t_end
|
54 |
+
h_mask = (t_end <= cos_offsets) & (cos_offsets < h_end)
|
55 |
+
w_mask = (h_end <= cos_offsets) & (cos_offsets < hd // 2)
|
56 |
+
t_cos_row = tl.load(t_cos + cos_offsets, mask=t_mask, other=0)
|
57 |
+
h_cos_row = tl.load(h_cos + cos_offsets, mask=h_mask, other=0)
|
58 |
+
w_cos_row = tl.load(w_cos + cos_offsets, mask=w_mask, other=0)
|
59 |
+
t_sin_row = tl.load(t_sin + cos_offsets, mask=t_mask, other=0)
|
60 |
+
h_sin_row = tl.load(h_sin + cos_offsets, mask=h_mask, other=0)
|
61 |
+
w_sin_row = tl.load(w_sin + cos_offsets, mask=w_mask, other=0)
|
62 |
+
cos_row = t_cos_row + h_cos_row + w_cos_row
|
63 |
+
sin_row = t_sin_row + h_sin_row + w_sin_row
|
64 |
+
|
65 |
+
# ####################################################################
|
66 |
+
# Load the left and right half of q and k for the current
|
67 |
+
# program instance (i.e. for the current token) separately
|
68 |
+
# ####################################################################
|
69 |
+
# left half of the head
|
70 |
+
first_half_q_offsets = tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
|
71 |
+
first_half_k_offsets = tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
|
72 |
+
first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (tl.arange(0, pad_hd // 2)[None, :] < hd // 2)
|
73 |
+
first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (tl.arange(0, pad_hd // 2)[None, :] < hd // 2)
|
74 |
+
q_tile_1 = tl.load(q_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to(sin_row.dtype)
|
75 |
+
k_tile_1 = tl.load(k_ptr + first_half_k_offsets, mask=first_k_mask, other=0).to(sin_row.dtype)
|
76 |
+
|
77 |
+
# right half of the head
|
78 |
+
second_half_q_offsets = first_half_q_offsets + (hd // 2)
|
79 |
+
second_half_k_offsets = first_half_k_offsets + (hd // 2)
|
80 |
+
second_q_mask = first_q_mask
|
81 |
+
second_k_mask = first_k_mask
|
82 |
+
q_tile_2 = tl.load(q_ptr + second_half_q_offsets, mask=second_q_mask, other=0).to(sin_row.dtype)
|
83 |
+
k_tile_2 = tl.load(k_ptr + second_half_k_offsets, mask=second_k_mask, other=0).to(sin_row.dtype)
|
84 |
+
|
85 |
+
if not BACKWARD_PASS:
|
86 |
+
# y = [x1, x2] * [cos, cos] + [-x2, x1] * [sin, sin]
|
87 |
+
new_q_tile_1 = q_tile_1 * cos_row - q_tile_2 * sin_row
|
88 |
+
tl.store(q_ptr + first_half_q_offsets, new_q_tile_1, mask=first_q_mask)
|
89 |
+
new_q_tile_2 = q_tile_2 * cos_row + q_tile_1 * sin_row
|
90 |
+
tl.store(q_ptr + second_half_q_offsets, new_q_tile_2, mask=second_q_mask)
|
91 |
+
|
92 |
+
new_k_tile_1 = k_tile_1 * cos_row - k_tile_2 * sin_row
|
93 |
+
tl.store(k_ptr + first_half_k_offsets, new_k_tile_1, mask=first_k_mask)
|
94 |
+
new_k_tile_2 = k_tile_2 * cos_row + k_tile_1 * sin_row
|
95 |
+
tl.store(k_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask)
|
96 |
+
else:
|
97 |
+
# with some math, we can get:
|
98 |
+
# dy = [dx1, dx2] * [cos, cos] + [-dx2, dx1] * [-sin, -sin]
|
99 |
+
new_q_tile_1 = q_tile_1 * cos_row + q_tile_2 * sin_row
|
100 |
+
tl.store(q_ptr + first_half_q_offsets, new_q_tile_1, mask=first_q_mask)
|
101 |
+
new_q_tile_2 = q_tile_2 * cos_row - q_tile_1 * sin_row
|
102 |
+
tl.store(q_ptr + second_half_q_offsets, new_q_tile_2, mask=second_q_mask)
|
103 |
+
|
104 |
+
new_k_tile_1 = k_tile_1 * cos_row + k_tile_2 * sin_row
|
105 |
+
tl.store(k_ptr + first_half_k_offsets, new_k_tile_1, mask=first_k_mask)
|
106 |
+
new_k_tile_2 = k_tile_2 * cos_row - k_tile_1 * sin_row
|
107 |
+
tl.store(k_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask)
|
108 |
+
|
109 |
+
|
110 |
+
def qwen2vl_mrope_forward(q, k, cos, sin, mrope_section):
|
111 |
+
# transpose it back to the physical shape because Triton looks at the physical storage
|
112 |
+
# note: q and k are incontiguous before the transformation and will become contiguous after transpose
|
113 |
+
q = q.transpose(1, 2)
|
114 |
+
k = k.transpose(1, 2)
|
115 |
+
|
116 |
+
batch_size, seq_len, n_q_head, head_dim = q.shape
|
117 |
+
n_kv_head = k.shape[2]
|
118 |
+
pad_hd = triton.next_power_of_2(head_dim)
|
119 |
+
pad_n_q_head = triton.next_power_of_2(n_q_head)
|
120 |
+
pad_n_kv_head = triton.next_power_of_2(n_kv_head)
|
121 |
+
BLOCK_SIZE = max(pad_n_q_head, pad_n_kv_head)
|
122 |
+
|
123 |
+
n_row = batch_size * seq_len
|
124 |
+
|
125 |
+
# ensure tensors passed into the kernel are contiguous. It will be no-op if they are already contiguous
|
126 |
+
q = q.contiguous()
|
127 |
+
k = k.contiguous()
|
128 |
+
cos = cos.contiguous()
|
129 |
+
sin = sin.contiguous()
|
130 |
+
|
131 |
+
_triton_qwen2vl_mrope[(n_row,)](
|
132 |
+
q,
|
133 |
+
k,
|
134 |
+
cos,
|
135 |
+
sin,
|
136 |
+
seq_len,
|
137 |
+
batch_size,
|
138 |
+
n_q_head,
|
139 |
+
n_kv_head,
|
140 |
+
head_dim,
|
141 |
+
pad_n_q_head,
|
142 |
+
pad_n_kv_head,
|
143 |
+
pad_hd,
|
144 |
+
mrope_section[0],
|
145 |
+
mrope_section[1],
|
146 |
+
BLOCK_SIZE=BLOCK_SIZE,
|
147 |
+
BACKWARD_PASS=False,
|
148 |
+
)
|
149 |
+
return q.transpose(1, 2), k.transpose(1, 2), cos, sin
|
150 |
+
|
151 |
+
|
152 |
+
def qwen2vl_mrope_backward(dq, dk, cos, sin, mrope_section):
|
153 |
+
dq = dq.transpose(1, 2)
|
154 |
+
dk = dk.transpose(1, 2)
|
155 |
+
|
156 |
+
batch_size, seq_len, n_q_head, head_dim = dq.shape
|
157 |
+
n_kv_head = dk.shape[2]
|
158 |
+
pad_hd = triton.next_power_of_2(head_dim)
|
159 |
+
pad_n_q_head = triton.next_power_of_2(n_q_head)
|
160 |
+
pad_n_kv_head = triton.next_power_of_2(n_kv_head)
|
161 |
+
BLOCK_SIZE = max(pad_n_q_head, pad_n_kv_head)
|
162 |
+
|
163 |
+
n_row = batch_size * seq_len
|
164 |
+
|
165 |
+
# ensure dq and dk are contiguous
|
166 |
+
dq = dq.contiguous()
|
167 |
+
dk = dk.contiguous()
|
168 |
+
|
169 |
+
# backward is similar to forward except swapping few ops
|
170 |
+
_triton_qwen2vl_mrope[(n_row,)](
|
171 |
+
dq,
|
172 |
+
dk,
|
173 |
+
cos,
|
174 |
+
sin,
|
175 |
+
seq_len,
|
176 |
+
batch_size,
|
177 |
+
n_q_head,
|
178 |
+
n_kv_head,
|
179 |
+
head_dim,
|
180 |
+
pad_n_q_head,
|
181 |
+
pad_n_kv_head,
|
182 |
+
pad_hd,
|
183 |
+
mrope_section[0],
|
184 |
+
mrope_section[1],
|
185 |
+
BLOCK_SIZE=BLOCK_SIZE,
|
186 |
+
BACKWARD_PASS=True,
|
187 |
+
)
|
188 |
+
return dq.transpose(1, 2), dk.transpose(1, 2)
|
189 |
+
|
190 |
+
|
191 |
+
class LigerQwen2VLMRopeFunction(torch.autograd.Function):
|
192 |
+
"""
|
193 |
+
Triton implementation of the Qwen2VL Multimodal Rotary Positional Embedding (M-RoPE) operation.
|
194 |
+
|
195 |
+
Please find the corresponding HuggingFace implementation here:
|
196 |
+
https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py
|
197 |
+
"""
|
198 |
+
|
199 |
+
@staticmethod
|
200 |
+
def forward(ctx, q, k, cos, sin, mrope_section, unsqueeze_dim=1):
|
201 |
+
"""
|
202 |
+
q size: (bsz, n_q_head, seq_len, head_dim)
|
203 |
+
k size: (bsz, n_kv_head, seq_len, head_dim)
|
204 |
+
cos size: (3, bsz, seq_len, head_dim)
|
205 |
+
sin size: (3, bsz, seq_len, head_dim)
|
206 |
+
"""
|
207 |
+
q, k, cos, sin = qwen2vl_mrope_forward(q, k, cos, sin, mrope_section)
|
208 |
+
ctx.save_for_backward(cos, sin)
|
209 |
+
ctx.mrope_section = mrope_section
|
210 |
+
return q, k
|
211 |
+
|
212 |
+
def backward(ctx, dq, dk):
|
213 |
+
"""
|
214 |
+
dq size: (bsz, n_q_head, seq_len, head_dim)
|
215 |
+
dk size: (bsz, n_kv_head, seq_len, head_dim)
|
216 |
+
cos size: (3, bsz, seq_len, head_dim)
|
217 |
+
sin size: (3, bsz, seq_len, head_dim)
|
218 |
+
"""
|
219 |
+
cos, sin = ctx.saved_tensors
|
220 |
+
mrope_section = ctx.mrope_section
|
221 |
+
dq, dk = qwen2vl_mrope_backward(dq, dk, cos, sin, mrope_section)
|
222 |
+
return dq, dk, None, None, None, None
|
torch-ext/liger_kernels/rms_norm.py
ADDED
@@ -0,0 +1,365 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This file incorporates code from Unsloth licensed under the Apache License, Version 2.0.
|
3 |
+
See the original Unsloth repository at https://github.com/unslothai/unsloth.
|
4 |
+
|
5 |
+
The following line
|
6 |
+
https://github.com/linkedin/Liger-Kernel/blob/7382a8761f9af679482b968f9348013d933947c7/src/liger_kernel/ops/rms_norm.py#L30
|
7 |
+
is based on code from Unsloth, located at:
|
8 |
+
https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/rms_layernorm.py#L22
|
9 |
+
|
10 |
+
Modifications made by Yanning Chen, 2024.
|
11 |
+
"""
|
12 |
+
|
13 |
+
import math
|
14 |
+
import operator
|
15 |
+
|
16 |
+
import torch
|
17 |
+
import triton
|
18 |
+
import triton.language as tl
|
19 |
+
|
20 |
+
from utils import calculate_settings
|
21 |
+
from utils import compare_version
|
22 |
+
from utils import ensure_contiguous
|
23 |
+
from utils import torch_to_triton_dtype
|
24 |
+
|
25 |
+
if compare_version("triton", operator.ge, "3.0.0"):
|
26 |
+
try:
|
27 |
+
# typical import path with dispatch available
|
28 |
+
from triton.language.extra.libdevice import rsqrt
|
29 |
+
except ModuleNotFoundError:
|
30 |
+
# for working with NGC containers
|
31 |
+
from triton.language.extra.cuda.libdevice import rsqrt
|
32 |
+
else:
|
33 |
+
from triton.language.math import rsqrt
|
34 |
+
|
35 |
+
|
36 |
+
_CASTING_MODE_NONE: tl.constexpr = tl.constexpr(-1)
|
37 |
+
_CASTING_MODE_LLAMA: tl.constexpr = tl.constexpr(0)
|
38 |
+
_CASTING_MODE_GEMMA: tl.constexpr = tl.constexpr(1)
|
39 |
+
|
40 |
+
|
41 |
+
@triton.jit
|
42 |
+
def _rms_norm_forward_kernel(
|
43 |
+
Y_ptr,
|
44 |
+
Y_row_stride,
|
45 |
+
X_ptr,
|
46 |
+
X_row_stride,
|
47 |
+
W_ptr,
|
48 |
+
W_row_stride,
|
49 |
+
RSTD_ptr,
|
50 |
+
RSTD_row_stride,
|
51 |
+
n_cols,
|
52 |
+
eps,
|
53 |
+
offset,
|
54 |
+
casting_mode: tl.constexpr, # constexpr so the `if` blocks can be optimized out
|
55 |
+
BLOCK_SIZE: tl.constexpr,
|
56 |
+
):
|
57 |
+
"""
|
58 |
+
y_i = (x_i / (RMS)) * (offset + wi), RMS = sqrt(sum(x_i^2) / N)
|
59 |
+
|
60 |
+
Reference:
|
61 |
+
1. https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
|
62 |
+
2. https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/rms_layernorm.py#L22
|
63 |
+
3. https://arxiv.org/pdf/1910.07467
|
64 |
+
"""
|
65 |
+
|
66 |
+
row_idx = tl.program_id(0)
|
67 |
+
col_offsets = tl.arange(0, BLOCK_SIZE)
|
68 |
+
mask = col_offsets < n_cols
|
69 |
+
|
70 |
+
Y_ptr += row_idx * Y_row_stride
|
71 |
+
X_ptr += row_idx * X_row_stride
|
72 |
+
RSTD_ptr += row_idx * RSTD_row_stride
|
73 |
+
|
74 |
+
X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0)
|
75 |
+
X_row_dtype = X_row.dtype
|
76 |
+
W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0)
|
77 |
+
|
78 |
+
# On Llama, only rstd is computed on fp32
|
79 |
+
if casting_mode == _CASTING_MODE_LLAMA:
|
80 |
+
X_row = X_row.to(tl.float32)
|
81 |
+
|
82 |
+
# Gemma computes everything on fp32, and then casts back the output to the original dtype
|
83 |
+
if casting_mode == _CASTING_MODE_GEMMA:
|
84 |
+
W_row = W_row.to(tl.float32)
|
85 |
+
X_row = X_row.to(tl.float32)
|
86 |
+
|
87 |
+
if casting_mode == _CASTING_MODE_NONE:
|
88 |
+
eps = eps.to(X_row_dtype)
|
89 |
+
offset = offset.to(X_row_dtype)
|
90 |
+
|
91 |
+
mean_square = tl.sum(X_row * X_row, axis=0) / n_cols
|
92 |
+
rstd = rsqrt(mean_square + eps)
|
93 |
+
|
94 |
+
# We can save time by caching rms with minimal memory overhead
|
95 |
+
# because rms is much smaller compared to X_row, as rms is for each row.
|
96 |
+
# However, on the computation side, it can save 4 operations (*, sum, /, sqrt).
|
97 |
+
tl.store(RSTD_ptr, rstd)
|
98 |
+
|
99 |
+
X_row = X_row * rstd
|
100 |
+
|
101 |
+
# On Llama, the multiplication with the weight is done on the original dtype
|
102 |
+
if casting_mode == _CASTING_MODE_LLAMA:
|
103 |
+
X_row = X_row.to(X_row_dtype)
|
104 |
+
|
105 |
+
Y_row = X_row * (offset + W_row)
|
106 |
+
|
107 |
+
if casting_mode == _CASTING_MODE_GEMMA:
|
108 |
+
Y_row = Y_row.to(X_row_dtype)
|
109 |
+
|
110 |
+
tl.store(Y_ptr + col_offsets, Y_row, mask=mask)
|
111 |
+
|
112 |
+
|
113 |
+
@triton.jit
|
114 |
+
def _rms_norm_backward_kernel(
|
115 |
+
dY_ptr,
|
116 |
+
dY_row_stride,
|
117 |
+
dX_ptr,
|
118 |
+
dX_row_stride,
|
119 |
+
X_ptr,
|
120 |
+
X_row_stride,
|
121 |
+
X_dtype: tl.constexpr,
|
122 |
+
W_ptr,
|
123 |
+
W_row_stride,
|
124 |
+
RSTD_ptr,
|
125 |
+
RSTD_row_stride,
|
126 |
+
dW_ptr,
|
127 |
+
dW_row_stride,
|
128 |
+
n_rows,
|
129 |
+
n_cols,
|
130 |
+
offset,
|
131 |
+
rows_per_program: tl.constexpr,
|
132 |
+
casting_mode: tl.constexpr,
|
133 |
+
BLOCK_SIZE: tl.constexpr,
|
134 |
+
):
|
135 |
+
"""
|
136 |
+
dx = (1 / RMS) * [dy * (w + offset - (1 / N) * (1 / RMS^2) * ((dy * (w + offset)) dot x) * x]. * means element-wise multiplication, whileas dot means dot product
|
137 |
+
dw = sum(dy * (x / RMS)). summation over BxT dimension
|
138 |
+
"""
|
139 |
+
|
140 |
+
row_block_id = tl.program_id(0)
|
141 |
+
row_start = row_block_id * rows_per_program
|
142 |
+
row_end = min((row_block_id + 1) * rows_per_program, n_rows)
|
143 |
+
col_offsets = tl.arange(0, BLOCK_SIZE)
|
144 |
+
mask = col_offsets < n_cols
|
145 |
+
|
146 |
+
dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
|
147 |
+
|
148 |
+
dY_ptr += row_start * dY_row_stride
|
149 |
+
dX_ptr += row_start * dX_row_stride
|
150 |
+
|
151 |
+
X_ptr += row_start * X_row_stride
|
152 |
+
RSTD_ptr += row_start
|
153 |
+
|
154 |
+
W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0.0)
|
155 |
+
W_row = W_row + offset
|
156 |
+
|
157 |
+
for _ in range(row_start, row_end):
|
158 |
+
dY_row = tl.load(dY_ptr + col_offsets, mask=mask, other=0.0)
|
159 |
+
X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0.0)
|
160 |
+
|
161 |
+
# Get cached rms
|
162 |
+
rstd_row = tl.load(RSTD_ptr)
|
163 |
+
|
164 |
+
X_row = X_row.to(tl.float32)
|
165 |
+
|
166 |
+
# Different bacward graphs for different casting modes
|
167 |
+
if casting_mode == _CASTING_MODE_LLAMA:
|
168 |
+
m = (dY_row * W_row).to(tl.float32)
|
169 |
+
|
170 |
+
elif casting_mode == _CASTING_MODE_GEMMA:
|
171 |
+
dY_row = dY_row.to(tl.float32)
|
172 |
+
m = dY_row * W_row
|
173 |
+
else:
|
174 |
+
m = dY_row * W_row
|
175 |
+
|
176 |
+
dX_row = rstd_row * m
|
177 |
+
|
178 |
+
dX_row += (rstd_row) * (-(1 / n_cols) * rstd_row * rstd_row * tl.sum(m * X_row, axis=0) * X_row)
|
179 |
+
|
180 |
+
# calculate the gradient of W
|
181 |
+
if casting_mode == _CASTING_MODE_LLAMA:
|
182 |
+
dW_row += dY_row * (X_row * rstd_row).to(X_dtype)
|
183 |
+
else:
|
184 |
+
# here X_row is already in fp32 (see previous if block)
|
185 |
+
dW_row += dY_row * (X_row * rstd_row)
|
186 |
+
|
187 |
+
tl.store(dX_ptr + col_offsets, dX_row.to(X_dtype), mask=mask)
|
188 |
+
|
189 |
+
dY_ptr += dY_row_stride
|
190 |
+
dX_ptr += dX_row_stride
|
191 |
+
X_ptr += X_row_stride
|
192 |
+
RSTD_ptr += RSTD_row_stride
|
193 |
+
|
194 |
+
tl.store(dW_ptr + row_block_id * dW_row_stride + col_offsets, dW_row, mask=mask)
|
195 |
+
|
196 |
+
|
197 |
+
_str_to_casting_mode = {
|
198 |
+
"llama": _CASTING_MODE_LLAMA.value,
|
199 |
+
"gemma": _CASTING_MODE_GEMMA.value,
|
200 |
+
"none": _CASTING_MODE_NONE.value,
|
201 |
+
}
|
202 |
+
|
203 |
+
|
204 |
+
def rms_norm_forward(X, W, eps, offset, casting_mode):
|
205 |
+
if not isinstance(casting_mode, int):
|
206 |
+
assert casting_mode in _str_to_casting_mode, f"Invalid casting mode: {casting_mode}"
|
207 |
+
casting_mode = _str_to_casting_mode[casting_mode]
|
208 |
+
else:
|
209 |
+
assert casting_mode in _str_to_casting_mode.values(), f"Invalid casting mode: {casting_mode}"
|
210 |
+
|
211 |
+
shape = X.shape
|
212 |
+
dim = shape[-1]
|
213 |
+
X = X.view(-1, dim)
|
214 |
+
n_rows, n_cols = X.shape
|
215 |
+
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
216 |
+
|
217 |
+
Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
|
218 |
+
# RSTD is to cache rstd for each row
|
219 |
+
# RSTD is always computed/stored in fp32 if we are using Llama or Gemma casting mode
|
220 |
+
rstd_dtype = torch.float32 if casting_mode in (_CASTING_MODE_LLAMA.value, _CASTING_MODE_GEMMA.value) else X.dtype
|
221 |
+
RSTD = torch.empty(n_rows, dtype=rstd_dtype, device=X.device)
|
222 |
+
|
223 |
+
# Check constraints.
|
224 |
+
assert X.shape[1] == W.shape[0], "Incompatible hidden size dimension between tensor1.shape[1] and tensor2.shape[0]"
|
225 |
+
|
226 |
+
# XPU-specific optimization
|
227 |
+
kernel_args = {}
|
228 |
+
if X.device.type == "xpu":
|
229 |
+
kernel_args["grf_mode"] = "large"
|
230 |
+
_rms_norm_forward_kernel[(n_rows,)](
|
231 |
+
Y,
|
232 |
+
Y.stride(0),
|
233 |
+
X,
|
234 |
+
X.stride(0),
|
235 |
+
W,
|
236 |
+
W.stride(0),
|
237 |
+
RSTD,
|
238 |
+
RSTD.stride(0),
|
239 |
+
n_cols,
|
240 |
+
eps,
|
241 |
+
offset,
|
242 |
+
casting_mode,
|
243 |
+
BLOCK_SIZE=BLOCK_SIZE,
|
244 |
+
num_warps=num_warps,
|
245 |
+
**kernel_args, # XPU-specific optimization
|
246 |
+
)
|
247 |
+
return Y.view(*shape), X, RSTD, BLOCK_SIZE, num_warps, casting_mode
|
248 |
+
|
249 |
+
|
250 |
+
def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warps, in_place):
|
251 |
+
shape = dY.shape
|
252 |
+
dim = shape[-1]
|
253 |
+
dY = dY.view(-1, dim)
|
254 |
+
n_rows, n_cols = dY.shape
|
255 |
+
|
256 |
+
sm_count = 1
|
257 |
+
if X.device.type == "cuda":
|
258 |
+
sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
|
259 |
+
elif X.device.type == "xpu":
|
260 |
+
sm_count = torch.xpu.get_device_properties(X.device).gpu_eu_count
|
261 |
+
|
262 |
+
# fp32 for numerical stability especially.
|
263 |
+
_dW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)
|
264 |
+
|
265 |
+
if n_cols > BLOCK_SIZE:
|
266 |
+
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
|
267 |
+
rows_per_program = math.ceil(n_rows / sm_count)
|
268 |
+
grid = (sm_count,)
|
269 |
+
|
270 |
+
if in_place is True:
|
271 |
+
dX = dY
|
272 |
+
else:
|
273 |
+
dX = torch.zeros_like(dY)
|
274 |
+
|
275 |
+
# XPU-specific optimization
|
276 |
+
kernel_args = {}
|
277 |
+
if X.device.type == "xpu":
|
278 |
+
kernel_args["grf_mode"] = "large"
|
279 |
+
|
280 |
+
_rms_norm_backward_kernel[grid](
|
281 |
+
dY,
|
282 |
+
dY.stride(0),
|
283 |
+
dX,
|
284 |
+
dX.stride(0),
|
285 |
+
X,
|
286 |
+
X.stride(0),
|
287 |
+
torch_to_triton_dtype[X.dtype],
|
288 |
+
W,
|
289 |
+
W.stride(0),
|
290 |
+
RSTD,
|
291 |
+
RSTD.stride(0),
|
292 |
+
_dW,
|
293 |
+
_dW.stride(0),
|
294 |
+
n_rows,
|
295 |
+
n_cols,
|
296 |
+
offset,
|
297 |
+
rows_per_program,
|
298 |
+
casting_mode,
|
299 |
+
BLOCK_SIZE=BLOCK_SIZE,
|
300 |
+
num_warps=num_warps,
|
301 |
+
**kernel_args, # XPU-specific optimization
|
302 |
+
)
|
303 |
+
dX = dX.view(*shape)
|
304 |
+
dW = _dW.sum(dim=0).to(W.dtype)
|
305 |
+
|
306 |
+
return dX, dW
|
307 |
+
|
308 |
+
|
309 |
+
class LigerRMSNormFunction(torch.autograd.Function):
|
310 |
+
"""
|
311 |
+
Performs RMSNorm (Root Mean Square Normalization), which normalizes the input tensor `X` using the
|
312 |
+
weight tensor `W`, with an optional offset and casting mode.
|
313 |
+
|
314 |
+
Some models use an 'offset' to shift the weight tensor `W` by a constant value. For example, Gemma
|
315 |
+
uses an offset of 1.0, so the computation becomes `(X / RMS(X)) * (W + 1.0)` instead of the usual
|
316 |
+
`(X / RMS(X)) * W`. You can pass the offset value as an argument to the forward function.
|
317 |
+
|
318 |
+
In addition, different models cast their inputs at different places during RMSNorm computation. For
|
319 |
+
example, Gemma casts everything to fp32 nefore starting the computation, while Llama casts only the
|
320 |
+
inverse RMS to fp32. You can specify the casting mode using the `casting_mode` argument. We currently
|
321 |
+
support the following casting modes (they match HuggingFace Transformers' implementations):
|
322 |
+
- 'llama': matches the Llama implementation, where only the inverse RMS is computed on fp32.
|
323 |
+
- 'gemma': matches the Gemma implementation, where everything is cast to fp32, then computed, then cast back to the original dtype.
|
324 |
+
- 'none': no casting is done. The computation is done in the original dtype. This saves memory and is slightly faster, but has more error w.r.t. the original implementation.
|
325 |
+
|
326 |
+
`in_place` option means whether to in_place modify dY to store dX. This is default to `True` to save memory. However, under certain cases, it can produce incorrect inputs.
|
327 |
+
For example, gemma2 uses two rmsnorm sequentially with residual in between. The resesidual part needs dY so it cannot be modified in-place.
|
328 |
+
Therefore, for the patching of RMSNorm in gemma2, we set `in_place` to `False`
|
329 |
+
"""
|
330 |
+
|
331 |
+
@staticmethod
|
332 |
+
@ensure_contiguous
|
333 |
+
def forward(ctx, X, W, eps, offset=0.0, casting_mode="llama", in_place=True):
|
334 |
+
"""
|
335 |
+
X: (B, T, H) or (BxT, H)
|
336 |
+
W: (H,)
|
337 |
+
"""
|
338 |
+
Y, X, RSTD, BLOCK_SIZE, num_warps, casting_mode = rms_norm_forward(X, W, eps, offset, casting_mode)
|
339 |
+
ctx.offset = offset
|
340 |
+
ctx.casting_mode = casting_mode
|
341 |
+
ctx.in_place = in_place
|
342 |
+
ctx.BLOCK_SIZE = BLOCK_SIZE
|
343 |
+
ctx.num_warps = num_warps
|
344 |
+
ctx.save_for_backward(X, W, RSTD)
|
345 |
+
return Y
|
346 |
+
|
347 |
+
@staticmethod
|
348 |
+
@ensure_contiguous
|
349 |
+
def backward(ctx, dY):
|
350 |
+
"""
|
351 |
+
Y: (B, T, H) or (BxT, H)
|
352 |
+
"""
|
353 |
+
X, W, RSTD = ctx.saved_tensors
|
354 |
+
dX, dW = rms_norm_backward(
|
355 |
+
dY,
|
356 |
+
X,
|
357 |
+
W,
|
358 |
+
RSTD,
|
359 |
+
ctx.offset,
|
360 |
+
ctx.casting_mode,
|
361 |
+
ctx.BLOCK_SIZE,
|
362 |
+
ctx.num_warps,
|
363 |
+
ctx.in_place,
|
364 |
+
)
|
365 |
+
return dX, dW, None, None, None, None
|
torch-ext/liger_kernels/rope.py
ADDED
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import triton
|
3 |
+
import triton.language as tl
|
4 |
+
|
5 |
+
|
6 |
+
@triton.jit
|
7 |
+
def _triton_rope(
|
8 |
+
q_ptr,
|
9 |
+
q_row_stride,
|
10 |
+
k_ptr,
|
11 |
+
k_row_stride,
|
12 |
+
cos,
|
13 |
+
cos_row_stride,
|
14 |
+
sin,
|
15 |
+
sin_row_stride,
|
16 |
+
sl,
|
17 |
+
bs: tl.constexpr,
|
18 |
+
cos_bs: tl.constexpr,
|
19 |
+
n_qh: tl.constexpr,
|
20 |
+
n_kh: tl.constexpr,
|
21 |
+
hd: tl.constexpr,
|
22 |
+
pad_n_qh: tl.constexpr,
|
23 |
+
pad_n_kh: tl.constexpr,
|
24 |
+
pad_hd: tl.constexpr,
|
25 |
+
BLOCK_SIZE: tl.constexpr,
|
26 |
+
BACKWARD_PASS: tl.constexpr = False,
|
27 |
+
):
|
28 |
+
# q size: (bsz, seq_len, num_q_heads, head_dim)
|
29 |
+
# q stride: (seq_len * num_q_heads * head_dim, num_q_heads * head_dim, head_dim, 1)
|
30 |
+
# k size: (bsz, seq_len, num_kv_heads, head_dim)
|
31 |
+
# k stride: (seq_len * num_kv_heads * head_dim, num_kv_heads * head_dim, head_dim, 1)
|
32 |
+
|
33 |
+
# cos size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
|
34 |
+
# stride: (seq_len * head_dim, head_dim, 1)
|
35 |
+
pid = tl.program_id(0)
|
36 |
+
|
37 |
+
# locate start address
|
38 |
+
q_ptr = q_ptr + pid * q_row_stride
|
39 |
+
k_ptr = k_ptr + pid * k_row_stride
|
40 |
+
|
41 |
+
# ####################################################################
|
42 |
+
# get the cos(mθ_{i...d/2}) and sin(mθ_{i...d/2}) for token position
|
43 |
+
# m of this program instance
|
44 |
+
# ####################################################################
|
45 |
+
|
46 |
+
# 1. program instances are laid out in a 1D vector of size bsz * seq_len, which
|
47 |
+
# effectively represents a 2D grid of size [bsz, seq_len] with seq_len dimension
|
48 |
+
# being the fastest changing dimension. Thus we can simply do pid // sl to get the batch index
|
49 |
+
# and pid % sl to get the sequence index.
|
50 |
+
# 2. We only need the left half of cos and sin matrix because the right half is just
|
51 |
+
# a clone of the left half.
|
52 |
+
batch_idx = pid // sl
|
53 |
+
cos_row_idx = pid % sl
|
54 |
+
cos = cos + tl.where(
|
55 |
+
cos_bs == 1,
|
56 |
+
cos_row_idx * cos_row_stride,
|
57 |
+
batch_idx * (sl * cos_row_stride) + cos_row_idx * cos_row_stride,
|
58 |
+
)
|
59 |
+
sin = sin + tl.where(
|
60 |
+
cos_bs == 1,
|
61 |
+
cos_row_idx * sin_row_stride,
|
62 |
+
batch_idx * (sl * sin_row_stride) + cos_row_idx * sin_row_stride,
|
63 |
+
)
|
64 |
+
|
65 |
+
cos_offsets = tl.arange(0, pad_hd // 2)
|
66 |
+
cos_mask = cos_offsets < hd // 2
|
67 |
+
cos_row = tl.load(cos + cos_offsets, mask=cos_mask, other=0)
|
68 |
+
sin_row = tl.load(sin + cos_offsets, mask=cos_mask, other=0)
|
69 |
+
|
70 |
+
# ####################################################################
|
71 |
+
# Load the left and right half of q and k for the current
|
72 |
+
# program instance (i.e. for the current token) separately
|
73 |
+
# ####################################################################
|
74 |
+
# left half of the head
|
75 |
+
first_half_q_offsets = tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
|
76 |
+
first_half_k_offsets = tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
|
77 |
+
first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (tl.arange(0, pad_hd // 2)[None, :] < hd // 2)
|
78 |
+
first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (tl.arange(0, pad_hd // 2)[None, :] < hd // 2)
|
79 |
+
q_tile_1 = tl.load(q_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to(sin_row.dtype)
|
80 |
+
k_tile_1 = tl.load(k_ptr + first_half_k_offsets, mask=first_k_mask, other=0).to(sin_row.dtype)
|
81 |
+
|
82 |
+
# right half of the head
|
83 |
+
second_half_q_offsets = first_half_q_offsets + (hd // 2)
|
84 |
+
second_half_k_offsets = first_half_k_offsets + (hd // 2)
|
85 |
+
second_q_mask = first_q_mask
|
86 |
+
second_k_mask = first_k_mask
|
87 |
+
q_tile_2 = tl.load(q_ptr + second_half_q_offsets, mask=second_q_mask, other=0).to(sin_row.dtype)
|
88 |
+
k_tile_2 = tl.load(k_ptr + second_half_k_offsets, mask=second_k_mask, other=0).to(sin_row.dtype)
|
89 |
+
|
90 |
+
if not BACKWARD_PASS:
|
91 |
+
# y = [x1, x2] * [cos, cos] + [-x2, x1] * [sin, sin]
|
92 |
+
new_q_tile_1 = q_tile_1 * cos_row - q_tile_2 * sin_row
|
93 |
+
tl.store(q_ptr + first_half_q_offsets, new_q_tile_1, mask=first_q_mask)
|
94 |
+
new_q_tile_2 = q_tile_2 * cos_row + q_tile_1 * sin_row
|
95 |
+
tl.store(q_ptr + second_half_q_offsets, new_q_tile_2, mask=second_q_mask)
|
96 |
+
|
97 |
+
new_k_tile_1 = k_tile_1 * cos_row - k_tile_2 * sin_row
|
98 |
+
tl.store(k_ptr + first_half_k_offsets, new_k_tile_1, mask=first_k_mask)
|
99 |
+
new_k_tile_2 = k_tile_2 * cos_row + k_tile_1 * sin_row
|
100 |
+
tl.store(k_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask)
|
101 |
+
else:
|
102 |
+
# with some math, we can get:
|
103 |
+
# dy = [dx1, dx2] * [cos, cos] + [-dx2, dx1] * [-sin, -sin]
|
104 |
+
new_q_tile_1 = q_tile_1 * cos_row + q_tile_2 * sin_row
|
105 |
+
tl.store(q_ptr + first_half_q_offsets, new_q_tile_1, mask=first_q_mask)
|
106 |
+
new_q_tile_2 = q_tile_2 * cos_row - q_tile_1 * sin_row
|
107 |
+
tl.store(q_ptr + second_half_q_offsets, new_q_tile_2, mask=second_q_mask)
|
108 |
+
|
109 |
+
new_k_tile_1 = k_tile_1 * cos_row + k_tile_2 * sin_row
|
110 |
+
tl.store(k_ptr + first_half_k_offsets, new_k_tile_1, mask=first_k_mask)
|
111 |
+
new_k_tile_2 = k_tile_2 * cos_row - k_tile_1 * sin_row
|
112 |
+
tl.store(k_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask)
|
113 |
+
|
114 |
+
|
115 |
+
def rope_forward(q, k, cos, sin):
|
116 |
+
# transpose it back to the physical shape because Triton looks at the physical storage
|
117 |
+
# note: q and k are incontiguous before the transformation and will become contiguous after transpose
|
118 |
+
q = q.transpose(1, 2)
|
119 |
+
k = k.transpose(1, 2)
|
120 |
+
|
121 |
+
batch_size, seq_len, n_q_head, head_dim = q.shape
|
122 |
+
n_kv_head = k.shape[2]
|
123 |
+
pad_hd = triton.next_power_of_2(head_dim)
|
124 |
+
pad_n_q_head = triton.next_power_of_2(n_q_head)
|
125 |
+
pad_n_kv_head = triton.next_power_of_2(n_kv_head)
|
126 |
+
BLOCK_SIZE = max(pad_n_q_head, pad_n_kv_head)
|
127 |
+
|
128 |
+
n_row = batch_size * seq_len
|
129 |
+
|
130 |
+
# ensure tensors passed into the kernel are contiguous. It will be no-op if they are already contiguous
|
131 |
+
q = q.contiguous()
|
132 |
+
k = k.contiguous()
|
133 |
+
cos = cos.contiguous()
|
134 |
+
sin = sin.contiguous()
|
135 |
+
cos_batch_size = cos.shape[0]
|
136 |
+
|
137 |
+
_triton_rope[(n_row,)](
|
138 |
+
q,
|
139 |
+
q.stride(1),
|
140 |
+
k,
|
141 |
+
k.stride(1),
|
142 |
+
cos,
|
143 |
+
cos.stride(-2),
|
144 |
+
sin,
|
145 |
+
sin.stride(-2),
|
146 |
+
seq_len,
|
147 |
+
batch_size,
|
148 |
+
cos_batch_size,
|
149 |
+
n_q_head,
|
150 |
+
n_kv_head,
|
151 |
+
head_dim,
|
152 |
+
pad_n_q_head,
|
153 |
+
pad_n_kv_head,
|
154 |
+
pad_hd,
|
155 |
+
BLOCK_SIZE=BLOCK_SIZE,
|
156 |
+
BACKWARD_PASS=False,
|
157 |
+
)
|
158 |
+
return q.transpose(1, 2), k.transpose(1, 2), cos, sin
|
159 |
+
|
160 |
+
|
161 |
+
def rope_backward(dq, dk, cos, sin):
|
162 |
+
dq = dq.transpose(1, 2)
|
163 |
+
dk = dk.transpose(1, 2)
|
164 |
+
|
165 |
+
batch_size, seq_len, n_q_head, head_dim = dq.shape
|
166 |
+
cos_batch_size = cos.shape[0]
|
167 |
+
n_kv_head = dk.shape[2]
|
168 |
+
pad_hd = triton.next_power_of_2(head_dim)
|
169 |
+
pad_n_q_head = triton.next_power_of_2(n_q_head)
|
170 |
+
pad_n_kv_head = triton.next_power_of_2(n_kv_head)
|
171 |
+
BLOCK_SIZE = max(pad_n_q_head, pad_n_kv_head)
|
172 |
+
|
173 |
+
n_row = batch_size * seq_len
|
174 |
+
|
175 |
+
# ensure dq and dk are contiguous
|
176 |
+
dq = dq.contiguous()
|
177 |
+
dk = dk.contiguous()
|
178 |
+
|
179 |
+
# backward is similar to forward except swapping few ops
|
180 |
+
_triton_rope[(n_row,)](
|
181 |
+
dq,
|
182 |
+
dq.stride(1),
|
183 |
+
dk,
|
184 |
+
dk.stride(1),
|
185 |
+
cos,
|
186 |
+
cos.stride(-2),
|
187 |
+
sin,
|
188 |
+
sin.stride(-2),
|
189 |
+
seq_len,
|
190 |
+
batch_size,
|
191 |
+
cos_batch_size,
|
192 |
+
n_q_head,
|
193 |
+
n_kv_head,
|
194 |
+
head_dim,
|
195 |
+
pad_n_q_head,
|
196 |
+
pad_n_kv_head,
|
197 |
+
pad_hd,
|
198 |
+
BLOCK_SIZE=BLOCK_SIZE,
|
199 |
+
BACKWARD_PASS=True,
|
200 |
+
)
|
201 |
+
return dq.transpose(1, 2), dk.transpose(1, 2)
|
202 |
+
|
203 |
+
|
204 |
+
class LigerRopeFunction(torch.autograd.Function):
|
205 |
+
"""
|
206 |
+
Triton implementation of the Rotary Positional Embedding (RoPE) operation. Please note that
|
207 |
+
this implements the HuggingFace Llama & Mistral version, whose rotation matrix is slightly different
|
208 |
+
than the original RoPE paper.
|
209 |
+
|
210 |
+
Please find the corresponding HuggingFace implementation here:
|
211 |
+
https://github.com/huggingface/transformers/blob/v4.40.2/src/transformers/models/llama/modeling_llama.py#L184
|
212 |
+
|
213 |
+
For more details about the rotation matrix used here, please refer to:
|
214 |
+
https://discuss.huggingface.co/t/is-llama-rotary-embedding-implementation-correct/44509/2
|
215 |
+
"""
|
216 |
+
|
217 |
+
@staticmethod
|
218 |
+
def forward(ctx, q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
219 |
+
"""
|
220 |
+
q size: (bsz, n_q_head, seq_len, head_dim)
|
221 |
+
k size: (bsz, n_kv_head, seq_len, head_dim)
|
222 |
+
cos size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
|
223 |
+
sin size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
|
224 |
+
"""
|
225 |
+
q, k, cos, sin = rope_forward(q, k, cos, sin)
|
226 |
+
ctx.save_for_backward(cos, sin)
|
227 |
+
return q, k
|
228 |
+
|
229 |
+
def backward(ctx, dq, dk):
|
230 |
+
"""
|
231 |
+
dq size: (bsz, n_q_head, seq_len, head_dim)
|
232 |
+
dk size: (bsz, n_kv_head, seq_len, head_dim)
|
233 |
+
cos size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
|
234 |
+
sin size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
|
235 |
+
"""
|
236 |
+
|
237 |
+
cos, sin = ctx.saved_tensors
|
238 |
+
dq, dk = rope_backward(dq, dk, cos, sin)
|
239 |
+
return dq, dk, None, None, None, None
|
torch-ext/liger_kernels/swiglu.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import triton
|
3 |
+
import triton.language as tl
|
4 |
+
|
5 |
+
from utils import calculate_settings
|
6 |
+
from utils import ensure_contiguous
|
7 |
+
|
8 |
+
|
9 |
+
@triton.jit
|
10 |
+
def silu(x):
|
11 |
+
return x * tl.sigmoid(x)
|
12 |
+
|
13 |
+
|
14 |
+
@triton.jit
|
15 |
+
def _swiglu_forward_kernel(a_ptr, b_ptr, c_ptr, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr):
|
16 |
+
program_id = tl.program_id(0).to(tl.int64)
|
17 |
+
|
18 |
+
# locate start index
|
19 |
+
a_ptr += program_id * stride
|
20 |
+
b_ptr += program_id * stride
|
21 |
+
c_ptr += program_id * stride
|
22 |
+
|
23 |
+
col_offsets = tl.arange(0, BLOCK_SIZE)
|
24 |
+
mask = col_offsets < n_cols
|
25 |
+
|
26 |
+
# sigmoid requires type float32
|
27 |
+
a_row = tl.load(a_ptr + col_offsets, mask=mask, other=0).to(tl.float32)
|
28 |
+
b_row = tl.load(b_ptr + col_offsets, mask=mask, other=0)
|
29 |
+
c_row = silu(a_row) * b_row
|
30 |
+
tl.store(c_ptr + col_offsets, c_row, mask=mask)
|
31 |
+
|
32 |
+
|
33 |
+
@triton.jit
|
34 |
+
def _swiglu_backward_kernel(dc_ptr, a_ptr, b_ptr, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr):
|
35 |
+
program_id = tl.program_id(0).to(tl.int64)
|
36 |
+
|
37 |
+
# locate start index
|
38 |
+
dc_ptr += program_id * stride
|
39 |
+
a_ptr += program_id * stride
|
40 |
+
b_ptr += program_id * stride
|
41 |
+
|
42 |
+
col_offsets = tl.arange(0, BLOCK_SIZE)
|
43 |
+
mask = col_offsets < n_cols
|
44 |
+
|
45 |
+
dc_row = tl.load(dc_ptr + col_offsets, mask=mask, other=0)
|
46 |
+
# sigmoid requires type float32
|
47 |
+
a_row = tl.load(a_ptr + col_offsets, mask=mask, other=0).to(tl.float32)
|
48 |
+
b_row = tl.load(b_ptr + col_offsets, mask=mask, other=0)
|
49 |
+
|
50 |
+
# recomputation to save memory
|
51 |
+
sig_a = tl.sigmoid(a_row)
|
52 |
+
silu_a = a_row * sig_a
|
53 |
+
db_row = dc_row * silu_a
|
54 |
+
da_row = dc_row * (silu_a * (1 - sig_a) + sig_a) * b_row
|
55 |
+
|
56 |
+
tl.store(a_ptr + col_offsets, da_row, mask=mask)
|
57 |
+
tl.store(b_ptr + col_offsets, db_row, mask=mask)
|
58 |
+
|
59 |
+
|
60 |
+
def swiglu_forward(a, b):
|
61 |
+
ori_shape = a.shape
|
62 |
+
|
63 |
+
n_cols = ori_shape[-1]
|
64 |
+
a = a.view(-1, n_cols)
|
65 |
+
b = b.view(-1, n_cols)
|
66 |
+
c = torch.empty_like(a)
|
67 |
+
n_rows = a.shape[0]
|
68 |
+
|
69 |
+
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
70 |
+
|
71 |
+
_swiglu_forward_kernel[(n_rows,)](
|
72 |
+
a,
|
73 |
+
b,
|
74 |
+
c,
|
75 |
+
c.stride(-2),
|
76 |
+
n_cols=n_cols,
|
77 |
+
BLOCK_SIZE=BLOCK_SIZE,
|
78 |
+
num_warps=num_warps,
|
79 |
+
)
|
80 |
+
return a, b, c.view(*ori_shape)
|
81 |
+
|
82 |
+
|
83 |
+
def swiglu_backward(a, b, dc):
|
84 |
+
ori_shape = dc.shape
|
85 |
+
n_cols = ori_shape[-1]
|
86 |
+
dc = dc.view(-1, n_cols)
|
87 |
+
n_rows = dc.shape[0]
|
88 |
+
|
89 |
+
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
90 |
+
|
91 |
+
_swiglu_backward_kernel[(n_rows,)](
|
92 |
+
dc,
|
93 |
+
a,
|
94 |
+
b,
|
95 |
+
dc.stride(-2),
|
96 |
+
n_cols=n_cols,
|
97 |
+
BLOCK_SIZE=BLOCK_SIZE,
|
98 |
+
num_warps=num_warps,
|
99 |
+
)
|
100 |
+
return a.view(*ori_shape), b.view(*ori_shape)
|
101 |
+
|
102 |
+
|
103 |
+
class LigerSiLUMulFunction(torch.autograd.Function):
|
104 |
+
@staticmethod
|
105 |
+
@ensure_contiguous
|
106 |
+
def forward(ctx, a, b):
|
107 |
+
a, b, c = swiglu_forward(a, b)
|
108 |
+
ctx.save_for_backward(a, b)
|
109 |
+
return c
|
110 |
+
|
111 |
+
@staticmethod
|
112 |
+
@ensure_contiguous
|
113 |
+
def backward(ctx, dc):
|
114 |
+
a, b = ctx.saved_tensors
|
115 |
+
a, b = swiglu_backward(a, b, dc)
|
116 |
+
return a, b
|
torch-ext/liger_kernels/tvd.py
ADDED
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Literal
|
2 |
+
from typing import Optional
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import triton
|
6 |
+
import triton.language as tl
|
7 |
+
|
8 |
+
from utils import ensure_contiguous
|
9 |
+
|
10 |
+
MAX_FUSED_SIZE = 65536 // 4
|
11 |
+
|
12 |
+
REDUCTION_LITERAL = Literal["none", "sum", "mean", "batchmean"]
|
13 |
+
|
14 |
+
_REDUCTION_MODE_NONE = tl.constexpr(0)
|
15 |
+
_REDUCTION_MODE_SUM = tl.constexpr(1)
|
16 |
+
_REDUCTION_MODE_MEAN = tl.constexpr(2)
|
17 |
+
_REDUCTION_MODE_BATCHMEAN = tl.constexpr(3)
|
18 |
+
|
19 |
+
_str_to_reduction_mode = {
|
20 |
+
"none": _REDUCTION_MODE_NONE.value,
|
21 |
+
"sum": _REDUCTION_MODE_SUM.value,
|
22 |
+
"mean": _REDUCTION_MODE_MEAN.value,
|
23 |
+
"batchmean": _REDUCTION_MODE_BATCHMEAN.value,
|
24 |
+
}
|
25 |
+
|
26 |
+
|
27 |
+
def get_num_warps(BLOCK_SIZE):
|
28 |
+
num_warps = 4
|
29 |
+
if BLOCK_SIZE >= 32768:
|
30 |
+
num_warps = 32
|
31 |
+
elif BLOCK_SIZE >= 8192:
|
32 |
+
num_warps = 16
|
33 |
+
elif BLOCK_SIZE >= 2048:
|
34 |
+
num_warps = 8
|
35 |
+
|
36 |
+
return num_warps
|
37 |
+
|
38 |
+
|
39 |
+
@triton.jit
|
40 |
+
def _tv_distance_kernel(
|
41 |
+
p_ptr,
|
42 |
+
p_stride,
|
43 |
+
q_ptr,
|
44 |
+
q_stride,
|
45 |
+
loss_ptr,
|
46 |
+
loss_stride,
|
47 |
+
grads_ptr,
|
48 |
+
grads_stride,
|
49 |
+
label_ptr,
|
50 |
+
ignore_index: tl.constexpr,
|
51 |
+
n_cols,
|
52 |
+
BLOCK_SIZE: tl.constexpr,
|
53 |
+
HAS_LABEL: tl.constexpr,
|
54 |
+
reduction: tl.constexpr = _REDUCTION_MODE_BATCHMEAN,
|
55 |
+
):
|
56 |
+
pid = tl.program_id(0).to(tl.int64)
|
57 |
+
p_ptr += pid * p_stride
|
58 |
+
q_ptr += pid * q_stride
|
59 |
+
loss_ptr += pid * loss_stride
|
60 |
+
grads_ptr += pid * grads_stride
|
61 |
+
label_ptr += pid
|
62 |
+
|
63 |
+
base_offsets = tl.arange(0, BLOCK_SIZE)
|
64 |
+
|
65 |
+
if HAS_LABEL:
|
66 |
+
label = tl.load(label_ptr)
|
67 |
+
if label == ignore_index:
|
68 |
+
for i in range(0, n_cols, BLOCK_SIZE):
|
69 |
+
offsets = i + base_offsets
|
70 |
+
mask = offsets < n_cols
|
71 |
+
tl.store(grads_ptr + offsets, 0.0, mask=mask)
|
72 |
+
if reduction == _REDUCTION_MODE_NONE:
|
73 |
+
tl.store(loss_ptr + offsets, 0.0, mask=mask)
|
74 |
+
return
|
75 |
+
|
76 |
+
loss_sum = 0.0
|
77 |
+
for i in range(0, n_cols, BLOCK_SIZE):
|
78 |
+
offsets = i + base_offsets
|
79 |
+
mask = offsets < n_cols
|
80 |
+
|
81 |
+
p = tl.load(p_ptr + offsets, mask=mask, other=0.0)
|
82 |
+
q = tl.load(q_ptr + offsets, mask=mask, other=0.0)
|
83 |
+
|
84 |
+
# TVD(P || Q) = 0.5 * |P - Q|
|
85 |
+
tv_loss = 0.5 * tl.abs(p - q)
|
86 |
+
|
87 |
+
grad_res = tl.where(p > q, 0.5, -0.5)
|
88 |
+
|
89 |
+
tl.store(grads_ptr + offsets, grad_res, mask=mask)
|
90 |
+
|
91 |
+
if reduction == _REDUCTION_MODE_NONE:
|
92 |
+
tl.store(loss_ptr + offsets, tv_loss, mask=mask)
|
93 |
+
else:
|
94 |
+
loss_sum += tl.sum(tv_loss, axis=0)
|
95 |
+
|
96 |
+
if reduction != _REDUCTION_MODE_NONE:
|
97 |
+
tl.store(loss_ptr, loss_sum)
|
98 |
+
|
99 |
+
|
100 |
+
def tv_distance_forward_triton(p, q, shift_labels, reduction, ignore_index, has_label):
|
101 |
+
BT, V = p.shape
|
102 |
+
|
103 |
+
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
|
104 |
+
num_warps = get_num_warps(BLOCK_SIZE)
|
105 |
+
|
106 |
+
grid = (BT,)
|
107 |
+
|
108 |
+
reduction = _str_to_reduction_mode[reduction]
|
109 |
+
|
110 |
+
out_size = (BT, V) if reduction == _REDUCTION_MODE_NONE.value else (BT,)
|
111 |
+
output_tensor = torch.zeros(out_size, device=p.device, dtype=torch.float32)
|
112 |
+
grads = torch.empty_like(p)
|
113 |
+
|
114 |
+
n_non_ignore = (shift_labels != ignore_index).sum().item() if has_label else BT
|
115 |
+
|
116 |
+
_tv_distance_kernel[grid](
|
117 |
+
p,
|
118 |
+
p.stride(0),
|
119 |
+
q,
|
120 |
+
q.stride(0),
|
121 |
+
output_tensor,
|
122 |
+
output_tensor.stride(0),
|
123 |
+
grads,
|
124 |
+
grads.stride(0),
|
125 |
+
shift_labels if has_label else torch.empty(1, device=p.device),
|
126 |
+
ignore_index,
|
127 |
+
V,
|
128 |
+
BLOCK_SIZE=BLOCK_SIZE,
|
129 |
+
HAS_LABEL=has_label,
|
130 |
+
num_warps=num_warps,
|
131 |
+
reduction=reduction,
|
132 |
+
)
|
133 |
+
|
134 |
+
if reduction == _REDUCTION_MODE_BATCHMEAN.value:
|
135 |
+
return output_tensor.sum() / n_non_ignore, grads / n_non_ignore
|
136 |
+
elif reduction == _REDUCTION_MODE_SUM.value:
|
137 |
+
return output_tensor.sum(dim=0), grads
|
138 |
+
elif reduction == _REDUCTION_MODE_MEAN.value:
|
139 |
+
return output_tensor.sum() / (n_non_ignore * V), grads / (n_non_ignore * V)
|
140 |
+
else:
|
141 |
+
return output_tensor, grads
|
142 |
+
|
143 |
+
|
144 |
+
def tvd_backward_triton(grad_output, grads):
|
145 |
+
# If cross entropy is the last layer, grad_output is 1.0. Skip the mul then.
|
146 |
+
if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
|
147 |
+
return grads
|
148 |
+
|
149 |
+
return grads * grad_output
|
150 |
+
|
151 |
+
|
152 |
+
class LigerTVDLossFunction(torch.autograd.Function):
|
153 |
+
"""
|
154 |
+
Class implementing the forward and backward pass for the Total Variation Distance Loss using Triton.
|
155 |
+
"""
|
156 |
+
|
157 |
+
@staticmethod
|
158 |
+
@ensure_contiguous
|
159 |
+
def forward(
|
160 |
+
ctx,
|
161 |
+
p: torch.Tensor,
|
162 |
+
q: torch.Tensor,
|
163 |
+
shift_labels: Optional[torch.Tensor] = None,
|
164 |
+
reduction: REDUCTION_LITERAL = "batchmean",
|
165 |
+
ignore_index: int = -100,
|
166 |
+
) -> torch.Tensor:
|
167 |
+
"""A forward pass for the Total Variation Distance Loss.
|
168 |
+
|
169 |
+
Args:
|
170 |
+
ctx: Torch autograd context
|
171 |
+
p (torch.Tensor): A tensor of shape (BT, V) containing the first distribution.
|
172 |
+
q (torch.Tensor): A tensor of shape (BT, V) containing the second distribution.
|
173 |
+
shift_labels (Optional[torch.Tensor]): A tensor of shape (BT,) containing the labels.
|
174 |
+
reduction (REDUCTION_LITERAL, optional): The reduction method to be applied. Defaults to "batchmean".
|
175 |
+
ignore_index (int, optional): The index to ignore during loss calculation. Defaults to -100.
|
176 |
+
|
177 |
+
Returns:
|
178 |
+
torch.Tensor: The computed Total Variation Distance Loss.
|
179 |
+
"""
|
180 |
+
has_label = False
|
181 |
+
if shift_labels is not None:
|
182 |
+
assert shift_labels.shape == (p.shape[0],), (
|
183 |
+
f"the shape of shift_labels must be (BT,). Got: {shift_labels.shape}"
|
184 |
+
)
|
185 |
+
shift_labels = shift_labels.contiguous()
|
186 |
+
has_label = True
|
187 |
+
|
188 |
+
loss, grads = tv_distance_forward_triton(p, q, shift_labels, reduction, ignore_index, has_label)
|
189 |
+
ctx.save_for_backward(grads)
|
190 |
+
return loss
|
191 |
+
|
192 |
+
@staticmethod
|
193 |
+
@ensure_contiguous
|
194 |
+
def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
|
195 |
+
"""A backward pass for the Total Variation Distance Loss.
|
196 |
+
|
197 |
+
Args:
|
198 |
+
ctx: Torch autograd context
|
199 |
+
grad_output (torch.Tensor): The gradient of the loss with respect to the output.
|
200 |
+
|
201 |
+
Returns:
|
202 |
+
tuple[torch.Tensor, None, None, None, None]: The gradient of the loss with respect to the inputs.
|
203 |
+
"""
|
204 |
+
(grads,) = ctx.saved_tensors
|
205 |
+
grads = tvd_backward_triton(grad_output, grads)
|
206 |
+
|
207 |
+
return grads, None, None, None, None
|
torch-ext/liger_kernels/utils.py
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This file incorporates code from Unsloth licensed under the Apache License, Version 2.0.
|
3 |
+
See the original Unsloth repository at https://github.com/unslothai/unsloth.
|
4 |
+
|
5 |
+
The following line
|
6 |
+
https://github.com/linkedin/Liger-Kernel/blob/7382a8761f9af679482b968f9348013d933947c7/src/liger_kernel/ops/utils.py#L23
|
7 |
+
is based on code from Unsloth, located at:
|
8 |
+
https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/utils.py#L43
|
9 |
+
|
10 |
+
Modifications made by Yanning Chen, 2024.
|
11 |
+
"""
|
12 |
+
|
13 |
+
import functools
|
14 |
+
import importlib
|
15 |
+
import operator
|
16 |
+
|
17 |
+
from typing import Callable
|
18 |
+
|
19 |
+
import torch
|
20 |
+
import triton
|
21 |
+
import triton.language as tl
|
22 |
+
|
23 |
+
from packaging.version import Version
|
24 |
+
|
25 |
+
def infer_device():
|
26 |
+
"""
|
27 |
+
Get current device name based on available devices
|
28 |
+
"""
|
29 |
+
if torch.cuda.is_available(): # Works for both Nvidia and AMD
|
30 |
+
return "cuda"
|
31 |
+
elif torch.xpu.is_available():
|
32 |
+
return "xpu"
|
33 |
+
else:
|
34 |
+
return "cpu"
|
35 |
+
|
36 |
+
def is_hip() -> bool:
|
37 |
+
return torch.version.hip is not None
|
38 |
+
|
39 |
+
|
40 |
+
def ensure_contiguous(fn):
|
41 |
+
@functools.wraps(fn)
|
42 |
+
def wrapper(ctx, *args, **kwargs):
|
43 |
+
def maybe_to_contiguous(x):
|
44 |
+
return x.contiguous() if isinstance(x, torch.Tensor) else x
|
45 |
+
|
46 |
+
args = [maybe_to_contiguous(arg) for arg in args]
|
47 |
+
kwargs = {k: maybe_to_contiguous(v) for k, v in kwargs.items()}
|
48 |
+
return fn(ctx, *args, **kwargs)
|
49 |
+
|
50 |
+
return wrapper
|
51 |
+
|
52 |
+
|
53 |
+
def calculate_settings(n):
|
54 |
+
# reference: https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/utils.py#L43
|
55 |
+
|
56 |
+
MAX_FUSED_SIZE = 65536
|
57 |
+
BLOCK_SIZE = triton.next_power_of_2(n)
|
58 |
+
if BLOCK_SIZE > MAX_FUSED_SIZE:
|
59 |
+
raise RuntimeError(
|
60 |
+
f"Cannot launch Triton kernel since n = {n} exceeds the recommended Triton blocksize = {MAX_FUSED_SIZE}."
|
61 |
+
)
|
62 |
+
|
63 |
+
num_warps = 4
|
64 |
+
if BLOCK_SIZE >= 32768:
|
65 |
+
num_warps = 32 if not is_hip() else 16
|
66 |
+
elif BLOCK_SIZE >= 8192:
|
67 |
+
num_warps = 16
|
68 |
+
elif BLOCK_SIZE >= 2048:
|
69 |
+
num_warps = 8
|
70 |
+
return BLOCK_SIZE, num_warps
|
71 |
+
|
72 |
+
|
73 |
+
def compare_version(package: str, operator: Callable, target: str):
|
74 |
+
try:
|
75 |
+
pkg = importlib.import_module(package)
|
76 |
+
except ImportError:
|
77 |
+
return False
|
78 |
+
pkg_version = Version(pkg.__version__)
|
79 |
+
return operator(pkg_version, Version(target))
|
80 |
+
|
81 |
+
|
82 |
+
def get_amp_custom_fwd_bwd() -> Callable:
|
83 |
+
device = infer_device()
|
84 |
+
if compare_version("torch", operator.ge, "2.4.0"):
|
85 |
+
return (
|
86 |
+
functools.partial(torch.amp.custom_fwd, device_type=device),
|
87 |
+
functools.partial(torch.amp.custom_bwd, device_type=device),
|
88 |
+
)
|
89 |
+
return torch.cuda.amp.custom_fwd, torch.cuda.amp.custom_bwd
|
90 |
+
|
91 |
+
|
92 |
+
amp_custom_fwd, amp_custom_bwd = get_amp_custom_fwd_bwd()
|
93 |
+
|
94 |
+
|
95 |
+
torch_to_triton_dtype = {
|
96 |
+
torch.float32: tl.float32,
|
97 |
+
torch.float16: tl.float16,
|
98 |
+
torch.bfloat16: tl.bfloat16,
|
99 |
+
}
|
100 |
+
|
101 |
+
|
102 |
+
@triton.jit
|
103 |
+
def element_mul_kernel(
|
104 |
+
X_ptr,
|
105 |
+
X_stride,
|
106 |
+
grad_output_ptr,
|
107 |
+
n_cols,
|
108 |
+
BLOCK_SIZE: tl.constexpr,
|
109 |
+
):
|
110 |
+
"""
|
111 |
+
This function multiplies each element of the tensor pointed by X_ptr with the value pointed by grad_output_ptr.
|
112 |
+
The multiplication is performed in-place on the tensor pointed by X_ptr.
|
113 |
+
|
114 |
+
Parameters:
|
115 |
+
X_ptr: Pointer to the input tensor.
|
116 |
+
X_stride (int): The stride of the input tensor.
|
117 |
+
grad_output_ptr: Pointer to the gradient output value.
|
118 |
+
n_cols (int): The number of columns in the input tensor.
|
119 |
+
BLOCK_SIZE (int): The block size for Triton operations.
|
120 |
+
"""
|
121 |
+
|
122 |
+
# Get the program ID and convert it to int64 to avoid overflow
|
123 |
+
program_id = tl.program_id(0).to(tl.int64)
|
124 |
+
|
125 |
+
# Locate the start index
|
126 |
+
X_ptr += program_id * X_stride
|
127 |
+
|
128 |
+
# Load the gradient output value
|
129 |
+
grad_output = tl.load(grad_output_ptr)
|
130 |
+
|
131 |
+
# Perform the element-wise multiplication
|
132 |
+
for i in range(0, n_cols, BLOCK_SIZE):
|
133 |
+
X_offsets = i + tl.arange(0, BLOCK_SIZE)
|
134 |
+
X_block = tl.load(X_ptr + X_offsets, mask=X_offsets < n_cols)
|
135 |
+
tl.store(X_ptr + X_offsets, X_block * grad_output, mask=X_offsets < n_cols)
|