medmekk HF Staff commited on
Commit
fd5a517
·
verified ·
1 Parent(s): a990c21

Upload custom kernels

Browse files
Files changed (35) hide show
  1. README.md +0 -0
  2. build.toml +5 -0
  3. build/torch-universal/liger_kernels/__init__.py +29 -0
  4. build/torch-universal/liger_kernels/_ops.py +8 -0
  5. build/torch-universal/liger_kernels/cross_entropy.py +460 -0
  6. build/torch-universal/liger_kernels/dyt.py +225 -0
  7. build/torch-universal/liger_kernels/fused_linear_cross_entropy.py +283 -0
  8. build/torch-universal/liger_kernels/geglu.py +141 -0
  9. build/torch-universal/liger_kernels/group_norm.py +305 -0
  10. build/torch-universal/liger_kernels/jsd.py +201 -0
  11. build/torch-universal/liger_kernels/kl_div.py +262 -0
  12. build/torch-universal/liger_kernels/layer_norm.py +265 -0
  13. build/torch-universal/liger_kernels/qwen2vl_mrope.py +222 -0
  14. build/torch-universal/liger_kernels/rms_norm.py +365 -0
  15. build/torch-universal/liger_kernels/rope.py +239 -0
  16. build/torch-universal/liger_kernels/swiglu.py +116 -0
  17. build/torch-universal/liger_kernels/tvd.py +207 -0
  18. build/torch-universal/liger_kernels/utils.py +135 -0
  19. flake.lock +117 -0
  20. flake.nix +17 -0
  21. torch-ext/liger_kernels/__init__.py +29 -0
  22. torch-ext/liger_kernels/cross_entropy.py +460 -0
  23. torch-ext/liger_kernels/dyt.py +225 -0
  24. torch-ext/liger_kernels/fused_linear_cross_entropy.py +283 -0
  25. torch-ext/liger_kernels/geglu.py +141 -0
  26. torch-ext/liger_kernels/group_norm.py +305 -0
  27. torch-ext/liger_kernels/jsd.py +201 -0
  28. torch-ext/liger_kernels/kl_div.py +262 -0
  29. torch-ext/liger_kernels/layer_norm.py +265 -0
  30. torch-ext/liger_kernels/qwen2vl_mrope.py +222 -0
  31. torch-ext/liger_kernels/rms_norm.py +365 -0
  32. torch-ext/liger_kernels/rope.py +239 -0
  33. torch-ext/liger_kernels/swiglu.py +116 -0
  34. torch-ext/liger_kernels/tvd.py +207 -0
  35. torch-ext/liger_kernels/utils.py +135 -0
README.md ADDED
File without changes
build.toml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ [general]
2
+ name = "liger_kernels"
3
+
4
+ [torch]
5
+ universal = true
build/torch-universal/liger_kernels/__init__.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from cross_entropy import LigerCrossEntropyFunction
2
+ from fused_linear_cross_entropy import LigerFusedLinearCrossEntropyFunction
3
+ from dyt import LigerDyTFunction
4
+ from geglu import LigerGELUMulFunction
5
+ from group_norm import LigerGroupNormFunction
6
+ from kl_div import LigerKLDivLossFunction
7
+ from layer_norm import LigerLayerNormFunction
8
+ from qwen2vl_mrope import LigerQwen2VLMRopeFunction
9
+ from rms_norm import LigerRMSNormFunction
10
+ from jsd import LigerJSDFunction
11
+ from rope import LigerRopeFunction
12
+ from swiglu import LigerSiLUMulFunction
13
+ from tvd import LigerTVDLossFunction
14
+
15
+ __all__ = [
16
+ "LigerCrossEntropyFunction",
17
+ "LigerFusedLinearCrossEntropyFunction",
18
+ "LigerDyTFunction",
19
+ "LigerGELUMulFunction",
20
+ "LigerGroupNormFunction",
21
+ "LigerKLDivLossFunction",
22
+ "LigerLayerNormFunction",
23
+ "LigerQwen2VLMRopeFunction",
24
+ "LigerRMSNormFunction",
25
+ "LigerJSDFunction",
26
+ "LigerRopeFunction",
27
+ "LigerSiLUMulFunction",
28
+ "LigerTVDLossFunction",
29
+ ]
build/torch-universal/liger_kernels/_ops.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ ops = torch.ops._liger_kernels_20250505094412
3
+
4
+ def add_op_namespace_prefix(op_name: str):
5
+ """
6
+ Prefix op by namespace.
7
+ """
8
+ return f"_liger_kernels_20250505094412::{op_name}"
build/torch-universal/liger_kernels/cross_entropy.py ADDED
@@ -0,0 +1,460 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import operator
2
+
3
+ from typing import Optional
4
+
5
+ import torch
6
+ import triton
7
+ import triton.language as tl
8
+
9
+ from utils import compare_version
10
+ from utils import element_mul_kernel
11
+ from utils import is_hip
12
+ from utils import infer_device
13
+
14
+ if compare_version("triton", operator.ge, "3.0.0"):
15
+ try:
16
+ # typical import path with dispatch available
17
+ from triton.language.extra.libdevice import tanh
18
+ except ModuleNotFoundError:
19
+ # for working with NGC containers
20
+ from triton.language.extra.cuda.libdevice import tanh
21
+ else:
22
+ from triton.language.math import tanh
23
+
24
+
25
+ @triton.jit
26
+ def liger_cross_entropy_kernel(
27
+ X_ptr,
28
+ X_stride,
29
+ Y_ptr,
30
+ Y_stride,
31
+ weight_ptr,
32
+ loss_ptr,
33
+ z_loss_ptr,
34
+ loss_stride,
35
+ n_cols,
36
+ n_non_ignore,
37
+ sum_non_ignore_weight,
38
+ weight_sum,
39
+ ignore_index,
40
+ lse_square_scale: tl.constexpr,
41
+ label_smoothing: tl.constexpr,
42
+ reduction: tl.constexpr, # set it as constexpr since reduction is always known at compile time
43
+ softcap,
44
+ RETURN_Z_LOSS: tl.constexpr,
45
+ BLOCK_SIZE: tl.constexpr,
46
+ HAS_WEIGHT: tl.constexpr,
47
+ HAS_SOFTCAPPING: tl.constexpr,
48
+ ):
49
+ """
50
+ This kernel computes both cross entropy loss and the gradient of the input.
51
+ We only consider hard label + mean reduction for now. Please refer to https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html for the math.
52
+
53
+ Parameters:
54
+ X_ptr: Pointer to input tensor.
55
+ X_stride (int): The stride of the input tensor.
56
+ Y_ptr: Pointer to target tensor.
57
+ Y_stride (int): The stride of the target tensor.
58
+ weight_ptr: Pointer to weight tensor.
59
+ loss_ptr: Pointer to tensor to store the loss.
60
+ z_loss_ptr: Pointer to tensor to store the z loss. No operation if RETURN_Z_LOSS is 0.
61
+ loss_stride (int): The stride of the loss tensor.
62
+ n_cols (int): The number of columns in the input tensor.
63
+ n_non_ignore (float): The number of non-ignored elements in the batch.
64
+ sum_non_ignore_weight (float): The sum of non-ignored target's weights in the batch.
65
+ weight_sum (float): The sum of weight tensor.
66
+ ignore_index (int): The index to ignore in the target.
67
+ label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
68
+ lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training.
69
+ reduction (str): The string for the reduction to apply
70
+ softcap (float): The upper threshold for scaling logits to the range (-softcap, +softcap).
71
+ RETURN_Z_LOSS (int): The boolean value to decide whether storing z loss to z_loss_ptr or not. It must be 0 or 1.
72
+ BLOCK_SIZE (int): The block size for Triton operations.
73
+ HAS_WEIGHT (bool): The boolean value to determine whether assigning weight to each of the classes.
74
+ HAS_SOFTCAPPING (bool): The boolean value to determine whether applying soft-capping or not.
75
+ """
76
+
77
+ # https://github.com/triton-lang/triton/issues/1058
78
+ # If B*T*V is too large, program_id * stride will overflow out of int32, so we convert to int64
79
+ program_id = tl.program_id(0).to(tl.int64)
80
+
81
+ # 1. Load Y_ptr first because if the target is ignore_index, we can return right away
82
+ Y_ptr += program_id * Y_stride
83
+ y = tl.load(Y_ptr)
84
+
85
+ # 2. locate the start index
86
+ X_ptr += program_id * X_stride
87
+
88
+ if y == ignore_index:
89
+ # set all X_ptr as 0
90
+ for i in range(0, n_cols, BLOCK_SIZE):
91
+ X_offsets = i + tl.arange(0, BLOCK_SIZE)
92
+ tl.store(X_ptr + X_offsets, 0.0, mask=X_offsets < n_cols)
93
+ return
94
+
95
+ loss_ptr += program_id * loss_stride
96
+ if RETURN_Z_LOSS:
97
+ z_loss_ptr += program_id * loss_stride
98
+
99
+ if HAS_WEIGHT:
100
+ weight_y = tl.load(weight_ptr + y).cast(tl.float32)
101
+
102
+ # Online softmax: 2 loads + 1 store (compared with 3 loads + 1 store for the safe softmax)
103
+ # Refer to Algorithm 3 in the paper: https://arxiv.org/pdf/1805.02867
104
+
105
+ # 3. [Online softmax] first pass: find max + sum
106
+ m = float("-inf") # m is the max value. use the notation from the paper
107
+ d = 0.0 # d is the sum. use the notation from the paper
108
+ ori_X_y = tl.load(X_ptr + y).cast(tl.float32) # we need to store the original value of X_y for the loss calculation
109
+ if HAS_SOFTCAPPING:
110
+ ori_X_y = softcap * tanh(ori_X_y / softcap)
111
+
112
+ # Label smoothing is a general case of normal cross entropy
113
+ # See the full derivation at https://github.com/linkedin/Liger-Kernel/pull/198#issue-2503665310
114
+ scaled_x_sum = 0.0
115
+ eps = label_smoothing / n_cols
116
+
117
+ for i in range(0, n_cols, BLOCK_SIZE):
118
+ X_offsets = i + tl.arange(0, BLOCK_SIZE)
119
+ X_block = tl.load(
120
+ X_ptr + X_offsets,
121
+ mask=X_offsets < n_cols,
122
+ other=float("-inf"),
123
+ # Ensure float32 precision for softmax calculation
124
+ ).cast(tl.float32)
125
+ if HAS_SOFTCAPPING:
126
+ X_block = softcap * tanh(X_block / softcap)
127
+ block_max = tl.max(X_block)
128
+ if label_smoothing > 0:
129
+ # scale X beforehand to avoid overflow
130
+ if HAS_WEIGHT:
131
+ weight_block = tl.load(weight_ptr + X_offsets, mask=X_offsets < n_cols)
132
+ scaled_x_sum += tl.sum(tl.where(X_offsets < n_cols, -eps * X_block * weight_block, 0.0))
133
+ else:
134
+ scaled_x_sum += tl.sum(tl.where(X_offsets < n_cols, -eps * X_block, 0.0))
135
+ m_new = tl.maximum(m, block_max)
136
+ d = d * tl.exp(m - m_new) + tl.sum(tl.exp(X_block - m_new))
137
+ m = m_new
138
+
139
+ # log (sum(e^(X_i))) = log (sum(e ^ (max(X) * e ^ (X_i - max(X)))))
140
+ # = log (e^(max(X)) * sum(e ^ (X_i - max(X))))
141
+ # = max(X) + log (sum(e ^ (X_i - max(X)))) = m + log d
142
+ lse = m + tl.log(d)
143
+
144
+ # 4. [Online Softmax] Second pass: compute gradients
145
+ # For 'mean' reduction, gradients are normalized by number of non-ignored elements (N)
146
+ # dx_y = (softmax(x_y) - 1) / N
147
+ # dx_i = softmax(x_i) / N, i != y
148
+ # For label smoothing:
149
+ # dx_i = (softmax(x_i) - label_smoothing / V) / N, V = n_cols, i != y
150
+ # dx_y = (softmax(x_y) - label_smoothing / V - (1 - label_smoothing)) / N
151
+ # = dx_i - (1 - label_smoothing) / N
152
+ # With Z loss:
153
+ # dx_i = ((1 + 2 * lse_square_scale * lse) * softmax(x_i) - label_smoothing / V) / N, i != y
154
+ # dx_y = dx_i - (1 - label_smoothing) / N
155
+ # For 'sum' reduction, no normalization is applied:
156
+ # dx_y = softmax(x_y) - 1
157
+ # dx_i = softmax(x_i), for i ≠ y
158
+
159
+ for i in range(0, n_cols, BLOCK_SIZE):
160
+ X_offsets = i + tl.arange(0, BLOCK_SIZE)
161
+ X_block = tl.load(
162
+ X_ptr + X_offsets,
163
+ mask=X_offsets < n_cols,
164
+ other=float("-inf"),
165
+ # Ensure float32 precision for softmax calculation
166
+ ).cast(tl.float32)
167
+ if HAS_SOFTCAPPING:
168
+ intermediate = tanh(X_block / softcap)
169
+ X_block = softcap * intermediate
170
+
171
+ if not HAS_WEIGHT:
172
+ # softmax(x_i)
173
+ X_block = tl.exp(X_block - m) / d
174
+ # derivative of z-loss: 2 * lse_square_scale * lse * softmax(x_i)
175
+ X_block += 2 * lse_square_scale * lse * X_block
176
+ # smoothing term
177
+ X_block += -eps
178
+ # special handle dx_y
179
+ X_block = tl.where(X_offsets != y, X_block, X_block - (1 - label_smoothing))
180
+ # reduction scale
181
+ if reduction == "mean":
182
+ X_block = X_block / n_non_ignore
183
+ else:
184
+ weight_block = tl.load(weight_ptr + X_offsets, mask=X_offsets < n_cols)
185
+ softmax_X = tl.exp(X_block - m) / d
186
+ # derivative of original_loss
187
+ dloss_ori = (1 - label_smoothing) * softmax_X
188
+ # specially handle dx_y
189
+ dloss_ori = tl.where(X_offsets != y, dloss_ori, dloss_ori - (1 - label_smoothing))
190
+ dloss_ori = dloss_ori * weight_y
191
+ # derivative of smooth_loss
192
+ dloss_smooth = eps * (-weight_block + softmax_X * weight_sum)
193
+ # derivative of z-loss
194
+ dz_loss = 2 * lse_square_scale * lse * softmax_X
195
+ # reduction scale
196
+ if reduction == "mean":
197
+ dloss_ori = dloss_ori / sum_non_ignore_weight
198
+ dloss_smooth = dloss_smooth / sum_non_ignore_weight
199
+ # TODO: Implement weighted z_loss. Currently, z_loss is not scaled by weight.
200
+ dz_loss = dz_loss / n_non_ignore
201
+ # derivative of total_loss
202
+ X_block = dloss_ori + dloss_smooth + dz_loss
203
+
204
+ # chain rule softcapping
205
+ # d(softcap * tanh(x / softcap)) = (1 - tanh^2(x / softcap))
206
+ if HAS_SOFTCAPPING:
207
+ X_block = X_block * (1 - intermediate * intermediate)
208
+
209
+ tl.store(X_ptr + X_offsets, X_block, mask=X_offsets < n_cols)
210
+
211
+ # We need tl.debug_barrier() to ensure the new result of X_ptr is written as mentioned in
212
+ # https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/ops/cross_entropy.py#L34
213
+ tl.debug_barrier()
214
+
215
+ # 5. Calculate the loss
216
+
217
+ # loss = log (softmax(X_y)) = log ((e ^ (X_y - max(X)) / sum(e ^ (X - max(X))))
218
+ # = (X_y - max(X)) - log(sum(e ^ (X - max(X))))
219
+ # = X_y - m - log d = X_y - lse
220
+ # sum(e ^ (X - max(X))) must >= 1 because the max term is e ^ 0 = 1
221
+ # So we can safely calculate log (softmax(X_y)) without overflow
222
+ loss = lse - ori_X_y
223
+ if HAS_WEIGHT:
224
+ loss = weight_y * loss
225
+
226
+ # Original loss = H(q, p), with label smoothing regularization = H(q', p) and (label_smoothing / V) = eps
227
+ # H(q', p) = (1 - label_smoothing) * H(q, p) + label_smoothing * H(u, p)
228
+ # = (1 - label_smoothing) * H(q, p) + eps * sum(logsoftmax(x_i))
229
+ # By using m (global max of xi) and d (sum of e^(xi-m)), we can simplify as:
230
+ # = (1 - label_smoothing) * H(q, p) + (sum(-eps * x_i) + label_smoothing * (m + logd))
231
+ # Refer to H(q', p) in section 7 of the paper: https://arxiv.org/pdf/1512.00567
232
+ # pytorch: https://github.com/pytorch/pytorch/blob/2981534f54d49fa3a9755c9b0855e7929c2527f0/aten/src/ATen/native/LossNLL.cpp#L516
233
+ # See full derivation at https://github.com/linkedin/Liger-Kernel/pull/198#issuecomment-2333753087
234
+ if label_smoothing > 0:
235
+ if HAS_WEIGHT:
236
+ smooth_loss = scaled_x_sum + eps * lse * weight_sum
237
+ else:
238
+ smooth_loss = scaled_x_sum + label_smoothing * lse
239
+ loss = loss * (1 - label_smoothing) + smooth_loss
240
+
241
+ # An auxiliary loss, z_loss
242
+ # Refer to Page14 Loss function section in the paper PaLM: https://www.jmlr.org/papers/v24/22-1144.html
243
+ z_loss = lse_square_scale * lse * lse
244
+ # Normalize the loss by the number of non-ignored elements if reduction is "mean"
245
+ if reduction == "mean":
246
+ if HAS_WEIGHT:
247
+ loss = loss / sum_non_ignore_weight
248
+ else:
249
+ loss = loss / n_non_ignore
250
+ # TODO: Implement weighted z_loss. Currently, z_loss is not scaled by weight.
251
+ z_loss = z_loss / n_non_ignore
252
+ loss += z_loss
253
+
254
+ tl.store(loss_ptr, loss)
255
+ if RETURN_Z_LOSS:
256
+ tl.store(z_loss_ptr, z_loss)
257
+
258
+
259
+ # The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
260
+ # However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
261
+ # The optimal maximum block size depends on your hardware, your kernel, and your dtype
262
+ MAX_FUSED_SIZE = 4096 if infer_device() == "xpu" else 65536 // 2 # the best size we found by manually tuning
263
+
264
+
265
+ def cross_entropy_forward(
266
+ _input,
267
+ target,
268
+ weight,
269
+ ignore_index,
270
+ lse_square_scale,
271
+ label_smoothing,
272
+ reduction,
273
+ softcap,
274
+ return_z_loss,
275
+ ):
276
+ assert isinstance(return_z_loss, bool), f"return_z_loss must be True or False. Got: {return_z_loss}"
277
+
278
+ BT, V = _input.shape
279
+ n_rows = BT
280
+
281
+ BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
282
+
283
+ # unreduced loss
284
+ loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device)
285
+ z_loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device) if return_z_loss else None
286
+
287
+ target_mask = target != ignore_index
288
+ n_non_ignore = target_mask.sum().item()
289
+ assert (target * target_mask).max() < _input.shape[-1], (
290
+ f"Target {target.max()} is out of bounds. Expected < {_input.shape[-1]}"
291
+ )
292
+ assert (target * target_mask).min() >= 0, f"Target {target.min()} is out of bounds. Expected >= 0"
293
+ sum_non_ignore_weight = n_non_ignore
294
+ weight_sum = 0.0
295
+ if weight is not None:
296
+ assert weight.shape[0] == V, f"If given, weight has to be a Tensor of size V. Got: {weight.shape}"
297
+ assert torch.is_floating_point(weight), (
298
+ f"If given, weight has to be a Tensor of floating point dtype. Got: {weight.dtype}"
299
+ )
300
+ sum_non_ignore_weight = torch.gather(weight, dim=0, index=target.masked_select(target_mask)).sum().item()
301
+ weight_sum = weight.sum().item()
302
+ # ensure weight is contiguous
303
+ if weight.stride(-1) != 1:
304
+ weight = weight.contiguous()
305
+
306
+ # ensure _input and target are contiguous in the last dimension
307
+ if _input.stride(-1) != 1:
308
+ _input = _input.contiguous()
309
+ if target.stride(-1) != 1:
310
+ target = target.contiguous()
311
+
312
+ # Here we use a trick to store X_ptr gradient in X_ptr so we can save memory
313
+ liger_cross_entropy_kernel[(n_rows,)](
314
+ X_ptr=_input,
315
+ X_stride=_input.stride(-2),
316
+ Y_ptr=target,
317
+ Y_stride=target.stride(-1), # always 1
318
+ weight_ptr=weight, # dummy if None
319
+ loss_ptr=loss_1d,
320
+ z_loss_ptr=z_loss_1d,
321
+ loss_stride=loss_1d.stride(-1), # always 1
322
+ n_cols=V,
323
+ n_non_ignore=n_non_ignore,
324
+ sum_non_ignore_weight=sum_non_ignore_weight,
325
+ ignore_index=ignore_index,
326
+ weight_sum=weight_sum,
327
+ lse_square_scale=lse_square_scale,
328
+ label_smoothing=label_smoothing,
329
+ reduction=reduction,
330
+ softcap=softcap,
331
+ RETURN_Z_LOSS=return_z_loss,
332
+ BLOCK_SIZE=BLOCK_SIZE,
333
+ HAS_WEIGHT=True if weight is not None else False,
334
+ HAS_SOFTCAPPING=True if softcap is not None else False,
335
+ # TODO: 32 seems to give the best performance
336
+ # Performance is quite sensitive to num_warps
337
+ num_warps=32 if not is_hip() else 16,
338
+ )
339
+
340
+ if reduction == "none":
341
+ loss = loss_1d
342
+ z_loss = z_loss_1d if return_z_loss else None
343
+ else:
344
+ loss = torch.sum(loss_1d)
345
+ z_loss = torch.sum(z_loss_1d) if return_z_loss else None
346
+
347
+ return loss, z_loss, _input
348
+
349
+
350
+ def cross_entropy_backward(_input, grad_output):
351
+ # If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time
352
+ if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
353
+ pass
354
+ # If reduction is 'none'
355
+ elif grad_output.ndim > 0:
356
+ _input = _input * grad_output.unsqueeze(dim=1)
357
+ # If reduction is ['mean', 'sum'], grad_output is just a scalar
358
+ # We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place
359
+ # for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton.
360
+ else:
361
+ BT, V = _input.shape
362
+ n_rows = BT
363
+ BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
364
+
365
+ element_mul_kernel[(n_rows,)](
366
+ _input,
367
+ _input.stride(-2),
368
+ grad_output,
369
+ V,
370
+ BLOCK_SIZE=BLOCK_SIZE,
371
+ num_warps=32 if not is_hip() else 16,
372
+ )
373
+
374
+ return _input
375
+
376
+
377
+ class LigerCrossEntropyFunction(torch.autograd.Function):
378
+ """
379
+ This class implements a custom autograd function for the Liger Cross Entropy loss.
380
+ It overrides the forward and backward methods of the torch.autograd.Function class.
381
+ """
382
+
383
+ @staticmethod
384
+ def forward(
385
+ ctx,
386
+ _input: torch.Tensor,
387
+ target: torch.Tensor,
388
+ weight: Optional[torch.FloatTensor],
389
+ ignore_index: int = -100,
390
+ lse_square_scale: float = 0.0,
391
+ label_smoothing: float = 0.0,
392
+ reduction: str = "mean",
393
+ softcap: Optional[float] = None,
394
+ return_z_loss: bool = False,
395
+ ):
396
+ """
397
+ The forward pass of the Liger Cross Entropy loss.
398
+
399
+ Parameters:
400
+ ctx : The context object.
401
+ _input (tensor): The input tensor of shape (BT, V) where B is batch size, T is sequence length, V is vocab size.
402
+ target (tensor): The target tensor of shape (BT) where each value is in [0, V-1].
403
+ weight(Tensor, optional): a manual rescaling weight given to each class. If given, has to be a Tensor of size V and floating point dtype
404
+ ignore_index (int): The index to ignore in the target.
405
+ lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training.
406
+ label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
407
+ reduction (str): The reduction to apply to the output: "none" | "mean | "sum".
408
+ softcap (Optional[float]): The upper threshold for scaling logits to the range (-softcap, +softcap).
409
+ return_z_loss (bool): When `return_z_loss` is `True`, returns (loss, z_loss) instead of (loss, None). Default: `False`
410
+
411
+ Returns:
412
+ tuple: A tuple with the compouted losses with respect to loss and z loss. The elements are tensors or None.
413
+ """
414
+ loss, z_loss, _input = cross_entropy_forward(
415
+ _input,
416
+ target,
417
+ weight,
418
+ ignore_index,
419
+ lse_square_scale,
420
+ label_smoothing,
421
+ reduction,
422
+ softcap,
423
+ return_z_loss,
424
+ )
425
+ # TODO: investigation
426
+ # If we don't detach the _input tensor, the memory will double
427
+ # Not sure why but seems that there will be a time both grad and value exist but in different location
428
+ ctx.save_for_backward(_input.detach())
429
+ ctx.return_z_loss = return_z_loss
430
+
431
+ return loss, z_loss
432
+
433
+ @staticmethod
434
+ def backward(ctx, grad_output, grad_ouput2):
435
+ """
436
+ The backward pass of the Liger Cross Entropy loss.
437
+
438
+ Parameters:
439
+ ctx : The context object with saved tensors.
440
+ grad_output (tensor): The tensor containing the gradient of the loss with respect to the output.
441
+ grad_output2 (tenosr): No use.
442
+ Returns:
443
+ tuple: A tuple with the gradients with respect to the inputs. The elements are tensors or None.
444
+ """
445
+ if ctx.return_z_loss:
446
+ del grad_ouput2 # z_loss is only for logging
447
+
448
+ (_input,) = ctx.saved_tensors
449
+ _input = cross_entropy_backward(_input, grad_output)
450
+ return (
451
+ _input,
452
+ None,
453
+ None,
454
+ None,
455
+ None,
456
+ None,
457
+ None,
458
+ None,
459
+ None,
460
+ )
build/torch-universal/liger_kernels/dyt.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import operator
2
+
3
+ import torch
4
+ import triton
5
+ import triton.language as tl
6
+
7
+ from utils import calculate_settings
8
+ from utils import compare_version
9
+ from utils import ensure_contiguous
10
+ from utils import infer_device
11
+
12
+ if compare_version("triton", operator.ge, "3.0.0"):
13
+ try:
14
+ # typical import path with dispatch available
15
+ from triton.language.extra.libdevice import tanh
16
+ except ModuleNotFoundError:
17
+ # for working with NGC containers
18
+ from triton.language.extra.cuda.libdevice import tanh
19
+ else:
20
+ from triton.language.math import tanh
21
+
22
+
23
+ @triton.jit
24
+ def _dyt_fwd_kernel(
25
+ x_ptr,
26
+ x_row_stride,
27
+ alpha_ptr,
28
+ gamma_ptr,
29
+ beta_ptr,
30
+ y_ptr,
31
+ y_row_stride,
32
+ n_cols,
33
+ BLOCK_SIZE: tl.constexpr,
34
+ ):
35
+ """
36
+ Reference:
37
+ https://arxiv.org/abs/2503.10622
38
+
39
+ Shapes:
40
+ - x: (BT, C)
41
+ - alpha: (1)
42
+ - gamma: (C)
43
+ - beta: (C)
44
+ """
45
+ row_idx = tl.program_id(0)
46
+ offsets = tl.arange(0, BLOCK_SIZE)
47
+ mask = offsets < n_cols
48
+
49
+ x_ptr += row_idx * x_row_stride
50
+ y_ptr += row_idx * y_row_stride
51
+
52
+ alpha = tl.load(alpha_ptr)
53
+ gamma = tl.load(gamma_ptr + offsets, mask=mask)
54
+ beta = tl.load(beta_ptr + offsets, mask=mask)
55
+ x = tl.load(x_ptr + offsets, mask=mask)
56
+ y = gamma * tanh((alpha * x).cast(tl.float32)) + beta
57
+ tl.store(y_ptr + offsets, y, mask=mask)
58
+
59
+
60
+ @triton.jit
61
+ def _dyt_bwd_kernel(
62
+ x_ptr,
63
+ x_row_stride,
64
+ dy_ptr,
65
+ dy_row_stride,
66
+ dx_ptr,
67
+ dx_row_stride,
68
+ alpha_ptr,
69
+ dalpha_ptr,
70
+ gamma_ptr,
71
+ dgamma_ptr,
72
+ dgamma_row_stride,
73
+ n_cols,
74
+ n_rows,
75
+ ROWS_PER_PROGRAM: tl.constexpr,
76
+ BLOCK_SIZE: tl.constexpr,
77
+ ):
78
+ """
79
+ Reference:
80
+ https://arxiv.org/abs/2503.10622
81
+
82
+ Shapes:
83
+ - x: (BT, C)
84
+ - alpha: (1)
85
+ - gamma: (C)
86
+ - dx: (BT, C)
87
+ - dy: (BT, C)
88
+ - dgamma: (sm_count, C)
89
+ - dalpha: (sm_count,)
90
+ """
91
+ # d(gamma * tanh(alpha * x) + beta) / dx
92
+ # = gamma * (1 - tanh^2(alpha * x)) * alpha
93
+ # d(gamma * tanh(alpha * x) + beta) / dalpha
94
+ # = gamma * (1 - tanh^2(alpha * x)) * x
95
+ # d(gamma * tanh(alpha * x) + beta) / dgamma
96
+ # = tanh(alpha * x)
97
+ # d(gamma * tanh(alpha * x)) / dbeta = 1
98
+ pid = tl.program_id(0)
99
+
100
+ row_start = pid * ROWS_PER_PROGRAM
101
+ row_end = min((pid + 1) * ROWS_PER_PROGRAM, n_rows)
102
+ offsets = tl.arange(0, BLOCK_SIZE)
103
+ mask = offsets < n_cols
104
+
105
+ dalpha = 0.0
106
+ dgamma = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
107
+
108
+ x_ptr += row_start * x_row_stride
109
+ dx_ptr += row_start * dx_row_stride
110
+ dy_ptr += row_start * dy_row_stride
111
+ alpha = tl.load(alpha_ptr)
112
+ gamma = tl.load(gamma_ptr + offsets, mask=mask, other=0.0)
113
+
114
+ for _ in tl.range(row_start, row_end):
115
+ dy = tl.load(dy_ptr + offsets, mask=mask, other=0.0)
116
+ x = tl.load(x_ptr + offsets, mask=mask, other=0.0)
117
+ tanh_ax = tanh((alpha * x).cast(tl.float32))
118
+ sech2_ax = 1 - tanh_ax * tanh_ax
119
+
120
+ dx = dy * gamma * sech2_ax * alpha
121
+ dalpha += tl.sum(dy * gamma * sech2_ax * x)
122
+ dgamma += dy * tanh_ax
123
+ tl.store(dx_ptr + offsets, dx, mask=mask)
124
+
125
+ dy_ptr += dy_row_stride
126
+ x_ptr += x_row_stride
127
+ dx_ptr += dx_row_stride
128
+
129
+ tl.store(dgamma_ptr + pid * dgamma_row_stride + offsets, dgamma, mask=mask)
130
+ tl.store(dalpha_ptr + pid, dalpha)
131
+
132
+ pass
133
+
134
+
135
+ def liger_dyt_fwd(x, alpha, gamma, beta):
136
+ shape = x.shape
137
+ dim = shape[-1]
138
+ x = x.view(-1, dim)
139
+ n_rows, n_cols = x.shape
140
+ y = torch.empty_like(x)
141
+ BLOCK_SIZE, num_warps = calculate_settings(n_cols)
142
+ _dyt_fwd_kernel[(n_rows,)](
143
+ x_ptr=x,
144
+ alpha_ptr=alpha,
145
+ gamma_ptr=gamma,
146
+ beta_ptr=beta,
147
+ y_ptr=y,
148
+ x_row_stride=x.stride(0),
149
+ y_row_stride=y.stride(0),
150
+ n_cols=n_cols,
151
+ BLOCK_SIZE=BLOCK_SIZE,
152
+ num_warps=num_warps,
153
+ )
154
+ return y.view(*shape)
155
+
156
+
157
+ def liger_dyt_bwd(dy, x, alpha, gamma):
158
+ shape = dy.shape
159
+ dtype = x.dtype
160
+ dim = shape[-1]
161
+ dy = dy.view(-1, dim)
162
+ x = x.view(-1, dim)
163
+ n_rows, n_cols = dy.shape
164
+ BLOCK_SIZE, num_warps = calculate_settings(n_cols)
165
+ sm_count = 1
166
+ device = infer_device()
167
+ if device == "cuda":
168
+ sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count
169
+ elif device == "xpu":
170
+ sm_count = torch.xpu.get_device_properties(x.device).gpu_subslice_count
171
+ if n_cols > BLOCK_SIZE:
172
+ raise RuntimeError(
173
+ f"Feature dimension {dim} exceeds maximum supported size of {BLOCK_SIZE}. Consider using a smaller feature dimension."
174
+ )
175
+
176
+ dx = torch.empty_like(x, dtype=torch.float32)
177
+ _dalpha = torch.empty((sm_count,), dtype=torch.float32, device=x.device)
178
+ _dgamma = torch.empty((sm_count, n_cols), dtype=torch.float32, device=x.device)
179
+
180
+ grid = (sm_count,)
181
+ rows_per_program = triton.cdiv(n_rows, sm_count)
182
+ _dyt_bwd_kernel[grid](
183
+ x_ptr=x,
184
+ x_row_stride=x.stride(0),
185
+ dy_ptr=dy,
186
+ dy_row_stride=dy.stride(0),
187
+ dx_ptr=dx,
188
+ dx_row_stride=dx.stride(0),
189
+ alpha_ptr=alpha,
190
+ dalpha_ptr=_dalpha,
191
+ gamma_ptr=gamma,
192
+ dgamma_ptr=_dgamma,
193
+ dgamma_row_stride=_dgamma.stride(0),
194
+ n_cols=n_cols,
195
+ n_rows=n_rows,
196
+ ROWS_PER_PROGRAM=rows_per_program,
197
+ BLOCK_SIZE=BLOCK_SIZE,
198
+ num_warps=num_warps,
199
+ )
200
+ dalpha = _dalpha.sum(dim=0, keepdim=True).to(dtype)
201
+ dgamma = _dgamma.sum(dim=0).to(dtype)
202
+ dbeta = dy.sum(dim=0).to(dtype)
203
+ return dx.view(*shape), dalpha, dgamma, dbeta
204
+
205
+
206
+ class LigerDyTFunction(torch.autograd.Function):
207
+ @staticmethod
208
+ @ensure_contiguous
209
+ def forward(ctx, x, alpha, gamma, beta):
210
+ y = liger_dyt_fwd(x, alpha, gamma, beta)
211
+ ctx.save_for_backward(x, alpha, gamma)
212
+ return y
213
+
214
+ @staticmethod
215
+ @ensure_contiguous
216
+ def backward(ctx, grad_output):
217
+ x, alpha, gamma = ctx.saved_tensors
218
+ dx, dalpha, dgamma, dbeta = liger_dyt_bwd(
219
+ grad_output,
220
+ x,
221
+ alpha,
222
+ gamma,
223
+ )
224
+
225
+ return (dx, dalpha, dgamma, dbeta)
build/torch-universal/liger_kernels/fused_linear_cross_entropy.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import triton
3
+
4
+ from cross_entropy import liger_cross_entropy_kernel
5
+ from utils import amp_custom_bwd
6
+ from utils import amp_custom_fwd
7
+ from utils import element_mul_kernel
8
+ from utils import is_hip
9
+
10
+ # The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
11
+ # However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
12
+ # The optimal maximum block size depends on your hardware, your kernel, and your dtype
13
+ MAX_FUSED_SIZE = 65536 // 2
14
+
15
+
16
+ def fused_linear_cross_entropy_forward(
17
+ _input,
18
+ weight,
19
+ target,
20
+ ce_weight=None,
21
+ bias=None,
22
+ ignore_index=-100,
23
+ lse_square_scale=0.0,
24
+ label_smoothing=0.0,
25
+ reduction="mean",
26
+ softcap=None,
27
+ return_z_loss=False,
28
+ ):
29
+ assert isinstance(return_z_loss, bool), f"return_z_loss must be True or False. Got: {return_z_loss}"
30
+ device = _input.device
31
+
32
+ # inputs have shape: BT x H
33
+ # materialized activations will have shape: BT x V
34
+ # the increase in memory = BT x V
35
+ # reduction can be achieved by partitioning the number of tokens BT into smaller chunks.
36
+ # for ex: if we were to achieve the same memory consumption as BT x H, then the chunk size should be:
37
+ # inc_factor = (V+H-1)//H, chunk_size = (BT + inc_factor - 1)//inc_factor
38
+ # for ex: BT = 4096*4, V = 32000, H = 4096 ==> inc_factor = 8, chunk_size = 2048
39
+ BT, H = _input.shape
40
+ V = weight.shape[0]
41
+ BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
42
+
43
+ inc_factor = triton.cdiv(V, H) # (V + H - 1) // H
44
+ chunk_size = triton.next_power_of_2(triton.cdiv(BT, inc_factor)) # (BT + inc_factor - 1) // inc_factor
45
+ num_chunks = triton.cdiv(BT, chunk_size) # (BT + chunk_size - 1) // chunk_size
46
+
47
+ grad_weight = torch.zeros_like(weight, device=device) if weight.requires_grad else None
48
+ grad_input = torch.zeros_like(_input, device=device)
49
+ grad_bias = torch.zeros_like(bias, device=device) if bias is not None else None
50
+ # we use fp32 for loss accumulator
51
+ loss_1d = torch.zeros(BT, dtype=torch.float32, device=device)
52
+ z_loss_1d = torch.zeros(BT, dtype=_input.dtype, device=_input.device) if return_z_loss else None
53
+
54
+ # TODO: evaluate how CUDA synchronization caused by .item() affects the speed
55
+ target_mask = target != ignore_index
56
+ total_n_non_ignore = target_mask.sum().item()
57
+ total_sum_non_ignore_ce_weight = total_n_non_ignore
58
+ ce_weight_sum = 0.0
59
+ if ce_weight is not None:
60
+ assert ce_weight.shape[0] == V, f"If given, weight has to be a Tensor of size V. Got: {ce_weight.shape}"
61
+ assert torch.is_floating_point(ce_weight), (
62
+ f"If given, weight has to be a Tensor of floating point dtype. Got: {ce_weight.dtype}"
63
+ )
64
+ total_sum_non_ignore_ce_weight = (
65
+ torch.gather(ce_weight, dim=0, index=target.masked_select(target_mask)).sum().item()
66
+ )
67
+ ce_weight_sum = ce_weight.sum().item()
68
+ if ce_weight.stride(-1) != 1:
69
+ ce_weight = ce_weight.contiguous()
70
+
71
+ for chunk_id in range(num_chunks):
72
+ start_idx = chunk_id * chunk_size
73
+ end_idx = min((chunk_id + 1) * chunk_size, BT)
74
+ _input_chunk = _input[start_idx:end_idx] # chunk_size x H
75
+
76
+ # when doing matmul, use the original precision
77
+ logits_chunk = _input_chunk @ weight.t() # chunk_size x V
78
+ if bias is not None:
79
+ logits_chunk = logits_chunk + bias
80
+
81
+ target_chunk = target[start_idx:end_idx] # chunk_size,
82
+
83
+ n_rows = logits_chunk.shape[0]
84
+
85
+ # unreduced loss
86
+ loss_1d_slice = loss_1d[start_idx:end_idx] # chunk_size,
87
+ z_loss_1d_slice = z_loss_1d[start_idx:end_idx] if return_z_loss else None
88
+
89
+ # ensure _input and target are contiguous
90
+ logits_chunk = logits_chunk.contiguous()
91
+ target_chunk = target_chunk.contiguous()
92
+
93
+ # Here we calculate the gradient of logits_chunk in place so we can save memory.
94
+ liger_cross_entropy_kernel[(n_rows,)](
95
+ X_ptr=logits_chunk,
96
+ X_stride=logits_chunk.stride(-2),
97
+ Y_ptr=target_chunk,
98
+ Y_stride=target_chunk.stride(-1), # always 1
99
+ weight_ptr=ce_weight,
100
+ loss_ptr=loss_1d_slice,
101
+ z_loss_ptr=z_loss_1d_slice,
102
+ loss_stride=loss_1d_slice.stride(-1), # always 1
103
+ n_cols=V,
104
+ n_non_ignore=total_n_non_ignore,
105
+ sum_non_ignore_weight=total_sum_non_ignore_ce_weight,
106
+ weight_sum=ce_weight_sum,
107
+ ignore_index=ignore_index,
108
+ lse_square_scale=lse_square_scale,
109
+ label_smoothing=label_smoothing,
110
+ reduction=reduction,
111
+ softcap=softcap,
112
+ RETURN_Z_LOSS=return_z_loss,
113
+ HAS_WEIGHT=True if ce_weight is not None else False,
114
+ HAS_SOFTCAPPING=True if softcap is not None else False,
115
+ BLOCK_SIZE=BLOCK_SIZE,
116
+ num_warps=32 if not is_hip() else 16,
117
+ )
118
+
119
+ loss_1d[start_idx:end_idx] = loss_1d_slice
120
+ if return_z_loss:
121
+ z_loss_1d[start_idx:end_idx] = z_loss_1d_slice
122
+ grad_logits_chunk = logits_chunk # chunk_size x V
123
+
124
+ grad_input[start_idx:end_idx] = grad_logits_chunk @ weight
125
+
126
+ if grad_weight is not None:
127
+ torch.addmm(
128
+ input=grad_weight,
129
+ mat1=logits_chunk.t().to(
130
+ _input_chunk.dtype
131
+ ), # In an autocast scenario without bias, differing logits_chunk data types will cause an addmm operation error.
132
+ mat2=_input_chunk,
133
+ out=grad_weight,
134
+ alpha=1.0,
135
+ beta=1.0,
136
+ )
137
+
138
+ if bias is not None:
139
+ torch.add(
140
+ input=grad_bias,
141
+ other=logits_chunk.sum(dim=0),
142
+ out=grad_bias,
143
+ alpha=1.0,
144
+ )
145
+
146
+ # Need extra calculations for backward if reduction=='none'. Not supporting reduction='none' now.
147
+ # if reduction == "none":
148
+ # loss = loss_1d
149
+ # z_loss = z_loss_1d if return_z_loss else None
150
+
151
+ else:
152
+ loss = torch.sum(loss_1d)
153
+ z_loss = torch.sum(z_loss_1d) if return_z_loss else None
154
+ return loss, z_loss, grad_input, grad_weight, grad_bias
155
+
156
+
157
+ def fused_linear_cross_entropy_backward(grad_output, grad_input, grad_weight, grad_bias):
158
+ # If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time
159
+ if not torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
160
+ # We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place
161
+ # for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton.
162
+ BT, H = grad_input.shape
163
+ n_rows = BT
164
+ BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(H))
165
+
166
+ element_mul_kernel[(n_rows,)](
167
+ grad_input,
168
+ grad_input.stride(-2),
169
+ grad_output,
170
+ H,
171
+ BLOCK_SIZE=BLOCK_SIZE,
172
+ num_warps=32 if not is_hip() else 16,
173
+ )
174
+
175
+ # handle grad_weight
176
+ if grad_weight is not None:
177
+ V, H = grad_weight.shape
178
+ n_rows = V
179
+
180
+ element_mul_kernel[(n_rows,)](
181
+ grad_weight,
182
+ grad_weight.stride(-2),
183
+ grad_output,
184
+ H,
185
+ BLOCK_SIZE=BLOCK_SIZE,
186
+ num_warps=32 if not is_hip() else 16,
187
+ )
188
+
189
+ if grad_bias is not None:
190
+ V = grad_bias.shape[0]
191
+ n_rows = V
192
+
193
+ element_mul_kernel[(n_rows,)](
194
+ grad_bias,
195
+ grad_bias.stride(-1),
196
+ grad_output,
197
+ 1,
198
+ BLOCK_SIZE=BLOCK_SIZE,
199
+ num_warps=32 if not is_hip() else 16,
200
+ )
201
+ return grad_input, grad_weight, grad_bias
202
+
203
+
204
+ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
205
+ @staticmethod
206
+ @amp_custom_fwd
207
+ def forward(
208
+ ctx,
209
+ _input,
210
+ weight,
211
+ target,
212
+ bias=None,
213
+ ce_weight=None,
214
+ ignore_index=-100,
215
+ lse_square_scale=0.0,
216
+ label_smoothing=0.0,
217
+ reduction="mean",
218
+ softcap=None,
219
+ return_z_loss: bool = False,
220
+ ):
221
+ """
222
+ Fusing the last linear layer with cross-entropy loss
223
+ Reference: https://github.com/mgmalek/efficient_cross_entropy
224
+
225
+ Handle the forward and backward pass of the final linear layer via cross-entropy loss by avoiding
226
+ the materialization of the large logits tensor. Since Cross Entropy Loss is the last layer, we can
227
+ compute the gradient at the forward pass. By doing so, we don't have to store the _input and target
228
+ for the backward pass.
229
+
230
+ _input: (B*T, H) where B is batch size, T is sequence length, H is hidden dimension.
231
+ target: (B*T) where each value is in [0, V-1]
232
+ weight: (V, H) where V is the number of classes
233
+ bias: (V) where V is the number of classes
234
+ ce_weight: a manual rescaling weight given to each class. If given, has to be a Tensor of size V and floating point dtype
235
+ ignore_index: the index to ignore in the target
236
+ label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
237
+ reduction: reduction to apply
238
+ """
239
+
240
+ loss, z_loss, grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_forward(
241
+ _input=_input,
242
+ weight=weight,
243
+ target=target,
244
+ bias=bias,
245
+ ce_weight=ce_weight,
246
+ ignore_index=ignore_index,
247
+ lse_square_scale=lse_square_scale,
248
+ label_smoothing=label_smoothing,
249
+ reduction=reduction,
250
+ softcap=softcap,
251
+ return_z_loss=return_z_loss,
252
+ )
253
+ # downcast to dtype and store for backward
254
+ ctx.save_for_backward(
255
+ grad_input.detach(),
256
+ grad_weight.detach() if grad_weight is not None else None,
257
+ grad_bias.detach() if bias is not None else None,
258
+ )
259
+ ctx.return_z_loss = return_z_loss
260
+ return loss, z_loss
261
+
262
+ @staticmethod
263
+ @amp_custom_bwd
264
+ def backward(ctx, grad_output, grad_output2):
265
+ if ctx.return_z_loss:
266
+ del grad_output2 # z_loss is only for logging
267
+ (grad_input, grad_weight, grad_bias) = ctx.saved_tensors
268
+ grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_backward(
269
+ grad_output, grad_input, grad_weight, grad_bias
270
+ )
271
+ return (
272
+ grad_input,
273
+ grad_weight,
274
+ None,
275
+ grad_bias,
276
+ None,
277
+ None,
278
+ None,
279
+ None,
280
+ None,
281
+ None,
282
+ None,
283
+ )
build/torch-universal/liger_kernels/geglu.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import operator
2
+
3
+ import torch
4
+ import triton
5
+ import triton.language as tl
6
+
7
+ from utils import calculate_settings
8
+ from utils import compare_version
9
+ from utils import ensure_contiguous
10
+
11
+ if compare_version("triton", operator.ge, "3.0.0"):
12
+ try:
13
+ # typical import path with dispatch available
14
+ from triton.language.extra.libdevice import tanh
15
+ except ModuleNotFoundError:
16
+ # for working with NGC containers
17
+ from triton.language.extra.cuda.libdevice import tanh
18
+ else:
19
+ from triton.language.math import tanh
20
+
21
+
22
+ @triton.jit
23
+ def _geglu_tanh_forward_kernel(a, b, c, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr):
24
+ program_id = tl.program_id(0).to(tl.int64)
25
+
26
+ # locate start index
27
+ a += program_id * stride
28
+ b += program_id * stride
29
+ c += program_id * stride
30
+
31
+ col_offsets = tl.arange(0, BLOCK_SIZE)
32
+ mask = col_offsets < n_cols
33
+ a_row = tl.load(a + col_offsets, mask=mask, other=0).to(tl.float32)
34
+ b_row = tl.load(b + col_offsets, mask=mask, other=0)
35
+
36
+ # tanh approximation form of GELU is computed with:
37
+ # 0.5 * a * (1 + tanh(sqrt(2 / pi) * (a + 0.044715 * a^3)))
38
+ sqrt_2_over_pi = 0.7978845608028654 # sqrt(2 / pi)
39
+ a_cubed = a_row * a_row * a_row
40
+ tanh_arg = sqrt_2_over_pi * (a_row + 0.044715 * a_cubed)
41
+ tanh_result = tanh(tanh_arg)
42
+ geglu_a = 0.5 * a_row * (1 + tanh_result)
43
+ c_row = geglu_a * b_row
44
+ tl.store(c + col_offsets, c_row, mask=mask)
45
+
46
+
47
+ @triton.jit
48
+ def _geglu_tanh_backward_kernel(dc, a, b, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr):
49
+ program_id = tl.program_id(0).to(tl.int64)
50
+
51
+ # locate start index
52
+ dc += program_id * stride
53
+ a += program_id * stride
54
+ b += program_id * stride
55
+
56
+ col_offsets = tl.arange(0, BLOCK_SIZE)
57
+ mask = col_offsets < n_cols
58
+
59
+ dc_row = tl.load(dc + col_offsets, mask=mask, other=0)
60
+ a_row = tl.load(a + col_offsets, mask=mask, other=0).to(tl.float32)
61
+ b_row = tl.load(b + col_offsets, mask=mask, other=0)
62
+
63
+ # recomputation to save memory
64
+ sqrt_2_over_pi = 0.7978845608028654 # sqrt(2 / pi)
65
+ a_cubed = a_row * a_row * a_row
66
+ tanh_arg = sqrt_2_over_pi * (a_row + 0.044715 * a_cubed)
67
+ tanh_result = tanh(tanh_arg)
68
+ geglu_a = 0.5 * a_row * (1 + tanh_result)
69
+
70
+ db_row = dc_row * geglu_a
71
+
72
+ # Gradient w.r.t. a can be computed with:
73
+ # b * (0.5 * (1 + tanh(z)) + 0.5 * a * (1 - tanh(z)^2) * (sqrt(2/pi) * (1 + 3 * 0.044715 * a^2)))
74
+ # where z = sqrt(2/pi) * (a + 0.044715 * a^3)
75
+ term1 = 0.5 * (1 + tanh_result)
76
+ tanh_sq = tanh_result * tanh_result
77
+ term2 = 0.5 * a_row * (1 - tanh_sq) * (sqrt_2_over_pi * (1 + 3 * 0.044715 * a_row * a_row))
78
+ da_row = dc_row * b_row * (term1 + term2)
79
+
80
+ tl.store(a + col_offsets, da_row, mask=mask)
81
+ tl.store(b + col_offsets, db_row, mask=mask)
82
+
83
+
84
+ def geglu_forward(a, b):
85
+ ori_shape = a.shape
86
+
87
+ n_cols = ori_shape[-1]
88
+ a = a.view(-1, n_cols)
89
+ b = b.view(-1, n_cols)
90
+ c = torch.empty_like(a)
91
+ n_rows = a.shape[0]
92
+
93
+ BLOCK_SIZE, num_warps = calculate_settings(n_cols)
94
+
95
+ _geglu_tanh_forward_kernel[(n_rows,)](
96
+ a,
97
+ b,
98
+ c,
99
+ c.stride(-2),
100
+ n_cols=n_cols,
101
+ BLOCK_SIZE=BLOCK_SIZE,
102
+ num_warps=num_warps,
103
+ )
104
+ return a, b, c.view(*ori_shape)
105
+
106
+
107
+ def geglu_backward(a, b, dc):
108
+ ori_shape = dc.shape
109
+ n_cols = ori_shape[-1]
110
+ dc = dc.view(-1, n_cols)
111
+ n_rows = dc.shape[0]
112
+
113
+ BLOCK_SIZE, num_warps = calculate_settings(n_cols)
114
+
115
+ _geglu_tanh_backward_kernel[(n_rows,)](
116
+ dc,
117
+ a,
118
+ b,
119
+ dc.stride(-2),
120
+ n_cols=n_cols,
121
+ BLOCK_SIZE=BLOCK_SIZE,
122
+ num_warps=num_warps,
123
+ )
124
+
125
+ return a.view(*ori_shape), b.view(*ori_shape)
126
+
127
+
128
+ class LigerGELUMulFunction(torch.autograd.Function):
129
+ @staticmethod
130
+ @ensure_contiguous
131
+ def forward(ctx, a, b):
132
+ a, b, c = geglu_forward(a, b)
133
+ ctx.save_for_backward(a, b)
134
+ return c
135
+
136
+ @staticmethod
137
+ @ensure_contiguous
138
+ def backward(ctx, dc):
139
+ a, b = ctx.saved_tensors
140
+ a, b = geglu_backward(a, b, dc)
141
+ return a, b
build/torch-universal/liger_kernels/group_norm.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import operator
2
+
3
+ import torch
4
+ import triton
5
+ import triton.language as tl
6
+
7
+ from utils import compare_version
8
+ from utils import ensure_contiguous
9
+
10
+ if compare_version("triton", operator.ge, "3.0.0"):
11
+ try:
12
+ # typical import path with dispatch available
13
+ from triton.language.extra.libdevice import rsqrt
14
+ except ModuleNotFoundError:
15
+ # for working with NGC containers
16
+ from triton.language.extra.cuda.libdevice import rsqrt
17
+ else:
18
+ from triton.language.math import rsqrt
19
+
20
+ MAX_FUSED_SIZE = 65536
21
+
22
+
23
+ @triton.jit
24
+ def _group_norm_forward_kernel(
25
+ Y_ptr, # pointer to output, shape (n_rows, n_groups, hidden_size)
26
+ Y_row_stride, # stride of each row in output
27
+ Y_col_stride, # stride of each column in output
28
+ X_ptr, # pointer to input, shape (n_rows, n_groups, hidden_size)
29
+ X_row_stride, # stride of each row in input
30
+ X_col_stride, # stride of each column in input
31
+ Mean_ptr, # pointer to mean, shape (n_rows, n_groups)
32
+ Mean_row_stride, # stride of each row in mean
33
+ Mean_col_stride, # stride of each column in mean
34
+ RSTD_ptr, # pointer to rstd, shape (n_rows, n_groups)
35
+ RSTD_row_stride, # stride of each row in rstd
36
+ RSTD_col_stride, # stride of each column in rstd
37
+ W_ptr, # pointer to W
38
+ B_ptr, # pointer to B
39
+ hidden_size, # hidden size of X
40
+ channels_per_group, # the number of channels per group
41
+ eps,
42
+ BLOCK_SIZE: tl.constexpr,
43
+ ):
44
+ """
45
+ References:
46
+ https://nn.labml.ai/normalization/group_norm/index.html
47
+ """
48
+ batch_idx = tl.program_id(0)
49
+ group_idx = tl.program_id(1)
50
+
51
+ X_ptr += batch_idx * X_row_stride + group_idx * X_col_stride
52
+ Y_ptr += batch_idx * Y_row_stride + group_idx * Y_col_stride
53
+
54
+ block_range = tl.arange(0, BLOCK_SIZE)
55
+
56
+ # Compute mean and variance using the online algorithm
57
+ s = 0.0
58
+ squared_sum = 0.0
59
+ for i in tl.range(0, hidden_size, BLOCK_SIZE):
60
+ hidden_size_offsets = i + block_range
61
+ mask = hidden_size_offsets < hidden_size
62
+ X = tl.load(X_ptr + hidden_size_offsets, mask=mask, other=0.0)
63
+ s += tl.sum(X)
64
+ # X**2
65
+ squared_sum += tl.sum(X * X)
66
+
67
+ m = s / hidden_size
68
+
69
+ # variance = E[X**2] - E[X]**2
70
+ variance = (squared_sum / hidden_size) - (m * m)
71
+
72
+ # 1/std
73
+ rstd = rsqrt(variance + eps)
74
+
75
+ # Normalize
76
+ hidden_size_per_channel = hidden_size // channels_per_group
77
+ for channel_idx in tl.range(group_idx * channels_per_group, (group_idx + 1) * channels_per_group):
78
+ W = tl.load(W_ptr + channel_idx)
79
+ B = tl.load(B_ptr + channel_idx)
80
+ for i in range(0, hidden_size_per_channel, BLOCK_SIZE):
81
+ hidden_size_offsets = i + block_range
82
+ mask = hidden_size_offsets < hidden_size_per_channel
83
+ X = tl.load(X_ptr + hidden_size_offsets, mask=mask, other=m)
84
+ Y = (X - m) * rstd * W + B
85
+ tl.store(Y_ptr + hidden_size_offsets, Y, mask=mask)
86
+
87
+ X_ptr += hidden_size_per_channel
88
+ Y_ptr += hidden_size_per_channel
89
+
90
+ tl.store(Mean_ptr + batch_idx * Mean_row_stride + group_idx * Mean_col_stride, m)
91
+ tl.store(RSTD_ptr + batch_idx * RSTD_row_stride + group_idx * RSTD_col_stride, rstd)
92
+
93
+
94
+ @triton.jit
95
+ def _group_norm_backward_kernel(
96
+ X_ptr, # pointer to input, shape (n_rows, n_channels, hidden_size)
97
+ X_row_stride, # stride of each row in input
98
+ X_col_stride, # stride of each column in input
99
+ W_ptr, # pointer to weights, shape (n_channels)
100
+ Mean_ptr, # pointer to mean, shape (n_rows, n_groups)
101
+ Mean_ptr_row_stride, # stride of each column in mean
102
+ Mean_ptr_col_stride, # stride of each column in mean
103
+ RSTD_ptr, # pointer to rstd, shape (n_rows, n_groups)
104
+ DX_ptr, # pointer to input grad, shape (n_rows, n_groups, hidden_size)
105
+ DW_ptr, # pointer to weights grad, shape (n_channels)
106
+ DB_ptr, # pointer to bias grad, shape (n_channels)
107
+ UPSTREAM_ptr, # pointer to output grad, shape (n_rows, n_channels, hidden_size)
108
+ hidden_size: tl.constexpr, # hidden size
109
+ channels_per_group: tl.constexpr, # number of groups in group norm
110
+ BLOCK_SIZE: tl.constexpr,
111
+ dtype: tl.constexpr,
112
+ ):
113
+ """
114
+ References:
115
+ https://nn.labml.ai/normalization/group_norm/index.html
116
+ https://github.com/karpathy/llm.c/blob/master/doc/layernorm/layernorm.md
117
+
118
+ The backprop equations are the same for group_norm and layer_norm
119
+ the only difference here is that we load the Mean, Rstd corresponding to the
120
+ group we're computing gradients for and the mean and rstd are computed over n-channels
121
+ so the total number of elements we compute the mean over is num_channels_per_group * hidden_size
122
+
123
+ We also need to load the Weights corresponding to the current channel to compute the gradients.
124
+ """
125
+ batch_idx = tl.program_id(0)
126
+ group_idx = tl.program_id(1)
127
+
128
+ # Move the pointers to the correct batch
129
+ X_ptr += batch_idx * X_row_stride
130
+ DX_ptr += batch_idx * X_row_stride
131
+ UPSTREAM_ptr += batch_idx * X_row_stride
132
+
133
+ # Mean and rstd are the same shape so have the same strides
134
+ mean = tl.load(Mean_ptr + batch_idx * Mean_ptr_row_stride + group_idx * Mean_ptr_col_stride)
135
+ rstd = tl.load(RSTD_ptr + batch_idx * Mean_ptr_row_stride + group_idx * Mean_ptr_col_stride)
136
+
137
+ c1 = 0.0
138
+ c2 = 0.0
139
+ block_range = tl.arange(0, BLOCK_SIZE)
140
+
141
+ # We need to compute the sum terms of the backprop equations across all channels in the group
142
+ for channel_idx in range(group_idx * channels_per_group, (group_idx + 1) * channels_per_group):
143
+ dW = 0.0
144
+ dB = 0.0
145
+ # Move the pointers to the correct channel
146
+ W = tl.load(W_ptr + channel_idx)
147
+ for i in tl.range(0, hidden_size, BLOCK_SIZE):
148
+ hidden_size_offsets = i + block_range
149
+ mask = hidden_size_offsets < hidden_size
150
+ X = tl.load(
151
+ X_ptr + channel_idx * X_col_stride + hidden_size_offsets,
152
+ mask=mask,
153
+ other=0.0,
154
+ )
155
+ UPSTREAM_grad = tl.load(
156
+ UPSTREAM_ptr + channel_idx * X_col_stride + hidden_size_offsets,
157
+ mask=mask,
158
+ other=0.0,
159
+ )
160
+
161
+ x_hat = (X - mean) * rstd
162
+ dW += tl.sum(UPSTREAM_grad * x_hat)
163
+ dB += tl.sum(UPSTREAM_grad)
164
+
165
+ wdy = W * UPSTREAM_grad
166
+ c1 += tl.sum(x_hat * wdy)
167
+ c2 += tl.sum(wdy)
168
+
169
+ # Need to ensure additions to the same channel are atomic
170
+ tl.atomic_add(DW_ptr + channel_idx, dW.to(dtype))
171
+ tl.atomic_add(DB_ptr + channel_idx, dB.to(dtype))
172
+
173
+ N = hidden_size * channels_per_group
174
+ c1 = c1 / N
175
+ c2 = c2 / N
176
+
177
+ for channel_idx in tl.range(group_idx * channels_per_group, (group_idx + 1) * channels_per_group):
178
+ # Move the pointers to the correct channel
179
+ W = tl.load(W_ptr + channel_idx)
180
+ for i in range(0, hidden_size, BLOCK_SIZE):
181
+ hidden_size_offsets = i + block_range
182
+ mask = hidden_size_offsets < hidden_size
183
+ X = tl.load(
184
+ X_ptr + channel_idx * X_col_stride + hidden_size_offsets,
185
+ mask=mask,
186
+ other=0.0,
187
+ )
188
+ UPSTREAM_grad = tl.load(
189
+ UPSTREAM_ptr + channel_idx * X_col_stride + hidden_size_offsets,
190
+ mask=mask,
191
+ other=0.0,
192
+ )
193
+
194
+ x_hat = (X - mean) * rstd
195
+ wdy = W * UPSTREAM_grad
196
+ dx = (wdy - (x_hat * c1 + c2)) * rstd
197
+ tl.store(DX_ptr + channel_idx * X_col_stride + hidden_size_offsets, dx, mask=mask)
198
+
199
+
200
+ def group_norm_forward(X, num_channels, num_groups, W, B, eps):
201
+ shape = X.shape
202
+ batch_size = shape[0]
203
+ channels_per_group = num_channels // num_groups
204
+ # Reshape X so that the mean and std are computed across the groups
205
+ X = X.view(batch_size, num_groups, -1).contiguous()
206
+ hidden_size = X.shape[-1]
207
+ BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(hidden_size))
208
+ Y = torch.empty((batch_size, num_groups, hidden_size), dtype=X.dtype, device=X.device)
209
+ Mean = torch.zeros((batch_size, num_groups), dtype=X.dtype, device=X.device)
210
+ RSTD = torch.zeros((batch_size, num_groups), dtype=X.dtype, device=X.device)
211
+
212
+ _group_norm_forward_kernel[(batch_size, num_groups)](
213
+ Y,
214
+ Y.stride(0),
215
+ Y.stride(1),
216
+ X,
217
+ X.stride(0),
218
+ X.stride(1),
219
+ Mean,
220
+ Mean.stride(0),
221
+ Mean.stride(1),
222
+ RSTD,
223
+ RSTD.stride(0),
224
+ RSTD.stride(1),
225
+ W,
226
+ B,
227
+ hidden_size,
228
+ channels_per_group,
229
+ eps,
230
+ BLOCK_SIZE=BLOCK_SIZE,
231
+ )
232
+ # Return tensors in the original shape
233
+ return Y.view(*shape), X.view(*shape), Mean, RSTD, BLOCK_SIZE
234
+
235
+
236
+ def group_norm_backward(dY, X, W, B, Mean, RSTD, num_channels, num_groups):
237
+ shape = dY.shape
238
+ batch_size = shape[0]
239
+ hidden_size = dY.shape[-1]
240
+ channels_per_group = num_channels // num_groups
241
+ dY = dY.view(batch_size, num_groups, -1)
242
+ DX = torch.empty(
243
+ (batch_size, num_groups, hidden_size * channels_per_group),
244
+ dtype=X.dtype,
245
+ device=X.device,
246
+ )
247
+ DW = torch.zeros((num_channels), dtype=W.dtype, device=W.device)
248
+ DB = torch.zeros((num_channels), dtype=B.dtype, device=B.device)
249
+ triton_dtype = tl.float32 if X.dtype == torch.float32 else tl.bfloat16
250
+
251
+ BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(hidden_size))
252
+ _group_norm_backward_kernel[(batch_size, num_groups)](
253
+ X,
254
+ X.stride(0),
255
+ X.stride(1),
256
+ W,
257
+ Mean,
258
+ Mean.stride(0),
259
+ Mean.stride(1),
260
+ RSTD,
261
+ DX,
262
+ DW,
263
+ DB,
264
+ dY,
265
+ hidden_size,
266
+ channels_per_group,
267
+ BLOCK_SIZE=BLOCK_SIZE,
268
+ dtype=triton_dtype,
269
+ )
270
+
271
+ # Return tensors in the original shape
272
+ return DX.view(*shape), DW, DB
273
+
274
+
275
+ class LigerGroupNormFunction(torch.autograd.Function):
276
+ @staticmethod
277
+ @ensure_contiguous
278
+ def forward(
279
+ ctx,
280
+ X,
281
+ affine_scaling_weight,
282
+ affine_shifting_bias,
283
+ num_channels,
284
+ num_groups,
285
+ eps,
286
+ ):
287
+ Y, X, Mean, RSTD, BLOCK_SIZE = group_norm_forward(
288
+ X,
289
+ num_channels,
290
+ num_groups,
291
+ affine_scaling_weight,
292
+ affine_shifting_bias,
293
+ eps,
294
+ )
295
+ ctx.num_channels = num_channels
296
+ ctx.num_groups = num_groups
297
+ ctx.save_for_backward(X, affine_scaling_weight, affine_shifting_bias, Mean, RSTD)
298
+ return Y
299
+
300
+ @staticmethod
301
+ @ensure_contiguous
302
+ def backward(ctx, dY):
303
+ X, W, B, Mean, RSTD = ctx.saved_tensors
304
+ DX, DW, DB = group_norm_backward(dY, X, W, B, Mean, RSTD, ctx.num_channels, ctx.num_groups)
305
+ return DX, DW, DB, None, None, None
build/torch-universal/liger_kernels/jsd.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+ import triton
5
+ import triton.language as tl
6
+
7
+ from utils import ensure_contiguous
8
+ from utils import infer_device
9
+
10
+
11
+ @triton.jit
12
+ def _jsd_kernel(
13
+ X_ptr, # input in logspace, X = log Q
14
+ X_stride,
15
+ Y_ptr, # ground truth in logspace, Y = log P
16
+ Y_stride,
17
+ loss_ptr,
18
+ loss_stride,
19
+ dX_ptr,
20
+ dX_stride,
21
+ label_ptr,
22
+ beta: tl.constexpr,
23
+ n_non_ignore: int,
24
+ ignore_index: tl.constexpr,
25
+ n_cols,
26
+ BLOCK_SIZE: tl.constexpr,
27
+ HAS_LABEL: tl.constexpr,
28
+ ):
29
+ # JSD(P || Q) = (KL(P || M) + KL(Q || M)) / 2, M = (1/2) * (P + Q) = (1/2) * (e ^ Y + e ^ X)
30
+ # = sum(P * log P + Q * log Q - 2 * M * log M) / 2
31
+ # = sum(e ^ Y * Y + e ^ X * X - 2 * M * log M) / 2
32
+ # grad_x_i = 0.5 * Q * (X - log_M)
33
+ pid = tl.program_id(0).to(tl.int64)
34
+ X_ptr += pid * X_stride
35
+ dX_ptr += pid * dX_stride
36
+ Y_ptr += pid * Y_stride
37
+ loss_ptr += pid * loss_stride
38
+ label_ptr += pid
39
+
40
+ if HAS_LABEL:
41
+ label = tl.load(label_ptr)
42
+ if label == ignore_index:
43
+ for i in range(0, n_cols, BLOCK_SIZE):
44
+ offsets = i + tl.arange(0, BLOCK_SIZE)
45
+ tl.store(dX_ptr + offsets, 0.0, mask=offsets < n_cols)
46
+ return
47
+
48
+ for i in range(0, n_cols, BLOCK_SIZE):
49
+ offsets = i + tl.arange(0, BLOCK_SIZE)
50
+ mask = offsets < n_cols
51
+ X = tl.load(X_ptr + offsets, mask=mask, other=float("-inf")).to(tl.float32)
52
+ Y = tl.load(Y_ptr + offsets, mask=mask, other=float("-inf")).to(tl.float32)
53
+
54
+ if beta == 0.0: # forward KL
55
+ Y_max = tl.max(Y, axis=0)
56
+ Y_shifted = Y - Y_max
57
+ Y_prob = tl.exp(Y_shifted) * tl.exp(Y_max) # Compensate for the shift
58
+ loss = Y_prob * (Y - X)
59
+ dX = -Y_prob
60
+ elif beta == 1.0: # reverse KL
61
+ X_max = tl.max(X, axis=0)
62
+ X_shifted = X - X_max
63
+ X_prob = tl.exp(X_shifted) * tl.exp(X_max) # Compensate for the shift
64
+ loss = X_prob * (X - Y)
65
+ dX = loss + X_prob
66
+ else:
67
+ max_val = tl.maximum(tl.max(X, axis=0), tl.max(Y, axis=0))
68
+ X_shifted = X - max_val
69
+ Y_shifted = Y - max_val
70
+
71
+ # Pre-compute exp(max_val) since it's used twice
72
+ exp_max = tl.exp(max_val)
73
+
74
+ # Compute exp terms with compensation
75
+ Q = tl.exp(X_shifted) * exp_max # = exp(X)
76
+ P = tl.exp(Y_shifted) * exp_max # = exp(Y)
77
+
78
+ # Pre-compute common terms
79
+ beta_P = beta * P
80
+ one_minus_beta_Q = (1 - beta) * Q
81
+ M = beta_P + one_minus_beta_Q
82
+ log_M = tl.log(M) # No need to compensate as M is already in original scale
83
+
84
+ loss = beta_P * Y + one_minus_beta_Q * X - M * log_M
85
+ dX = one_minus_beta_Q * (X - log_M)
86
+
87
+ # Pre-compute scaling factor
88
+ scale = 1.0 / n_non_ignore
89
+ loss = loss * scale
90
+ dX = dX * scale
91
+
92
+ tl.store(loss_ptr + offsets, loss, mask=mask)
93
+ tl.store(dX_ptr + offsets, dX, mask=mask)
94
+
95
+
96
+ MAX_FUSED_SIZE = 4096 if infer_device() == "xpu" else 65536
97
+
98
+
99
+ def jsd_forward(_input, target, shift_labels, beta, ignore_index, has_label):
100
+ BT, V = _input.shape
101
+ n_rows = BT
102
+ BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
103
+ # non reduction loss
104
+ loss = torch.zeros(_input.shape, dtype=torch.float32, device=_input.device)
105
+ dX = torch.empty_like(_input)
106
+
107
+ if has_label:
108
+ n_non_ignore = (shift_labels != ignore_index).sum().item()
109
+ else:
110
+ n_non_ignore = BT
111
+
112
+ _jsd_kernel[(n_rows,)](
113
+ X_ptr=_input, # input in logspace, X = log Q
114
+ X_stride=_input.stride(-2),
115
+ Y_ptr=target, # ground truth in logspace, Y = log P
116
+ Y_stride=target.stride(-2),
117
+ loss_ptr=loss,
118
+ loss_stride=loss.stride(-2),
119
+ dX_ptr=dX,
120
+ dX_stride=dX.stride(-2),
121
+ label_ptr=(shift_labels if has_label else torch.empty(1, device=_input.device)), # dummy ptr if no label
122
+ beta=beta,
123
+ n_non_ignore=n_non_ignore,
124
+ ignore_index=ignore_index,
125
+ n_cols=V,
126
+ BLOCK_SIZE=BLOCK_SIZE,
127
+ HAS_LABEL=has_label,
128
+ )
129
+
130
+ loss = torch.sum(loss)
131
+ return loss.to(_input.dtype), dX
132
+
133
+
134
+ def jsd_backward(dX, grad_output):
135
+ # If jsd is the last layer, grad_output is 1.0. Skip the mul to save time
136
+ if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
137
+ return dX
138
+ else:
139
+ return grad_output * dX
140
+
141
+
142
+ class LigerJSDFunction(torch.autograd.Function):
143
+ r"""
144
+ This class implements the forward and backward pass for the generalized Jensen-Shannon Divergence.
145
+ .. math::
146
+ JSD(\beta)(P || Q)
147
+ = \beta * KLDiv(P || (\beta * P + (1 - \beta) * Q)) + (1 - \beta) * KLDiv(Q || (\beta * P + (1 - \beta) * Q))
148
+
149
+ .. note::
150
+ As all the other losses in PyTorch, this function expects the first argument,
151
+ :attr:`_input`, to be the predictions, the output of the student model, in log-space
152
+ and the second, :attr:`target`, to be the observations, the output of the teacher model, in log-space.
153
+ This differs from the standard mathematical notation :math:`JSD(P || Q)` where
154
+ :math:`P` denotes the teacher model and :math:`Q` denotes the student model.
155
+ """
156
+
157
+ @staticmethod
158
+ @ensure_contiguous
159
+ def forward(
160
+ ctx,
161
+ _input: torch.Tensor,
162
+ target: torch.Tensor,
163
+ shift_labels: Optional[torch.Tensor] = None,
164
+ beta: float = 0.5,
165
+ ignore_index: int = -100,
166
+ ) -> torch.Tensor:
167
+ """
168
+ Args:
169
+ _input (torch.Tensor): predict values with shape (BT, V) in logspace
170
+ target (torch.Tensor): ground truth values with shape (BT, V) in logspace
171
+ shift_labels (Optional[torch.LongTensor]): indicator of next predicted vocab with shape (BT) where each value is in [0, V-1].
172
+ beta (float): coefficient beta of generalized JSD in the interval [0, 1]. It implements forward/reverse KL when beta equals 0 and 1 respectively. Default: `0.5`
173
+ ignore_index (int): the index to ignore. Default: -100
174
+
175
+ Returns:
176
+ loss (torch.Tensor): generalized JSD
177
+ """
178
+ has_label = False
179
+ if shift_labels is not None:
180
+ assert shift_labels.shape == (_input.shape[0],), (
181
+ f"the shape of shift_labels must be (BT,). Got: {shift_labels.shape}"
182
+ )
183
+ shift_labels = shift_labels.contiguous()
184
+ has_label = True
185
+
186
+ loss, dX = jsd_forward(_input, target, shift_labels, beta, ignore_index, has_label)
187
+ ctx.save_for_backward(dX)
188
+ return loss
189
+
190
+ @staticmethod
191
+ @ensure_contiguous
192
+ def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
193
+ (dX,) = ctx.saved_tensors
194
+ dX = jsd_backward(dX, grad_output)
195
+ return (
196
+ dX,
197
+ None,
198
+ None,
199
+ None,
200
+ None,
201
+ )
build/torch-universal/liger_kernels/kl_div.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Literal
2
+
3
+ import torch
4
+ import triton
5
+ import triton.language as tl
6
+
7
+ from utils import ensure_contiguous
8
+ from utils import is_hip
9
+ from utils import infer_device
10
+
11
+
12
+ def get_num_warps(BLOCK_SIZE):
13
+ num_warps = 4
14
+ if BLOCK_SIZE >= 32768:
15
+ num_warps = 32 if not is_hip() else 16
16
+ elif BLOCK_SIZE >= 8192:
17
+ num_warps = 16
18
+ elif BLOCK_SIZE >= 2048:
19
+ num_warps = 8
20
+
21
+ return num_warps
22
+
23
+
24
+ MAX_FUSED_SIZE = 65536 // 4 # 65536 // 4 or 8 works the best
25
+
26
+ REDUCTION_LITERAL = Literal["none", "sum", "mean", "batchmean"]
27
+
28
+ _REDUCTION_MODE_NONE: tl.constexpr = tl.constexpr(0)
29
+ _REDUCTION_MODE_SUM: tl.constexpr = tl.constexpr(1)
30
+ _REDUCTION_MODE_MEAN: tl.constexpr = tl.constexpr(2)
31
+ _REDUCTION_MODE_BATCHMEAN: tl.constexpr = tl.constexpr(3)
32
+
33
+ _str_to_reduction_mode = {
34
+ "none": _REDUCTION_MODE_NONE.value,
35
+ "sum": _REDUCTION_MODE_SUM.value,
36
+ "mean": _REDUCTION_MODE_MEAN.value,
37
+ "batchmean": _REDUCTION_MODE_BATCHMEAN.value,
38
+ }
39
+
40
+
41
+ @triton.jit
42
+ def _kldiv_kernel_forward(
43
+ y_ptr, # [B, S], prediction ptr, the kernel expects the prediction in log-space
44
+ y_stride, # int, prediction stride
45
+ gt_ptr, # [B, S], ground truth ptr
46
+ gt_stride, # int, ground truth stride
47
+ loss_ptr, # [B] or [B, S] if reduction == _REDUCTION_MODE_NONE, output ptr
48
+ loss_stride, # int, output stride
49
+ n_cols, # int, number of columns in the input tensor
50
+ eps,
51
+ BLOCK_SIZE: tl.constexpr,
52
+ log_target: tl.constexpr = False,
53
+ reduction: tl.constexpr = _REDUCTION_MODE_BATCHMEAN,
54
+ ):
55
+ pid = tl.program_id(0).to(tl.int64)
56
+ y_ptr += pid * y_stride
57
+ gt_ptr += pid * gt_stride
58
+ loss_ptr += pid * loss_stride
59
+
60
+ base_offsets = tl.arange(0, BLOCK_SIZE)
61
+
62
+ loss_sum = 0.0
63
+ for i in range(0, n_cols, BLOCK_SIZE):
64
+ offsets = i + base_offsets
65
+ mask = offsets < n_cols
66
+ y = tl.load(y_ptr + offsets, mask=mask, other=0.0)
67
+ y_true = tl.load(gt_ptr + offsets, mask=mask, other=0.0)
68
+
69
+ # KL(y_true || y) = y_true * (log(y_true) - log(y))
70
+ # We compute KL(y_true || y) with y in the log-space
71
+ if not log_target:
72
+ loss = y_true * (tl.log(tl.maximum(y_true, eps)) - y)
73
+ else:
74
+ loss = tl.exp(y_true) * (y_true - y)
75
+
76
+ if reduction == _REDUCTION_MODE_NONE:
77
+ tl.store(loss_ptr + offsets, loss, mask=mask)
78
+ else:
79
+ loss_sum += tl.sum(loss, axis=0)
80
+
81
+ if reduction != _REDUCTION_MODE_NONE:
82
+ tl.store(loss_ptr, loss_sum)
83
+
84
+
85
+ @triton.jit
86
+ def _kldiv_kernel_backward(
87
+ target_ptr,
88
+ target_stride,
89
+ new_grads_ptr,
90
+ new_grads_stride,
91
+ n_cols,
92
+ BLOCK_SIZE: tl.constexpr,
93
+ log_target: tl.constexpr = False,
94
+ ):
95
+ pid = tl.program_id(0).to(tl.int64)
96
+
97
+ target_ptr += pid * target_stride
98
+ new_grads_ptr += pid * new_grads_stride
99
+
100
+ offsets = tl.arange(0, BLOCK_SIZE)
101
+ mask = offsets < n_cols
102
+
103
+ for i in range(0, n_cols, BLOCK_SIZE):
104
+ offsets = i + tl.arange(0, BLOCK_SIZE)
105
+ mask = offsets < n_cols
106
+
107
+ target = tl.load(target_ptr + offsets, mask=mask, other=0.0)
108
+
109
+ if not log_target:
110
+ res = target * -1
111
+ else:
112
+ res = -tl.exp(target)
113
+
114
+ tl.store(new_grads_ptr + offsets, res, mask=mask)
115
+
116
+
117
+ def kldiv_forward_triton(y_pred, y_true, log_target, reduction, eps): # [BT, V]
118
+ BT, V = y_pred.shape
119
+ BLOCK_SIZE = (
120
+ min(8192, triton.next_power_of_2(V))
121
+ if infer_device() == "xpu"
122
+ else min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
123
+ )
124
+ num_warps = 32 if infer_device() == "xpu" else get_num_warps(BLOCK_SIZE)
125
+
126
+ grid = (BT,)
127
+ reduction = _str_to_reduction_mode[reduction]
128
+
129
+ out_size = (BT, V) if reduction == _REDUCTION_MODE_NONE.value else (BT,)
130
+ output_tensor = torch.zeros(out_size, device=y_pred.device, dtype=torch.float32)
131
+
132
+ _kldiv_kernel_forward[grid](
133
+ y_pred,
134
+ y_pred.stride(0),
135
+ y_true,
136
+ y_true.stride(0),
137
+ output_tensor,
138
+ output_tensor.stride(0),
139
+ V,
140
+ eps=eps,
141
+ BLOCK_SIZE=BLOCK_SIZE,
142
+ num_warps=num_warps,
143
+ log_target=log_target,
144
+ reduction=reduction,
145
+ )
146
+
147
+ # calculated according to the reduction mode same as in Pytorch. In the later versions, `mean` will be changed to the same behavior as `batchmean`
148
+ # https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html
149
+ # https://github.com/pytorch/pytorch/blob/d7b57c4d63edb42e1deeeba9497fcb5f1f748ff2/torch/nn/functional.py#L3372
150
+ if reduction == _REDUCTION_MODE_BATCHMEAN.value:
151
+ return output_tensor.sum() / BT
152
+ elif reduction == _REDUCTION_MODE_SUM.value:
153
+ return output_tensor.sum(dim=0)
154
+ elif reduction == _REDUCTION_MODE_MEAN.value:
155
+ return output_tensor.sum() / (BT * V)
156
+ else:
157
+ return output_tensor
158
+
159
+
160
+ def kldiv_backward_triton(target, grad_output, new_grads, log_target):
161
+ BT, V = target.shape
162
+ BLOCK_SIZE = (
163
+ min(8192, triton.next_power_of_2(V))
164
+ if infer_device() == "xpu"
165
+ else min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
166
+ )
167
+ num_warps = 32 if infer_device() == "xpu" else get_num_warps(BLOCK_SIZE)
168
+
169
+ grid = (BT,)
170
+
171
+ # We store the gradients in-place in the input tensor
172
+ _kldiv_kernel_backward[grid](
173
+ target,
174
+ target.stride(0),
175
+ new_grads,
176
+ new_grads.stride(0),
177
+ V,
178
+ BLOCK_SIZE=BLOCK_SIZE,
179
+ num_warps=num_warps,
180
+ log_target=log_target,
181
+ )
182
+
183
+ # If cross entropy is the last layer, grad_output is 1.0. Skip the mul then.
184
+ if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
185
+ return new_grads
186
+
187
+ return new_grads * grad_output
188
+
189
+
190
+ class LigerKLDivLossFunction(torch.autograd.Function):
191
+ """
192
+ Class implementing the forward and backward pass for the KL Divergence Loss using Triton, as defined by the following formula:
193
+ ```python
194
+ if log_target:
195
+ loss = target.exp() * (target - input)
196
+ else:
197
+ loss = target * (target.log() - input)
198
+ ```,
199
+ then the loss is reduced according to the `reduction` parameter.
200
+ as defined in the PyTorch documentation: https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html
201
+ """
202
+
203
+ @staticmethod
204
+ @ensure_contiguous
205
+ def forward(
206
+ ctx,
207
+ y_pred: torch.Tensor,
208
+ y_true: torch.Tensor,
209
+ reduction: REDUCTION_LITERAL = "batchmean",
210
+ log_target: bool = False,
211
+ eps: float = 1e-10,
212
+ ) -> torch.Tensor:
213
+ """A forward pass for the KL Divergence Loss.
214
+
215
+ Args:
216
+ ctx: Torch autograd context
217
+ y_pred (torch.Tensor): A tensor of shape (BT, V) containing the predicted values, expected to be log-probabilities.
218
+ y_true (torch.Tensor): A tensor of shape (BT, V) containing the target values, expected to be either probabilities or log-probabilities, depending on the value of `log_target`.
219
+ reduction (REDUCTION_LITERAL, optional): Reduction to be used. Defaults to "batchmean".
220
+ log_target (bool, optional): If set to true, expects the ground truth to already be log-probabilities. Defaults to False.
221
+ eps: (float, optional): A small value to avoid division by zero. Defaults to 1e-10.
222
+
223
+ Returns:
224
+ torch.Tensor: The computed KL Divergence Loss, with shape (BT, V) if `reduction` is "none", else a scalar.
225
+ """
226
+ ctx.save_for_backward(y_true)
227
+ ctx.reduction = reduction
228
+ ctx.log_target = log_target
229
+ return kldiv_forward_triton(y_pred, y_true, log_target=log_target, reduction=reduction, eps=eps)
230
+
231
+ @staticmethod
232
+ @ensure_contiguous
233
+ def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
234
+ """A backward pass for the KL Divergence Loss.
235
+
236
+ Args:
237
+ ctx: Torch autograd context
238
+ grad_output (torch.Tensor): The gradient of the loss with respect to the output.
239
+
240
+ Returns:
241
+ tuple[torch.Tensor, None, None, None, None]: The gradient of the loss with respect to the inputs and None for the other arguments of the forward method.
242
+ """
243
+ (y_true,) = ctx.saved_tensors
244
+
245
+ new_grads = torch.empty_like(y_true)
246
+
247
+ derivative = kldiv_backward_triton(y_true, grad_output, new_grads, ctx.log_target)
248
+
249
+ if ctx.reduction == "batchmean":
250
+ derivative = derivative / y_true.shape[0]
251
+ elif ctx.reduction == "sum" or ctx.reduction == "none":
252
+ pass
253
+ elif ctx.reduction == "mean":
254
+ derivative = derivative / (y_true.shape[0] * y_true.shape[1])
255
+
256
+ return (
257
+ derivative,
258
+ None,
259
+ None,
260
+ None,
261
+ None,
262
+ )
build/torch-universal/liger_kernels/layer_norm.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import operator
3
+
4
+ import torch
5
+ import triton
6
+ import triton.language as tl
7
+
8
+ from utils import calculate_settings
9
+ from utils import compare_version
10
+ from utils import ensure_contiguous
11
+
12
+ if compare_version("triton", operator.ge, "3.0.0"):
13
+ try:
14
+ # typical import path with dispatch available
15
+ from triton.language.extra.libdevice import rsqrt
16
+ except ModuleNotFoundError:
17
+ # for working with NGC containers
18
+ from triton.language.extra.cuda.libdevice import rsqrt
19
+ else:
20
+ from triton.language.math import rsqrt
21
+
22
+
23
+ @triton.jit
24
+ def _layer_norm_forward_kernel(
25
+ Y_ptr, # pointer to output, shape (n_rows, n_cols)
26
+ Y_row_stride, # stride of each row in output
27
+ X_ptr, # pointer to input, shape (n_rows, n_cols)
28
+ X_row_stride, # stride of each row in input
29
+ W_ptr, # pointer to weights, shape (n_cols,)
30
+ W_row_stride, # stride of each row in weights
31
+ B_ptr, # pointer to bias, shape (n_cols,)
32
+ B_row_stride, # stride of each row in bias
33
+ Mean_ptr, # pointer to mean, shape (n_rows,)
34
+ Mean_row_stride, # stride of each row in mean
35
+ RSTD_ptr, # pointer to rstd, shape (n_rows,)
36
+ RSTD_row_stride, # stride of each row in rstd
37
+ n_cols,
38
+ eps,
39
+ BLOCK_SIZE: tl.constexpr,
40
+ ):
41
+ """
42
+ References:
43
+ https://arxiv.org/abs/1607.06450
44
+ https://github.com/karpathy/llm.c/blob/master/doc/layernorm/layernorm.md
45
+ """
46
+ row_idx = tl.program_id(0)
47
+ col_offsets = tl.arange(0, BLOCK_SIZE)
48
+ mask = col_offsets < n_cols
49
+
50
+ Y_ptr += row_idx * Y_row_stride
51
+ X_ptr += row_idx * X_row_stride
52
+ Mean_ptr += row_idx * Mean_row_stride
53
+ RSTD_ptr += row_idx * RSTD_row_stride
54
+
55
+ X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0)
56
+ W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0)
57
+ B_row = tl.load(B_ptr + col_offsets, mask=mask, other=0)
58
+
59
+ mean = tl.sum(X_row, axis=0) / n_cols
60
+ Xmm = tl.where(mask, X_row - mean, 0)
61
+ var = tl.sum(Xmm * Xmm, axis=0) / n_cols
62
+ rstd = rsqrt(var + eps)
63
+
64
+ tl.store(Mean_ptr, mean)
65
+ tl.store(RSTD_ptr, rstd)
66
+
67
+ Y_row = Xmm * rstd * W_row + B_row
68
+
69
+ tl.store(Y_ptr + col_offsets, Y_row, mask=mask)
70
+
71
+
72
+ @triton.jit
73
+ def _layer_norm_backward_kernel(
74
+ X_ptr, # pointer to input, shape (n_rows, n_cols)
75
+ W_ptr, # pointer to weights, shape (n_cols,)
76
+ Mean_ptr, # pointer to mean, shape (n_rows,)
77
+ RSTD_ptr, # pointer to rstd, shape (n_rows,)
78
+ DX_ptr, # pointer to input grad, shape (n_rows, n_cols)
79
+ DW_ptr, # pointer to weights grad, shape (n_cols,)
80
+ DB_ptr, # pointer to bias grad, shape (n_cols,)
81
+ DY_ptr, # pointer to output grad, shape (n_rows, n_cols)
82
+ stride_x, # stride of each row in input
83
+ stride_dx, # stride of each row in input grad
84
+ stride_dw, # stride of each row in weights grad
85
+ stride_db, # stride of each row in bias grad
86
+ stride_dy, # stride of each row in output grad
87
+ n_rows,
88
+ n_cols,
89
+ rows_per_program: tl.constexpr,
90
+ BLOCK_SIZE: tl.constexpr,
91
+ dtype: tl.constexpr,
92
+ ):
93
+ """
94
+ References:
95
+ https://arxiv.org/abs/1607.06450
96
+ https://github.com/karpathy/llm.c/blob/master/doc/layernorm/layernorm.md
97
+ https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
98
+ https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/triton/layer_norm.py
99
+ """
100
+ row_block_id = tl.program_id(0)
101
+ row_start = row_block_id * rows_per_program
102
+ row_end = min((row_block_id + 1) * rows_per_program, n_rows)
103
+ cols = tl.arange(0, BLOCK_SIZE)
104
+ mask = cols < n_cols
105
+
106
+ dw_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
107
+ db_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
108
+
109
+ X_ptr += row_start * stride_x
110
+ Mean_ptr += row_start
111
+ RSTD_ptr += row_start
112
+ DX_ptr += row_start * stride_dx
113
+ DY_ptr += row_start * stride_dy
114
+
115
+ for _ in range(row_start, row_end):
116
+ x = tl.load(X_ptr + cols, mask=mask, other=0.0)
117
+ w = tl.load(W_ptr + cols, mask=mask, other=0.0)
118
+ dy = tl.load(DY_ptr + cols, mask=mask, other=0.0)
119
+ mean = tl.load(Mean_ptr)
120
+ rstd = tl.load(RSTD_ptr)
121
+
122
+ x_hat = (x - mean) * rstd
123
+ wdy = w * dy
124
+ c1 = tl.sum(x_hat * wdy, axis=0) / n_cols
125
+ c2 = tl.sum(wdy, axis=0) / n_cols
126
+ dx = (wdy - (x_hat * c1 + c2)) * rstd
127
+ tl.store(DX_ptr + cols, dx.to(dtype), mask=mask)
128
+
129
+ dw_row += dy * x_hat
130
+ db_row += dy
131
+
132
+ X_ptr += stride_x
133
+ Mean_ptr += 1
134
+ RSTD_ptr += 1
135
+ DX_ptr += stride_dx
136
+ DY_ptr += stride_dy
137
+
138
+ tl.store(DW_ptr + row_block_id * stride_dw + cols, dw_row.to(dtype), mask=mask)
139
+ tl.store(DB_ptr + row_block_id * stride_db + cols, db_row.to(dtype), mask=mask)
140
+
141
+
142
+ def layer_norm_forward(X, W, B, eps):
143
+ shape = X.shape
144
+ dim = shape[-1]
145
+ X = X.view(-1, dim)
146
+ n_rows, n_cols = X.shape
147
+ BLOCK_SIZE, num_warps = calculate_settings(n_cols)
148
+ Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
149
+ Mean = torch.empty(n_rows, dtype=X.dtype, device=X.device)
150
+ RSTD = torch.empty(n_rows, dtype=X.dtype, device=X.device)
151
+ if X.shape[1] != W.shape[0]:
152
+ raise ValueError(
153
+ f"Incompatible dimensions: input feature size (X.shape[1]={X.shape[1]}) "
154
+ f"must match weight size (W.shape[0]={W.shape[0]})"
155
+ )
156
+
157
+ # XPU-specific optimization
158
+ kernel_args = {}
159
+ if X.device.type == "xpu":
160
+ kernel_args["grf_mode"] = "large"
161
+
162
+ _layer_norm_forward_kernel[(n_rows,)](
163
+ Y,
164
+ Y.stride(0),
165
+ X,
166
+ X.stride(0),
167
+ W,
168
+ W.stride(0),
169
+ B,
170
+ B.stride(0),
171
+ Mean,
172
+ Mean.stride(0),
173
+ RSTD,
174
+ RSTD.stride(0),
175
+ n_cols,
176
+ eps,
177
+ BLOCK_SIZE=BLOCK_SIZE,
178
+ num_warps=num_warps,
179
+ **kernel_args, # XPU-specific optimization
180
+ )
181
+ return Y.view(*shape), X, Mean, RSTD, BLOCK_SIZE, num_warps
182
+
183
+
184
+ def layer_norm_backward(dY, X, W, B, Mean, RSTD):
185
+ shape = dY.shape
186
+ dim = shape[-1]
187
+ dY = dY.view(-1, dim)
188
+ n_rows, n_cols = dY.shape
189
+
190
+ sm_count = 1
191
+ if X.device.type == "cuda":
192
+ sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
193
+ elif X.device.type == "xpu":
194
+ sm_count = torch.xpu.get_device_properties(X.device).gpu_eu_count
195
+
196
+ DX = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
197
+ _DW = torch.empty((sm_count, n_cols), dtype=W.dtype, device=W.device)
198
+ _DB = torch.empty((sm_count, n_cols), dtype=W.dtype, device=W.device)
199
+
200
+ BLOCK_SIZE, num_warps = calculate_settings(n_cols)
201
+ if n_cols > BLOCK_SIZE:
202
+ raise RuntimeError(
203
+ f"Feature dimension {n_cols} exceeds maximum supported size of {BLOCK_SIZE}. Consider using a smaller feature dimension."
204
+ )
205
+
206
+ rows_per_program = math.ceil(n_rows / sm_count)
207
+ grid = (sm_count,)
208
+ triton_dtype = (
209
+ tl.float32
210
+ if X.dtype == torch.float32
211
+ else tl.bfloat16
212
+ if X.dtype == torch.bfloat16
213
+ else tl.float16
214
+ if X.dtype == torch.float16
215
+ else tl.float32 # fallback to float32 for other types
216
+ )
217
+
218
+ # XPU-specific optimization
219
+ kernel_args = {}
220
+ if X.device.type == "xpu":
221
+ kernel_args.update({"grf_mode": "large", "num_warps": 32, "num_stages": 4})
222
+
223
+ _layer_norm_backward_kernel[grid](
224
+ X,
225
+ W,
226
+ Mean,
227
+ RSTD,
228
+ DX,
229
+ _DW,
230
+ _DB,
231
+ dY,
232
+ X.stride(0),
233
+ DX.stride(0),
234
+ _DW.stride(0),
235
+ _DB.stride(0),
236
+ dY.stride(0),
237
+ n_rows,
238
+ n_cols,
239
+ rows_per_program,
240
+ BLOCK_SIZE=BLOCK_SIZE,
241
+ dtype=triton_dtype,
242
+ **kernel_args, # XPU-specific optimization
243
+ )
244
+
245
+ DW = _DW.sum(dim=0).to(W.dtype)
246
+ DB = _DB.sum(dim=0).to(W.dtype)
247
+
248
+ DX = DX.view(*shape)
249
+ return DX, DW, DB
250
+
251
+
252
+ class LigerLayerNormFunction(torch.autograd.Function):
253
+ @staticmethod
254
+ @ensure_contiguous
255
+ def forward(ctx, X, W, B, eps):
256
+ Y, X, Mean, RSTD, BLOCK_SIZE, num_warps = layer_norm_forward(X, W, B, eps)
257
+ ctx.save_for_backward(X, W, B, Mean, RSTD)
258
+ return Y
259
+
260
+ @staticmethod
261
+ @ensure_contiguous
262
+ def backward(ctx, dY):
263
+ X, W, B, Mean, RSTD = ctx.saved_tensors
264
+ DX, DW, DB = layer_norm_backward(dY, X, W, B, Mean, RSTD)
265
+ return DX, DW, DB, None
build/torch-universal/liger_kernels/qwen2vl_mrope.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import triton
3
+ import triton.language as tl
4
+
5
+
6
+ @triton.jit
7
+ def _triton_qwen2vl_mrope(
8
+ q_ptr,
9
+ k_ptr,
10
+ cos,
11
+ sin,
12
+ sl,
13
+ bs: tl.constexpr,
14
+ n_qh: tl.constexpr,
15
+ n_kh: tl.constexpr,
16
+ hd: tl.constexpr,
17
+ pad_n_qh: tl.constexpr,
18
+ pad_n_kh: tl.constexpr,
19
+ pad_hd: tl.constexpr,
20
+ mrope_section_t: tl.constexpr,
21
+ mrope_section_h: tl.constexpr,
22
+ BLOCK_SIZE: tl.constexpr,
23
+ BACKWARD_PASS: tl.constexpr = False,
24
+ ):
25
+ pid = tl.program_id(0)
26
+
27
+ # locate start address
28
+ q_ptr = q_ptr + pid * (n_qh * hd)
29
+ k_ptr = k_ptr + pid * (n_kh * hd)
30
+
31
+ # ####################################################################
32
+ # get the cos(mθ_{i...d/2}) and sin(mθ_{i...d/2}) for token position
33
+ # m of this program instance
34
+ # ####################################################################
35
+
36
+ # 1. program instances are laid out in a 1D vector of size bsz * seq_len, which
37
+ # effectively represents a 2D grid of size [bsz, seq_len] with seq_len dimension
38
+ # being the fastest changing dimension. Thus we can simply do pid // sl to get the batch index
39
+ # and pid % sl to get the sequence index.
40
+ # 2. We only need the left half of cos and sin matrix because the right half is just
41
+ # a clone of the left half.
42
+ t_end = mrope_section_t
43
+ h_end = t_end + mrope_section_h
44
+
45
+ t_cos = cos + pid * hd
46
+ h_cos = t_cos + bs * sl * hd
47
+ w_cos = h_cos + bs * sl * hd
48
+ t_sin = sin + pid * hd
49
+ h_sin = t_sin + bs * sl * hd
50
+ w_sin = h_sin + bs * sl * hd
51
+
52
+ cos_offsets = tl.arange(0, pad_hd // 2)
53
+ t_mask = cos_offsets < t_end
54
+ h_mask = (t_end <= cos_offsets) & (cos_offsets < h_end)
55
+ w_mask = (h_end <= cos_offsets) & (cos_offsets < hd // 2)
56
+ t_cos_row = tl.load(t_cos + cos_offsets, mask=t_mask, other=0)
57
+ h_cos_row = tl.load(h_cos + cos_offsets, mask=h_mask, other=0)
58
+ w_cos_row = tl.load(w_cos + cos_offsets, mask=w_mask, other=0)
59
+ t_sin_row = tl.load(t_sin + cos_offsets, mask=t_mask, other=0)
60
+ h_sin_row = tl.load(h_sin + cos_offsets, mask=h_mask, other=0)
61
+ w_sin_row = tl.load(w_sin + cos_offsets, mask=w_mask, other=0)
62
+ cos_row = t_cos_row + h_cos_row + w_cos_row
63
+ sin_row = t_sin_row + h_sin_row + w_sin_row
64
+
65
+ # ####################################################################
66
+ # Load the left and right half of q and k for the current
67
+ # program instance (i.e. for the current token) separately
68
+ # ####################################################################
69
+ # left half of the head
70
+ first_half_q_offsets = tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
71
+ first_half_k_offsets = tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
72
+ first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (tl.arange(0, pad_hd // 2)[None, :] < hd // 2)
73
+ first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (tl.arange(0, pad_hd // 2)[None, :] < hd // 2)
74
+ q_tile_1 = tl.load(q_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to(sin_row.dtype)
75
+ k_tile_1 = tl.load(k_ptr + first_half_k_offsets, mask=first_k_mask, other=0).to(sin_row.dtype)
76
+
77
+ # right half of the head
78
+ second_half_q_offsets = first_half_q_offsets + (hd // 2)
79
+ second_half_k_offsets = first_half_k_offsets + (hd // 2)
80
+ second_q_mask = first_q_mask
81
+ second_k_mask = first_k_mask
82
+ q_tile_2 = tl.load(q_ptr + second_half_q_offsets, mask=second_q_mask, other=0).to(sin_row.dtype)
83
+ k_tile_2 = tl.load(k_ptr + second_half_k_offsets, mask=second_k_mask, other=0).to(sin_row.dtype)
84
+
85
+ if not BACKWARD_PASS:
86
+ # y = [x1, x2] * [cos, cos] + [-x2, x1] * [sin, sin]
87
+ new_q_tile_1 = q_tile_1 * cos_row - q_tile_2 * sin_row
88
+ tl.store(q_ptr + first_half_q_offsets, new_q_tile_1, mask=first_q_mask)
89
+ new_q_tile_2 = q_tile_2 * cos_row + q_tile_1 * sin_row
90
+ tl.store(q_ptr + second_half_q_offsets, new_q_tile_2, mask=second_q_mask)
91
+
92
+ new_k_tile_1 = k_tile_1 * cos_row - k_tile_2 * sin_row
93
+ tl.store(k_ptr + first_half_k_offsets, new_k_tile_1, mask=first_k_mask)
94
+ new_k_tile_2 = k_tile_2 * cos_row + k_tile_1 * sin_row
95
+ tl.store(k_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask)
96
+ else:
97
+ # with some math, we can get:
98
+ # dy = [dx1, dx2] * [cos, cos] + [-dx2, dx1] * [-sin, -sin]
99
+ new_q_tile_1 = q_tile_1 * cos_row + q_tile_2 * sin_row
100
+ tl.store(q_ptr + first_half_q_offsets, new_q_tile_1, mask=first_q_mask)
101
+ new_q_tile_2 = q_tile_2 * cos_row - q_tile_1 * sin_row
102
+ tl.store(q_ptr + second_half_q_offsets, new_q_tile_2, mask=second_q_mask)
103
+
104
+ new_k_tile_1 = k_tile_1 * cos_row + k_tile_2 * sin_row
105
+ tl.store(k_ptr + first_half_k_offsets, new_k_tile_1, mask=first_k_mask)
106
+ new_k_tile_2 = k_tile_2 * cos_row - k_tile_1 * sin_row
107
+ tl.store(k_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask)
108
+
109
+
110
+ def qwen2vl_mrope_forward(q, k, cos, sin, mrope_section):
111
+ # transpose it back to the physical shape because Triton looks at the physical storage
112
+ # note: q and k are incontiguous before the transformation and will become contiguous after transpose
113
+ q = q.transpose(1, 2)
114
+ k = k.transpose(1, 2)
115
+
116
+ batch_size, seq_len, n_q_head, head_dim = q.shape
117
+ n_kv_head = k.shape[2]
118
+ pad_hd = triton.next_power_of_2(head_dim)
119
+ pad_n_q_head = triton.next_power_of_2(n_q_head)
120
+ pad_n_kv_head = triton.next_power_of_2(n_kv_head)
121
+ BLOCK_SIZE = max(pad_n_q_head, pad_n_kv_head)
122
+
123
+ n_row = batch_size * seq_len
124
+
125
+ # ensure tensors passed into the kernel are contiguous. It will be no-op if they are already contiguous
126
+ q = q.contiguous()
127
+ k = k.contiguous()
128
+ cos = cos.contiguous()
129
+ sin = sin.contiguous()
130
+
131
+ _triton_qwen2vl_mrope[(n_row,)](
132
+ q,
133
+ k,
134
+ cos,
135
+ sin,
136
+ seq_len,
137
+ batch_size,
138
+ n_q_head,
139
+ n_kv_head,
140
+ head_dim,
141
+ pad_n_q_head,
142
+ pad_n_kv_head,
143
+ pad_hd,
144
+ mrope_section[0],
145
+ mrope_section[1],
146
+ BLOCK_SIZE=BLOCK_SIZE,
147
+ BACKWARD_PASS=False,
148
+ )
149
+ return q.transpose(1, 2), k.transpose(1, 2), cos, sin
150
+
151
+
152
+ def qwen2vl_mrope_backward(dq, dk, cos, sin, mrope_section):
153
+ dq = dq.transpose(1, 2)
154
+ dk = dk.transpose(1, 2)
155
+
156
+ batch_size, seq_len, n_q_head, head_dim = dq.shape
157
+ n_kv_head = dk.shape[2]
158
+ pad_hd = triton.next_power_of_2(head_dim)
159
+ pad_n_q_head = triton.next_power_of_2(n_q_head)
160
+ pad_n_kv_head = triton.next_power_of_2(n_kv_head)
161
+ BLOCK_SIZE = max(pad_n_q_head, pad_n_kv_head)
162
+
163
+ n_row = batch_size * seq_len
164
+
165
+ # ensure dq and dk are contiguous
166
+ dq = dq.contiguous()
167
+ dk = dk.contiguous()
168
+
169
+ # backward is similar to forward except swapping few ops
170
+ _triton_qwen2vl_mrope[(n_row,)](
171
+ dq,
172
+ dk,
173
+ cos,
174
+ sin,
175
+ seq_len,
176
+ batch_size,
177
+ n_q_head,
178
+ n_kv_head,
179
+ head_dim,
180
+ pad_n_q_head,
181
+ pad_n_kv_head,
182
+ pad_hd,
183
+ mrope_section[0],
184
+ mrope_section[1],
185
+ BLOCK_SIZE=BLOCK_SIZE,
186
+ BACKWARD_PASS=True,
187
+ )
188
+ return dq.transpose(1, 2), dk.transpose(1, 2)
189
+
190
+
191
+ class LigerQwen2VLMRopeFunction(torch.autograd.Function):
192
+ """
193
+ Triton implementation of the Qwen2VL Multimodal Rotary Positional Embedding (M-RoPE) operation.
194
+
195
+ Please find the corresponding HuggingFace implementation here:
196
+ https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py
197
+ """
198
+
199
+ @staticmethod
200
+ def forward(ctx, q, k, cos, sin, mrope_section, unsqueeze_dim=1):
201
+ """
202
+ q size: (bsz, n_q_head, seq_len, head_dim)
203
+ k size: (bsz, n_kv_head, seq_len, head_dim)
204
+ cos size: (3, bsz, seq_len, head_dim)
205
+ sin size: (3, bsz, seq_len, head_dim)
206
+ """
207
+ q, k, cos, sin = qwen2vl_mrope_forward(q, k, cos, sin, mrope_section)
208
+ ctx.save_for_backward(cos, sin)
209
+ ctx.mrope_section = mrope_section
210
+ return q, k
211
+
212
+ def backward(ctx, dq, dk):
213
+ """
214
+ dq size: (bsz, n_q_head, seq_len, head_dim)
215
+ dk size: (bsz, n_kv_head, seq_len, head_dim)
216
+ cos size: (3, bsz, seq_len, head_dim)
217
+ sin size: (3, bsz, seq_len, head_dim)
218
+ """
219
+ cos, sin = ctx.saved_tensors
220
+ mrope_section = ctx.mrope_section
221
+ dq, dk = qwen2vl_mrope_backward(dq, dk, cos, sin, mrope_section)
222
+ return dq, dk, None, None, None, None
build/torch-universal/liger_kernels/rms_norm.py ADDED
@@ -0,0 +1,365 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file incorporates code from Unsloth licensed under the Apache License, Version 2.0.
3
+ See the original Unsloth repository at https://github.com/unslothai/unsloth.
4
+
5
+ The following line
6
+ https://github.com/linkedin/Liger-Kernel/blob/7382a8761f9af679482b968f9348013d933947c7/src/liger_kernel/ops/rms_norm.py#L30
7
+ is based on code from Unsloth, located at:
8
+ https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/rms_layernorm.py#L22
9
+
10
+ Modifications made by Yanning Chen, 2024.
11
+ """
12
+
13
+ import math
14
+ import operator
15
+
16
+ import torch
17
+ import triton
18
+ import triton.language as tl
19
+
20
+ from utils import calculate_settings
21
+ from utils import compare_version
22
+ from utils import ensure_contiguous
23
+ from utils import torch_to_triton_dtype
24
+
25
+ if compare_version("triton", operator.ge, "3.0.0"):
26
+ try:
27
+ # typical import path with dispatch available
28
+ from triton.language.extra.libdevice import rsqrt
29
+ except ModuleNotFoundError:
30
+ # for working with NGC containers
31
+ from triton.language.extra.cuda.libdevice import rsqrt
32
+ else:
33
+ from triton.language.math import rsqrt
34
+
35
+
36
+ _CASTING_MODE_NONE: tl.constexpr = tl.constexpr(-1)
37
+ _CASTING_MODE_LLAMA: tl.constexpr = tl.constexpr(0)
38
+ _CASTING_MODE_GEMMA: tl.constexpr = tl.constexpr(1)
39
+
40
+
41
+ @triton.jit
42
+ def _rms_norm_forward_kernel(
43
+ Y_ptr,
44
+ Y_row_stride,
45
+ X_ptr,
46
+ X_row_stride,
47
+ W_ptr,
48
+ W_row_stride,
49
+ RSTD_ptr,
50
+ RSTD_row_stride,
51
+ n_cols,
52
+ eps,
53
+ offset,
54
+ casting_mode: tl.constexpr, # constexpr so the `if` blocks can be optimized out
55
+ BLOCK_SIZE: tl.constexpr,
56
+ ):
57
+ """
58
+ y_i = (x_i / (RMS)) * (offset + wi), RMS = sqrt(sum(x_i^2) / N)
59
+
60
+ Reference:
61
+ 1. https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
62
+ 2. https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/rms_layernorm.py#L22
63
+ 3. https://arxiv.org/pdf/1910.07467
64
+ """
65
+
66
+ row_idx = tl.program_id(0)
67
+ col_offsets = tl.arange(0, BLOCK_SIZE)
68
+ mask = col_offsets < n_cols
69
+
70
+ Y_ptr += row_idx * Y_row_stride
71
+ X_ptr += row_idx * X_row_stride
72
+ RSTD_ptr += row_idx * RSTD_row_stride
73
+
74
+ X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0)
75
+ X_row_dtype = X_row.dtype
76
+ W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0)
77
+
78
+ # On Llama, only rstd is computed on fp32
79
+ if casting_mode == _CASTING_MODE_LLAMA:
80
+ X_row = X_row.to(tl.float32)
81
+
82
+ # Gemma computes everything on fp32, and then casts back the output to the original dtype
83
+ if casting_mode == _CASTING_MODE_GEMMA:
84
+ W_row = W_row.to(tl.float32)
85
+ X_row = X_row.to(tl.float32)
86
+
87
+ if casting_mode == _CASTING_MODE_NONE:
88
+ eps = eps.to(X_row_dtype)
89
+ offset = offset.to(X_row_dtype)
90
+
91
+ mean_square = tl.sum(X_row * X_row, axis=0) / n_cols
92
+ rstd = rsqrt(mean_square + eps)
93
+
94
+ # We can save time by caching rms with minimal memory overhead
95
+ # because rms is much smaller compared to X_row, as rms is for each row.
96
+ # However, on the computation side, it can save 4 operations (*, sum, /, sqrt).
97
+ tl.store(RSTD_ptr, rstd)
98
+
99
+ X_row = X_row * rstd
100
+
101
+ # On Llama, the multiplication with the weight is done on the original dtype
102
+ if casting_mode == _CASTING_MODE_LLAMA:
103
+ X_row = X_row.to(X_row_dtype)
104
+
105
+ Y_row = X_row * (offset + W_row)
106
+
107
+ if casting_mode == _CASTING_MODE_GEMMA:
108
+ Y_row = Y_row.to(X_row_dtype)
109
+
110
+ tl.store(Y_ptr + col_offsets, Y_row, mask=mask)
111
+
112
+
113
+ @triton.jit
114
+ def _rms_norm_backward_kernel(
115
+ dY_ptr,
116
+ dY_row_stride,
117
+ dX_ptr,
118
+ dX_row_stride,
119
+ X_ptr,
120
+ X_row_stride,
121
+ X_dtype: tl.constexpr,
122
+ W_ptr,
123
+ W_row_stride,
124
+ RSTD_ptr,
125
+ RSTD_row_stride,
126
+ dW_ptr,
127
+ dW_row_stride,
128
+ n_rows,
129
+ n_cols,
130
+ offset,
131
+ rows_per_program: tl.constexpr,
132
+ casting_mode: tl.constexpr,
133
+ BLOCK_SIZE: tl.constexpr,
134
+ ):
135
+ """
136
+ dx = (1 / RMS) * [dy * (w + offset - (1 / N) * (1 / RMS^2) * ((dy * (w + offset)) dot x) * x]. * means element-wise multiplication, whileas dot means dot product
137
+ dw = sum(dy * (x / RMS)). summation over BxT dimension
138
+ """
139
+
140
+ row_block_id = tl.program_id(0)
141
+ row_start = row_block_id * rows_per_program
142
+ row_end = min((row_block_id + 1) * rows_per_program, n_rows)
143
+ col_offsets = tl.arange(0, BLOCK_SIZE)
144
+ mask = col_offsets < n_cols
145
+
146
+ dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
147
+
148
+ dY_ptr += row_start * dY_row_stride
149
+ dX_ptr += row_start * dX_row_stride
150
+
151
+ X_ptr += row_start * X_row_stride
152
+ RSTD_ptr += row_start
153
+
154
+ W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0.0)
155
+ W_row = W_row + offset
156
+
157
+ for _ in range(row_start, row_end):
158
+ dY_row = tl.load(dY_ptr + col_offsets, mask=mask, other=0.0)
159
+ X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0.0)
160
+
161
+ # Get cached rms
162
+ rstd_row = tl.load(RSTD_ptr)
163
+
164
+ X_row = X_row.to(tl.float32)
165
+
166
+ # Different bacward graphs for different casting modes
167
+ if casting_mode == _CASTING_MODE_LLAMA:
168
+ m = (dY_row * W_row).to(tl.float32)
169
+
170
+ elif casting_mode == _CASTING_MODE_GEMMA:
171
+ dY_row = dY_row.to(tl.float32)
172
+ m = dY_row * W_row
173
+ else:
174
+ m = dY_row * W_row
175
+
176
+ dX_row = rstd_row * m
177
+
178
+ dX_row += (rstd_row) * (-(1 / n_cols) * rstd_row * rstd_row * tl.sum(m * X_row, axis=0) * X_row)
179
+
180
+ # calculate the gradient of W
181
+ if casting_mode == _CASTING_MODE_LLAMA:
182
+ dW_row += dY_row * (X_row * rstd_row).to(X_dtype)
183
+ else:
184
+ # here X_row is already in fp32 (see previous if block)
185
+ dW_row += dY_row * (X_row * rstd_row)
186
+
187
+ tl.store(dX_ptr + col_offsets, dX_row.to(X_dtype), mask=mask)
188
+
189
+ dY_ptr += dY_row_stride
190
+ dX_ptr += dX_row_stride
191
+ X_ptr += X_row_stride
192
+ RSTD_ptr += RSTD_row_stride
193
+
194
+ tl.store(dW_ptr + row_block_id * dW_row_stride + col_offsets, dW_row, mask=mask)
195
+
196
+
197
+ _str_to_casting_mode = {
198
+ "llama": _CASTING_MODE_LLAMA.value,
199
+ "gemma": _CASTING_MODE_GEMMA.value,
200
+ "none": _CASTING_MODE_NONE.value,
201
+ }
202
+
203
+
204
+ def rms_norm_forward(X, W, eps, offset, casting_mode):
205
+ if not isinstance(casting_mode, int):
206
+ assert casting_mode in _str_to_casting_mode, f"Invalid casting mode: {casting_mode}"
207
+ casting_mode = _str_to_casting_mode[casting_mode]
208
+ else:
209
+ assert casting_mode in _str_to_casting_mode.values(), f"Invalid casting mode: {casting_mode}"
210
+
211
+ shape = X.shape
212
+ dim = shape[-1]
213
+ X = X.view(-1, dim)
214
+ n_rows, n_cols = X.shape
215
+ BLOCK_SIZE, num_warps = calculate_settings(n_cols)
216
+
217
+ Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
218
+ # RSTD is to cache rstd for each row
219
+ # RSTD is always computed/stored in fp32 if we are using Llama or Gemma casting mode
220
+ rstd_dtype = torch.float32 if casting_mode in (_CASTING_MODE_LLAMA.value, _CASTING_MODE_GEMMA.value) else X.dtype
221
+ RSTD = torch.empty(n_rows, dtype=rstd_dtype, device=X.device)
222
+
223
+ # Check constraints.
224
+ assert X.shape[1] == W.shape[0], "Incompatible hidden size dimension between tensor1.shape[1] and tensor2.shape[0]"
225
+
226
+ # XPU-specific optimization
227
+ kernel_args = {}
228
+ if X.device.type == "xpu":
229
+ kernel_args["grf_mode"] = "large"
230
+ _rms_norm_forward_kernel[(n_rows,)](
231
+ Y,
232
+ Y.stride(0),
233
+ X,
234
+ X.stride(0),
235
+ W,
236
+ W.stride(0),
237
+ RSTD,
238
+ RSTD.stride(0),
239
+ n_cols,
240
+ eps,
241
+ offset,
242
+ casting_mode,
243
+ BLOCK_SIZE=BLOCK_SIZE,
244
+ num_warps=num_warps,
245
+ **kernel_args, # XPU-specific optimization
246
+ )
247
+ return Y.view(*shape), X, RSTD, BLOCK_SIZE, num_warps, casting_mode
248
+
249
+
250
+ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warps, in_place):
251
+ shape = dY.shape
252
+ dim = shape[-1]
253
+ dY = dY.view(-1, dim)
254
+ n_rows, n_cols = dY.shape
255
+
256
+ sm_count = 1
257
+ if X.device.type == "cuda":
258
+ sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
259
+ elif X.device.type == "xpu":
260
+ sm_count = torch.xpu.get_device_properties(X.device).gpu_eu_count
261
+
262
+ # fp32 for numerical stability especially.
263
+ _dW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)
264
+
265
+ if n_cols > BLOCK_SIZE:
266
+ raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
267
+ rows_per_program = math.ceil(n_rows / sm_count)
268
+ grid = (sm_count,)
269
+
270
+ if in_place is True:
271
+ dX = dY
272
+ else:
273
+ dX = torch.zeros_like(dY)
274
+
275
+ # XPU-specific optimization
276
+ kernel_args = {}
277
+ if X.device.type == "xpu":
278
+ kernel_args["grf_mode"] = "large"
279
+
280
+ _rms_norm_backward_kernel[grid](
281
+ dY,
282
+ dY.stride(0),
283
+ dX,
284
+ dX.stride(0),
285
+ X,
286
+ X.stride(0),
287
+ torch_to_triton_dtype[X.dtype],
288
+ W,
289
+ W.stride(0),
290
+ RSTD,
291
+ RSTD.stride(0),
292
+ _dW,
293
+ _dW.stride(0),
294
+ n_rows,
295
+ n_cols,
296
+ offset,
297
+ rows_per_program,
298
+ casting_mode,
299
+ BLOCK_SIZE=BLOCK_SIZE,
300
+ num_warps=num_warps,
301
+ **kernel_args, # XPU-specific optimization
302
+ )
303
+ dX = dX.view(*shape)
304
+ dW = _dW.sum(dim=0).to(W.dtype)
305
+
306
+ return dX, dW
307
+
308
+
309
+ class LigerRMSNormFunction(torch.autograd.Function):
310
+ """
311
+ Performs RMSNorm (Root Mean Square Normalization), which normalizes the input tensor `X` using the
312
+ weight tensor `W`, with an optional offset and casting mode.
313
+
314
+ Some models use an 'offset' to shift the weight tensor `W` by a constant value. For example, Gemma
315
+ uses an offset of 1.0, so the computation becomes `(X / RMS(X)) * (W + 1.0)` instead of the usual
316
+ `(X / RMS(X)) * W`. You can pass the offset value as an argument to the forward function.
317
+
318
+ In addition, different models cast their inputs at different places during RMSNorm computation. For
319
+ example, Gemma casts everything to fp32 nefore starting the computation, while Llama casts only the
320
+ inverse RMS to fp32. You can specify the casting mode using the `casting_mode` argument. We currently
321
+ support the following casting modes (they match HuggingFace Transformers' implementations):
322
+ - 'llama': matches the Llama implementation, where only the inverse RMS is computed on fp32.
323
+ - 'gemma': matches the Gemma implementation, where everything is cast to fp32, then computed, then cast back to the original dtype.
324
+ - 'none': no casting is done. The computation is done in the original dtype. This saves memory and is slightly faster, but has more error w.r.t. the original implementation.
325
+
326
+ `in_place` option means whether to in_place modify dY to store dX. This is default to `True` to save memory. However, under certain cases, it can produce incorrect inputs.
327
+ For example, gemma2 uses two rmsnorm sequentially with residual in between. The resesidual part needs dY so it cannot be modified in-place.
328
+ Therefore, for the patching of RMSNorm in gemma2, we set `in_place` to `False`
329
+ """
330
+
331
+ @staticmethod
332
+ @ensure_contiguous
333
+ def forward(ctx, X, W, eps, offset=0.0, casting_mode="llama", in_place=True):
334
+ """
335
+ X: (B, T, H) or (BxT, H)
336
+ W: (H,)
337
+ """
338
+ Y, X, RSTD, BLOCK_SIZE, num_warps, casting_mode = rms_norm_forward(X, W, eps, offset, casting_mode)
339
+ ctx.offset = offset
340
+ ctx.casting_mode = casting_mode
341
+ ctx.in_place = in_place
342
+ ctx.BLOCK_SIZE = BLOCK_SIZE
343
+ ctx.num_warps = num_warps
344
+ ctx.save_for_backward(X, W, RSTD)
345
+ return Y
346
+
347
+ @staticmethod
348
+ @ensure_contiguous
349
+ def backward(ctx, dY):
350
+ """
351
+ Y: (B, T, H) or (BxT, H)
352
+ """
353
+ X, W, RSTD = ctx.saved_tensors
354
+ dX, dW = rms_norm_backward(
355
+ dY,
356
+ X,
357
+ W,
358
+ RSTD,
359
+ ctx.offset,
360
+ ctx.casting_mode,
361
+ ctx.BLOCK_SIZE,
362
+ ctx.num_warps,
363
+ ctx.in_place,
364
+ )
365
+ return dX, dW, None, None, None, None
build/torch-universal/liger_kernels/rope.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import triton
3
+ import triton.language as tl
4
+
5
+
6
+ @triton.jit
7
+ def _triton_rope(
8
+ q_ptr,
9
+ q_row_stride,
10
+ k_ptr,
11
+ k_row_stride,
12
+ cos,
13
+ cos_row_stride,
14
+ sin,
15
+ sin_row_stride,
16
+ sl,
17
+ bs: tl.constexpr,
18
+ cos_bs: tl.constexpr,
19
+ n_qh: tl.constexpr,
20
+ n_kh: tl.constexpr,
21
+ hd: tl.constexpr,
22
+ pad_n_qh: tl.constexpr,
23
+ pad_n_kh: tl.constexpr,
24
+ pad_hd: tl.constexpr,
25
+ BLOCK_SIZE: tl.constexpr,
26
+ BACKWARD_PASS: tl.constexpr = False,
27
+ ):
28
+ # q size: (bsz, seq_len, num_q_heads, head_dim)
29
+ # q stride: (seq_len * num_q_heads * head_dim, num_q_heads * head_dim, head_dim, 1)
30
+ # k size: (bsz, seq_len, num_kv_heads, head_dim)
31
+ # k stride: (seq_len * num_kv_heads * head_dim, num_kv_heads * head_dim, head_dim, 1)
32
+
33
+ # cos size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
34
+ # stride: (seq_len * head_dim, head_dim, 1)
35
+ pid = tl.program_id(0)
36
+
37
+ # locate start address
38
+ q_ptr = q_ptr + pid * q_row_stride
39
+ k_ptr = k_ptr + pid * k_row_stride
40
+
41
+ # ####################################################################
42
+ # get the cos(mθ_{i...d/2}) and sin(mθ_{i...d/2}) for token position
43
+ # m of this program instance
44
+ # ####################################################################
45
+
46
+ # 1. program instances are laid out in a 1D vector of size bsz * seq_len, which
47
+ # effectively represents a 2D grid of size [bsz, seq_len] with seq_len dimension
48
+ # being the fastest changing dimension. Thus we can simply do pid // sl to get the batch index
49
+ # and pid % sl to get the sequence index.
50
+ # 2. We only need the left half of cos and sin matrix because the right half is just
51
+ # a clone of the left half.
52
+ batch_idx = pid // sl
53
+ cos_row_idx = pid % sl
54
+ cos = cos + tl.where(
55
+ cos_bs == 1,
56
+ cos_row_idx * cos_row_stride,
57
+ batch_idx * (sl * cos_row_stride) + cos_row_idx * cos_row_stride,
58
+ )
59
+ sin = sin + tl.where(
60
+ cos_bs == 1,
61
+ cos_row_idx * sin_row_stride,
62
+ batch_idx * (sl * sin_row_stride) + cos_row_idx * sin_row_stride,
63
+ )
64
+
65
+ cos_offsets = tl.arange(0, pad_hd // 2)
66
+ cos_mask = cos_offsets < hd // 2
67
+ cos_row = tl.load(cos + cos_offsets, mask=cos_mask, other=0)
68
+ sin_row = tl.load(sin + cos_offsets, mask=cos_mask, other=0)
69
+
70
+ # ####################################################################
71
+ # Load the left and right half of q and k for the current
72
+ # program instance (i.e. for the current token) separately
73
+ # ####################################################################
74
+ # left half of the head
75
+ first_half_q_offsets = tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
76
+ first_half_k_offsets = tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
77
+ first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (tl.arange(0, pad_hd // 2)[None, :] < hd // 2)
78
+ first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (tl.arange(0, pad_hd // 2)[None, :] < hd // 2)
79
+ q_tile_1 = tl.load(q_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to(sin_row.dtype)
80
+ k_tile_1 = tl.load(k_ptr + first_half_k_offsets, mask=first_k_mask, other=0).to(sin_row.dtype)
81
+
82
+ # right half of the head
83
+ second_half_q_offsets = first_half_q_offsets + (hd // 2)
84
+ second_half_k_offsets = first_half_k_offsets + (hd // 2)
85
+ second_q_mask = first_q_mask
86
+ second_k_mask = first_k_mask
87
+ q_tile_2 = tl.load(q_ptr + second_half_q_offsets, mask=second_q_mask, other=0).to(sin_row.dtype)
88
+ k_tile_2 = tl.load(k_ptr + second_half_k_offsets, mask=second_k_mask, other=0).to(sin_row.dtype)
89
+
90
+ if not BACKWARD_PASS:
91
+ # y = [x1, x2] * [cos, cos] + [-x2, x1] * [sin, sin]
92
+ new_q_tile_1 = q_tile_1 * cos_row - q_tile_2 * sin_row
93
+ tl.store(q_ptr + first_half_q_offsets, new_q_tile_1, mask=first_q_mask)
94
+ new_q_tile_2 = q_tile_2 * cos_row + q_tile_1 * sin_row
95
+ tl.store(q_ptr + second_half_q_offsets, new_q_tile_2, mask=second_q_mask)
96
+
97
+ new_k_tile_1 = k_tile_1 * cos_row - k_tile_2 * sin_row
98
+ tl.store(k_ptr + first_half_k_offsets, new_k_tile_1, mask=first_k_mask)
99
+ new_k_tile_2 = k_tile_2 * cos_row + k_tile_1 * sin_row
100
+ tl.store(k_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask)
101
+ else:
102
+ # with some math, we can get:
103
+ # dy = [dx1, dx2] * [cos, cos] + [-dx2, dx1] * [-sin, -sin]
104
+ new_q_tile_1 = q_tile_1 * cos_row + q_tile_2 * sin_row
105
+ tl.store(q_ptr + first_half_q_offsets, new_q_tile_1, mask=first_q_mask)
106
+ new_q_tile_2 = q_tile_2 * cos_row - q_tile_1 * sin_row
107
+ tl.store(q_ptr + second_half_q_offsets, new_q_tile_2, mask=second_q_mask)
108
+
109
+ new_k_tile_1 = k_tile_1 * cos_row + k_tile_2 * sin_row
110
+ tl.store(k_ptr + first_half_k_offsets, new_k_tile_1, mask=first_k_mask)
111
+ new_k_tile_2 = k_tile_2 * cos_row - k_tile_1 * sin_row
112
+ tl.store(k_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask)
113
+
114
+
115
+ def rope_forward(q, k, cos, sin):
116
+ # transpose it back to the physical shape because Triton looks at the physical storage
117
+ # note: q and k are incontiguous before the transformation and will become contiguous after transpose
118
+ q = q.transpose(1, 2)
119
+ k = k.transpose(1, 2)
120
+
121
+ batch_size, seq_len, n_q_head, head_dim = q.shape
122
+ n_kv_head = k.shape[2]
123
+ pad_hd = triton.next_power_of_2(head_dim)
124
+ pad_n_q_head = triton.next_power_of_2(n_q_head)
125
+ pad_n_kv_head = triton.next_power_of_2(n_kv_head)
126
+ BLOCK_SIZE = max(pad_n_q_head, pad_n_kv_head)
127
+
128
+ n_row = batch_size * seq_len
129
+
130
+ # ensure tensors passed into the kernel are contiguous. It will be no-op if they are already contiguous
131
+ q = q.contiguous()
132
+ k = k.contiguous()
133
+ cos = cos.contiguous()
134
+ sin = sin.contiguous()
135
+ cos_batch_size = cos.shape[0]
136
+
137
+ _triton_rope[(n_row,)](
138
+ q,
139
+ q.stride(1),
140
+ k,
141
+ k.stride(1),
142
+ cos,
143
+ cos.stride(-2),
144
+ sin,
145
+ sin.stride(-2),
146
+ seq_len,
147
+ batch_size,
148
+ cos_batch_size,
149
+ n_q_head,
150
+ n_kv_head,
151
+ head_dim,
152
+ pad_n_q_head,
153
+ pad_n_kv_head,
154
+ pad_hd,
155
+ BLOCK_SIZE=BLOCK_SIZE,
156
+ BACKWARD_PASS=False,
157
+ )
158
+ return q.transpose(1, 2), k.transpose(1, 2), cos, sin
159
+
160
+
161
+ def rope_backward(dq, dk, cos, sin):
162
+ dq = dq.transpose(1, 2)
163
+ dk = dk.transpose(1, 2)
164
+
165
+ batch_size, seq_len, n_q_head, head_dim = dq.shape
166
+ cos_batch_size = cos.shape[0]
167
+ n_kv_head = dk.shape[2]
168
+ pad_hd = triton.next_power_of_2(head_dim)
169
+ pad_n_q_head = triton.next_power_of_2(n_q_head)
170
+ pad_n_kv_head = triton.next_power_of_2(n_kv_head)
171
+ BLOCK_SIZE = max(pad_n_q_head, pad_n_kv_head)
172
+
173
+ n_row = batch_size * seq_len
174
+
175
+ # ensure dq and dk are contiguous
176
+ dq = dq.contiguous()
177
+ dk = dk.contiguous()
178
+
179
+ # backward is similar to forward except swapping few ops
180
+ _triton_rope[(n_row,)](
181
+ dq,
182
+ dq.stride(1),
183
+ dk,
184
+ dk.stride(1),
185
+ cos,
186
+ cos.stride(-2),
187
+ sin,
188
+ sin.stride(-2),
189
+ seq_len,
190
+ batch_size,
191
+ cos_batch_size,
192
+ n_q_head,
193
+ n_kv_head,
194
+ head_dim,
195
+ pad_n_q_head,
196
+ pad_n_kv_head,
197
+ pad_hd,
198
+ BLOCK_SIZE=BLOCK_SIZE,
199
+ BACKWARD_PASS=True,
200
+ )
201
+ return dq.transpose(1, 2), dk.transpose(1, 2)
202
+
203
+
204
+ class LigerRopeFunction(torch.autograd.Function):
205
+ """
206
+ Triton implementation of the Rotary Positional Embedding (RoPE) operation. Please note that
207
+ this implements the HuggingFace Llama & Mistral version, whose rotation matrix is slightly different
208
+ than the original RoPE paper.
209
+
210
+ Please find the corresponding HuggingFace implementation here:
211
+ https://github.com/huggingface/transformers/blob/v4.40.2/src/transformers/models/llama/modeling_llama.py#L184
212
+
213
+ For more details about the rotation matrix used here, please refer to:
214
+ https://discuss.huggingface.co/t/is-llama-rotary-embedding-implementation-correct/44509/2
215
+ """
216
+
217
+ @staticmethod
218
+ def forward(ctx, q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
219
+ """
220
+ q size: (bsz, n_q_head, seq_len, head_dim)
221
+ k size: (bsz, n_kv_head, seq_len, head_dim)
222
+ cos size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
223
+ sin size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
224
+ """
225
+ q, k, cos, sin = rope_forward(q, k, cos, sin)
226
+ ctx.save_for_backward(cos, sin)
227
+ return q, k
228
+
229
+ def backward(ctx, dq, dk):
230
+ """
231
+ dq size: (bsz, n_q_head, seq_len, head_dim)
232
+ dk size: (bsz, n_kv_head, seq_len, head_dim)
233
+ cos size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
234
+ sin size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
235
+ """
236
+
237
+ cos, sin = ctx.saved_tensors
238
+ dq, dk = rope_backward(dq, dk, cos, sin)
239
+ return dq, dk, None, None, None, None
build/torch-universal/liger_kernels/swiglu.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from utils import calculate_settings
6
+ from utils import ensure_contiguous
7
+
8
+
9
+ @triton.jit
10
+ def silu(x):
11
+ return x * tl.sigmoid(x)
12
+
13
+
14
+ @triton.jit
15
+ def _swiglu_forward_kernel(a_ptr, b_ptr, c_ptr, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr):
16
+ program_id = tl.program_id(0).to(tl.int64)
17
+
18
+ # locate start index
19
+ a_ptr += program_id * stride
20
+ b_ptr += program_id * stride
21
+ c_ptr += program_id * stride
22
+
23
+ col_offsets = tl.arange(0, BLOCK_SIZE)
24
+ mask = col_offsets < n_cols
25
+
26
+ # sigmoid requires type float32
27
+ a_row = tl.load(a_ptr + col_offsets, mask=mask, other=0).to(tl.float32)
28
+ b_row = tl.load(b_ptr + col_offsets, mask=mask, other=0)
29
+ c_row = silu(a_row) * b_row
30
+ tl.store(c_ptr + col_offsets, c_row, mask=mask)
31
+
32
+
33
+ @triton.jit
34
+ def _swiglu_backward_kernel(dc_ptr, a_ptr, b_ptr, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr):
35
+ program_id = tl.program_id(0).to(tl.int64)
36
+
37
+ # locate start index
38
+ dc_ptr += program_id * stride
39
+ a_ptr += program_id * stride
40
+ b_ptr += program_id * stride
41
+
42
+ col_offsets = tl.arange(0, BLOCK_SIZE)
43
+ mask = col_offsets < n_cols
44
+
45
+ dc_row = tl.load(dc_ptr + col_offsets, mask=mask, other=0)
46
+ # sigmoid requires type float32
47
+ a_row = tl.load(a_ptr + col_offsets, mask=mask, other=0).to(tl.float32)
48
+ b_row = tl.load(b_ptr + col_offsets, mask=mask, other=0)
49
+
50
+ # recomputation to save memory
51
+ sig_a = tl.sigmoid(a_row)
52
+ silu_a = a_row * sig_a
53
+ db_row = dc_row * silu_a
54
+ da_row = dc_row * (silu_a * (1 - sig_a) + sig_a) * b_row
55
+
56
+ tl.store(a_ptr + col_offsets, da_row, mask=mask)
57
+ tl.store(b_ptr + col_offsets, db_row, mask=mask)
58
+
59
+
60
+ def swiglu_forward(a, b):
61
+ ori_shape = a.shape
62
+
63
+ n_cols = ori_shape[-1]
64
+ a = a.view(-1, n_cols)
65
+ b = b.view(-1, n_cols)
66
+ c = torch.empty_like(a)
67
+ n_rows = a.shape[0]
68
+
69
+ BLOCK_SIZE, num_warps = calculate_settings(n_cols)
70
+
71
+ _swiglu_forward_kernel[(n_rows,)](
72
+ a,
73
+ b,
74
+ c,
75
+ c.stride(-2),
76
+ n_cols=n_cols,
77
+ BLOCK_SIZE=BLOCK_SIZE,
78
+ num_warps=num_warps,
79
+ )
80
+ return a, b, c.view(*ori_shape)
81
+
82
+
83
+ def swiglu_backward(a, b, dc):
84
+ ori_shape = dc.shape
85
+ n_cols = ori_shape[-1]
86
+ dc = dc.view(-1, n_cols)
87
+ n_rows = dc.shape[0]
88
+
89
+ BLOCK_SIZE, num_warps = calculate_settings(n_cols)
90
+
91
+ _swiglu_backward_kernel[(n_rows,)](
92
+ dc,
93
+ a,
94
+ b,
95
+ dc.stride(-2),
96
+ n_cols=n_cols,
97
+ BLOCK_SIZE=BLOCK_SIZE,
98
+ num_warps=num_warps,
99
+ )
100
+ return a.view(*ori_shape), b.view(*ori_shape)
101
+
102
+
103
+ class LigerSiLUMulFunction(torch.autograd.Function):
104
+ @staticmethod
105
+ @ensure_contiguous
106
+ def forward(ctx, a, b):
107
+ a, b, c = swiglu_forward(a, b)
108
+ ctx.save_for_backward(a, b)
109
+ return c
110
+
111
+ @staticmethod
112
+ @ensure_contiguous
113
+ def backward(ctx, dc):
114
+ a, b = ctx.saved_tensors
115
+ a, b = swiglu_backward(a, b, dc)
116
+ return a, b
build/torch-universal/liger_kernels/tvd.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Literal
2
+ from typing import Optional
3
+
4
+ import torch
5
+ import triton
6
+ import triton.language as tl
7
+
8
+ from utils import ensure_contiguous
9
+
10
+ MAX_FUSED_SIZE = 65536 // 4
11
+
12
+ REDUCTION_LITERAL = Literal["none", "sum", "mean", "batchmean"]
13
+
14
+ _REDUCTION_MODE_NONE = tl.constexpr(0)
15
+ _REDUCTION_MODE_SUM = tl.constexpr(1)
16
+ _REDUCTION_MODE_MEAN = tl.constexpr(2)
17
+ _REDUCTION_MODE_BATCHMEAN = tl.constexpr(3)
18
+
19
+ _str_to_reduction_mode = {
20
+ "none": _REDUCTION_MODE_NONE.value,
21
+ "sum": _REDUCTION_MODE_SUM.value,
22
+ "mean": _REDUCTION_MODE_MEAN.value,
23
+ "batchmean": _REDUCTION_MODE_BATCHMEAN.value,
24
+ }
25
+
26
+
27
+ def get_num_warps(BLOCK_SIZE):
28
+ num_warps = 4
29
+ if BLOCK_SIZE >= 32768:
30
+ num_warps = 32
31
+ elif BLOCK_SIZE >= 8192:
32
+ num_warps = 16
33
+ elif BLOCK_SIZE >= 2048:
34
+ num_warps = 8
35
+
36
+ return num_warps
37
+
38
+
39
+ @triton.jit
40
+ def _tv_distance_kernel(
41
+ p_ptr,
42
+ p_stride,
43
+ q_ptr,
44
+ q_stride,
45
+ loss_ptr,
46
+ loss_stride,
47
+ grads_ptr,
48
+ grads_stride,
49
+ label_ptr,
50
+ ignore_index: tl.constexpr,
51
+ n_cols,
52
+ BLOCK_SIZE: tl.constexpr,
53
+ HAS_LABEL: tl.constexpr,
54
+ reduction: tl.constexpr = _REDUCTION_MODE_BATCHMEAN,
55
+ ):
56
+ pid = tl.program_id(0).to(tl.int64)
57
+ p_ptr += pid * p_stride
58
+ q_ptr += pid * q_stride
59
+ loss_ptr += pid * loss_stride
60
+ grads_ptr += pid * grads_stride
61
+ label_ptr += pid
62
+
63
+ base_offsets = tl.arange(0, BLOCK_SIZE)
64
+
65
+ if HAS_LABEL:
66
+ label = tl.load(label_ptr)
67
+ if label == ignore_index:
68
+ for i in range(0, n_cols, BLOCK_SIZE):
69
+ offsets = i + base_offsets
70
+ mask = offsets < n_cols
71
+ tl.store(grads_ptr + offsets, 0.0, mask=mask)
72
+ if reduction == _REDUCTION_MODE_NONE:
73
+ tl.store(loss_ptr + offsets, 0.0, mask=mask)
74
+ return
75
+
76
+ loss_sum = 0.0
77
+ for i in range(0, n_cols, BLOCK_SIZE):
78
+ offsets = i + base_offsets
79
+ mask = offsets < n_cols
80
+
81
+ p = tl.load(p_ptr + offsets, mask=mask, other=0.0)
82
+ q = tl.load(q_ptr + offsets, mask=mask, other=0.0)
83
+
84
+ # TVD(P || Q) = 0.5 * |P - Q|
85
+ tv_loss = 0.5 * tl.abs(p - q)
86
+
87
+ grad_res = tl.where(p > q, 0.5, -0.5)
88
+
89
+ tl.store(grads_ptr + offsets, grad_res, mask=mask)
90
+
91
+ if reduction == _REDUCTION_MODE_NONE:
92
+ tl.store(loss_ptr + offsets, tv_loss, mask=mask)
93
+ else:
94
+ loss_sum += tl.sum(tv_loss, axis=0)
95
+
96
+ if reduction != _REDUCTION_MODE_NONE:
97
+ tl.store(loss_ptr, loss_sum)
98
+
99
+
100
+ def tv_distance_forward_triton(p, q, shift_labels, reduction, ignore_index, has_label):
101
+ BT, V = p.shape
102
+
103
+ BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
104
+ num_warps = get_num_warps(BLOCK_SIZE)
105
+
106
+ grid = (BT,)
107
+
108
+ reduction = _str_to_reduction_mode[reduction]
109
+
110
+ out_size = (BT, V) if reduction == _REDUCTION_MODE_NONE.value else (BT,)
111
+ output_tensor = torch.zeros(out_size, device=p.device, dtype=torch.float32)
112
+ grads = torch.empty_like(p)
113
+
114
+ n_non_ignore = (shift_labels != ignore_index).sum().item() if has_label else BT
115
+
116
+ _tv_distance_kernel[grid](
117
+ p,
118
+ p.stride(0),
119
+ q,
120
+ q.stride(0),
121
+ output_tensor,
122
+ output_tensor.stride(0),
123
+ grads,
124
+ grads.stride(0),
125
+ shift_labels if has_label else torch.empty(1, device=p.device),
126
+ ignore_index,
127
+ V,
128
+ BLOCK_SIZE=BLOCK_SIZE,
129
+ HAS_LABEL=has_label,
130
+ num_warps=num_warps,
131
+ reduction=reduction,
132
+ )
133
+
134
+ if reduction == _REDUCTION_MODE_BATCHMEAN.value:
135
+ return output_tensor.sum() / n_non_ignore, grads / n_non_ignore
136
+ elif reduction == _REDUCTION_MODE_SUM.value:
137
+ return output_tensor.sum(dim=0), grads
138
+ elif reduction == _REDUCTION_MODE_MEAN.value:
139
+ return output_tensor.sum() / (n_non_ignore * V), grads / (n_non_ignore * V)
140
+ else:
141
+ return output_tensor, grads
142
+
143
+
144
+ def tvd_backward_triton(grad_output, grads):
145
+ # If cross entropy is the last layer, grad_output is 1.0. Skip the mul then.
146
+ if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
147
+ return grads
148
+
149
+ return grads * grad_output
150
+
151
+
152
+ class LigerTVDLossFunction(torch.autograd.Function):
153
+ """
154
+ Class implementing the forward and backward pass for the Total Variation Distance Loss using Triton.
155
+ """
156
+
157
+ @staticmethod
158
+ @ensure_contiguous
159
+ def forward(
160
+ ctx,
161
+ p: torch.Tensor,
162
+ q: torch.Tensor,
163
+ shift_labels: Optional[torch.Tensor] = None,
164
+ reduction: REDUCTION_LITERAL = "batchmean",
165
+ ignore_index: int = -100,
166
+ ) -> torch.Tensor:
167
+ """A forward pass for the Total Variation Distance Loss.
168
+
169
+ Args:
170
+ ctx: Torch autograd context
171
+ p (torch.Tensor): A tensor of shape (BT, V) containing the first distribution.
172
+ q (torch.Tensor): A tensor of shape (BT, V) containing the second distribution.
173
+ shift_labels (Optional[torch.Tensor]): A tensor of shape (BT,) containing the labels.
174
+ reduction (REDUCTION_LITERAL, optional): The reduction method to be applied. Defaults to "batchmean".
175
+ ignore_index (int, optional): The index to ignore during loss calculation. Defaults to -100.
176
+
177
+ Returns:
178
+ torch.Tensor: The computed Total Variation Distance Loss.
179
+ """
180
+ has_label = False
181
+ if shift_labels is not None:
182
+ assert shift_labels.shape == (p.shape[0],), (
183
+ f"the shape of shift_labels must be (BT,). Got: {shift_labels.shape}"
184
+ )
185
+ shift_labels = shift_labels.contiguous()
186
+ has_label = True
187
+
188
+ loss, grads = tv_distance_forward_triton(p, q, shift_labels, reduction, ignore_index, has_label)
189
+ ctx.save_for_backward(grads)
190
+ return loss
191
+
192
+ @staticmethod
193
+ @ensure_contiguous
194
+ def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
195
+ """A backward pass for the Total Variation Distance Loss.
196
+
197
+ Args:
198
+ ctx: Torch autograd context
199
+ grad_output (torch.Tensor): The gradient of the loss with respect to the output.
200
+
201
+ Returns:
202
+ tuple[torch.Tensor, None, None, None, None]: The gradient of the loss with respect to the inputs.
203
+ """
204
+ (grads,) = ctx.saved_tensors
205
+ grads = tvd_backward_triton(grad_output, grads)
206
+
207
+ return grads, None, None, None, None
build/torch-universal/liger_kernels/utils.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file incorporates code from Unsloth licensed under the Apache License, Version 2.0.
3
+ See the original Unsloth repository at https://github.com/unslothai/unsloth.
4
+
5
+ The following line
6
+ https://github.com/linkedin/Liger-Kernel/blob/7382a8761f9af679482b968f9348013d933947c7/src/liger_kernel/ops/utils.py#L23
7
+ is based on code from Unsloth, located at:
8
+ https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/utils.py#L43
9
+
10
+ Modifications made by Yanning Chen, 2024.
11
+ """
12
+
13
+ import functools
14
+ import importlib
15
+ import operator
16
+
17
+ from typing import Callable
18
+
19
+ import torch
20
+ import triton
21
+ import triton.language as tl
22
+
23
+ from packaging.version import Version
24
+
25
+ def infer_device():
26
+ """
27
+ Get current device name based on available devices
28
+ """
29
+ if torch.cuda.is_available(): # Works for both Nvidia and AMD
30
+ return "cuda"
31
+ elif torch.xpu.is_available():
32
+ return "xpu"
33
+ else:
34
+ return "cpu"
35
+
36
+ def is_hip() -> bool:
37
+ return torch.version.hip is not None
38
+
39
+
40
+ def ensure_contiguous(fn):
41
+ @functools.wraps(fn)
42
+ def wrapper(ctx, *args, **kwargs):
43
+ def maybe_to_contiguous(x):
44
+ return x.contiguous() if isinstance(x, torch.Tensor) else x
45
+
46
+ args = [maybe_to_contiguous(arg) for arg in args]
47
+ kwargs = {k: maybe_to_contiguous(v) for k, v in kwargs.items()}
48
+ return fn(ctx, *args, **kwargs)
49
+
50
+ return wrapper
51
+
52
+
53
+ def calculate_settings(n):
54
+ # reference: https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/utils.py#L43
55
+
56
+ MAX_FUSED_SIZE = 65536
57
+ BLOCK_SIZE = triton.next_power_of_2(n)
58
+ if BLOCK_SIZE > MAX_FUSED_SIZE:
59
+ raise RuntimeError(
60
+ f"Cannot launch Triton kernel since n = {n} exceeds the recommended Triton blocksize = {MAX_FUSED_SIZE}."
61
+ )
62
+
63
+ num_warps = 4
64
+ if BLOCK_SIZE >= 32768:
65
+ num_warps = 32 if not is_hip() else 16
66
+ elif BLOCK_SIZE >= 8192:
67
+ num_warps = 16
68
+ elif BLOCK_SIZE >= 2048:
69
+ num_warps = 8
70
+ return BLOCK_SIZE, num_warps
71
+
72
+
73
+ def compare_version(package: str, operator: Callable, target: str):
74
+ try:
75
+ pkg = importlib.import_module(package)
76
+ except ImportError:
77
+ return False
78
+ pkg_version = Version(pkg.__version__)
79
+ return operator(pkg_version, Version(target))
80
+
81
+
82
+ def get_amp_custom_fwd_bwd() -> Callable:
83
+ device = infer_device()
84
+ if compare_version("torch", operator.ge, "2.4.0"):
85
+ return (
86
+ functools.partial(torch.amp.custom_fwd, device_type=device),
87
+ functools.partial(torch.amp.custom_bwd, device_type=device),
88
+ )
89
+ return torch.cuda.amp.custom_fwd, torch.cuda.amp.custom_bwd
90
+
91
+
92
+ amp_custom_fwd, amp_custom_bwd = get_amp_custom_fwd_bwd()
93
+
94
+
95
+ torch_to_triton_dtype = {
96
+ torch.float32: tl.float32,
97
+ torch.float16: tl.float16,
98
+ torch.bfloat16: tl.bfloat16,
99
+ }
100
+
101
+
102
+ @triton.jit
103
+ def element_mul_kernel(
104
+ X_ptr,
105
+ X_stride,
106
+ grad_output_ptr,
107
+ n_cols,
108
+ BLOCK_SIZE: tl.constexpr,
109
+ ):
110
+ """
111
+ This function multiplies each element of the tensor pointed by X_ptr with the value pointed by grad_output_ptr.
112
+ The multiplication is performed in-place on the tensor pointed by X_ptr.
113
+
114
+ Parameters:
115
+ X_ptr: Pointer to the input tensor.
116
+ X_stride (int): The stride of the input tensor.
117
+ grad_output_ptr: Pointer to the gradient output value.
118
+ n_cols (int): The number of columns in the input tensor.
119
+ BLOCK_SIZE (int): The block size for Triton operations.
120
+ """
121
+
122
+ # Get the program ID and convert it to int64 to avoid overflow
123
+ program_id = tl.program_id(0).to(tl.int64)
124
+
125
+ # Locate the start index
126
+ X_ptr += program_id * X_stride
127
+
128
+ # Load the gradient output value
129
+ grad_output = tl.load(grad_output_ptr)
130
+
131
+ # Perform the element-wise multiplication
132
+ for i in range(0, n_cols, BLOCK_SIZE):
133
+ X_offsets = i + tl.arange(0, BLOCK_SIZE)
134
+ X_block = tl.load(X_ptr + X_offsets, mask=X_offsets < n_cols)
135
+ tl.store(X_ptr + X_offsets, X_block * grad_output, mask=X_offsets < n_cols)
flake.lock ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nodes": {
3
+ "flake-compat": {
4
+ "locked": {
5
+ "lastModified": 1733328505,
6
+ "narHash": "sha256-NeCCThCEP3eCl2l/+27kNNK7QrwZB1IJCrXfrbv5oqU=",
7
+ "owner": "edolstra",
8
+ "repo": "flake-compat",
9
+ "rev": "ff81ac966bb2cae68946d5ed5fc4994f96d0ffec",
10
+ "type": "github"
11
+ },
12
+ "original": {
13
+ "owner": "edolstra",
14
+ "repo": "flake-compat",
15
+ "type": "github"
16
+ }
17
+ },
18
+ "flake-utils": {
19
+ "inputs": {
20
+ "systems": "systems"
21
+ },
22
+ "locked": {
23
+ "lastModified": 1731533236,
24
+ "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
25
+ "owner": "numtide",
26
+ "repo": "flake-utils",
27
+ "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
28
+ "type": "github"
29
+ },
30
+ "original": {
31
+ "owner": "numtide",
32
+ "repo": "flake-utils",
33
+ "type": "github"
34
+ }
35
+ },
36
+ "kernel-builder": {
37
+ "inputs": {
38
+ "flake-compat": "flake-compat",
39
+ "flake-utils": "flake-utils",
40
+ "nixpkgs": "nixpkgs",
41
+ "rocm-nix": "rocm-nix"
42
+ },
43
+ "locked": {
44
+ "lastModified": 1745579622,
45
+ "narHash": "sha256-g8BXijChxDCZNu17M4Jj0GPv/7faVnArbHBOMNMpHjM=",
46
+ "owner": "huggingface",
47
+ "repo": "kernel-builder",
48
+ "rev": "e2f6f338737c6f1c570f9b59e43182633c0879c1",
49
+ "type": "github"
50
+ },
51
+ "original": {
52
+ "owner": "huggingface",
53
+ "repo": "kernel-builder",
54
+ "type": "github"
55
+ }
56
+ },
57
+ "nixpkgs": {
58
+ "locked": {
59
+ "lastModified": 1743559129,
60
+ "narHash": "sha256-7gpAWsENV3tY2HmeHYQ2MoQxGpys+jQWnkS/BHAMXVk=",
61
+ "owner": "nixos",
62
+ "repo": "nixpkgs",
63
+ "rev": "adae22bea8bcc0aa2fd6e8732044660fb7755f5e",
64
+ "type": "github"
65
+ },
66
+ "original": {
67
+ "owner": "nixos",
68
+ "ref": "nixos-unstable-small",
69
+ "repo": "nixpkgs",
70
+ "type": "github"
71
+ }
72
+ },
73
+ "rocm-nix": {
74
+ "inputs": {
75
+ "nixpkgs": [
76
+ "kernel-builder",
77
+ "nixpkgs"
78
+ ]
79
+ },
80
+ "locked": {
81
+ "lastModified": 1745310663,
82
+ "narHash": "sha256-1U3PzCO/jt7HUlEgLOY3RpxadKwTo6GSvb2j4m0UFw0=",
83
+ "owner": "huggingface",
84
+ "repo": "rocm-nix",
85
+ "rev": "e08373a0efa1c297b0c57af070e0a311df47481f",
86
+ "type": "github"
87
+ },
88
+ "original": {
89
+ "owner": "huggingface",
90
+ "repo": "rocm-nix",
91
+ "type": "github"
92
+ }
93
+ },
94
+ "root": {
95
+ "inputs": {
96
+ "kernel-builder": "kernel-builder"
97
+ }
98
+ },
99
+ "systems": {
100
+ "locked": {
101
+ "lastModified": 1681028828,
102
+ "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
103
+ "owner": "nix-systems",
104
+ "repo": "default",
105
+ "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
106
+ "type": "github"
107
+ },
108
+ "original": {
109
+ "owner": "nix-systems",
110
+ "repo": "default",
111
+ "type": "github"
112
+ }
113
+ }
114
+ },
115
+ "root": "root",
116
+ "version": 7
117
+ }
flake.nix ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ description = "Flake for Unsloth Kernels";
3
+
4
+ inputs = {
5
+ kernel-builder.url = "github:huggingface/kernel-builder";
6
+ };
7
+
8
+ outputs =
9
+ {
10
+ self,
11
+ kernel-builder,
12
+ }:
13
+ kernel-builder.lib.genFlakeOutputs {
14
+ path = ./.;
15
+ rev = self.shortRev or self.dirtyShortRev or self.lastModifiedDate;
16
+ };
17
+ }
torch-ext/liger_kernels/__init__.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from cross_entropy import LigerCrossEntropyFunction
2
+ from fused_linear_cross_entropy import LigerFusedLinearCrossEntropyFunction
3
+ from dyt import LigerDyTFunction
4
+ from geglu import LigerGELUMulFunction
5
+ from group_norm import LigerGroupNormFunction
6
+ from kl_div import LigerKLDivLossFunction
7
+ from layer_norm import LigerLayerNormFunction
8
+ from qwen2vl_mrope import LigerQwen2VLMRopeFunction
9
+ from rms_norm import LigerRMSNormFunction
10
+ from jsd import LigerJSDFunction
11
+ from rope import LigerRopeFunction
12
+ from swiglu import LigerSiLUMulFunction
13
+ from tvd import LigerTVDLossFunction
14
+
15
+ __all__ = [
16
+ "LigerCrossEntropyFunction",
17
+ "LigerFusedLinearCrossEntropyFunction",
18
+ "LigerDyTFunction",
19
+ "LigerGELUMulFunction",
20
+ "LigerGroupNormFunction",
21
+ "LigerKLDivLossFunction",
22
+ "LigerLayerNormFunction",
23
+ "LigerQwen2VLMRopeFunction",
24
+ "LigerRMSNormFunction",
25
+ "LigerJSDFunction",
26
+ "LigerRopeFunction",
27
+ "LigerSiLUMulFunction",
28
+ "LigerTVDLossFunction",
29
+ ]
torch-ext/liger_kernels/cross_entropy.py ADDED
@@ -0,0 +1,460 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import operator
2
+
3
+ from typing import Optional
4
+
5
+ import torch
6
+ import triton
7
+ import triton.language as tl
8
+
9
+ from utils import compare_version
10
+ from utils import element_mul_kernel
11
+ from utils import is_hip
12
+ from utils import infer_device
13
+
14
+ if compare_version("triton", operator.ge, "3.0.0"):
15
+ try:
16
+ # typical import path with dispatch available
17
+ from triton.language.extra.libdevice import tanh
18
+ except ModuleNotFoundError:
19
+ # for working with NGC containers
20
+ from triton.language.extra.cuda.libdevice import tanh
21
+ else:
22
+ from triton.language.math import tanh
23
+
24
+
25
+ @triton.jit
26
+ def liger_cross_entropy_kernel(
27
+ X_ptr,
28
+ X_stride,
29
+ Y_ptr,
30
+ Y_stride,
31
+ weight_ptr,
32
+ loss_ptr,
33
+ z_loss_ptr,
34
+ loss_stride,
35
+ n_cols,
36
+ n_non_ignore,
37
+ sum_non_ignore_weight,
38
+ weight_sum,
39
+ ignore_index,
40
+ lse_square_scale: tl.constexpr,
41
+ label_smoothing: tl.constexpr,
42
+ reduction: tl.constexpr, # set it as constexpr since reduction is always known at compile time
43
+ softcap,
44
+ RETURN_Z_LOSS: tl.constexpr,
45
+ BLOCK_SIZE: tl.constexpr,
46
+ HAS_WEIGHT: tl.constexpr,
47
+ HAS_SOFTCAPPING: tl.constexpr,
48
+ ):
49
+ """
50
+ This kernel computes both cross entropy loss and the gradient of the input.
51
+ We only consider hard label + mean reduction for now. Please refer to https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html for the math.
52
+
53
+ Parameters:
54
+ X_ptr: Pointer to input tensor.
55
+ X_stride (int): The stride of the input tensor.
56
+ Y_ptr: Pointer to target tensor.
57
+ Y_stride (int): The stride of the target tensor.
58
+ weight_ptr: Pointer to weight tensor.
59
+ loss_ptr: Pointer to tensor to store the loss.
60
+ z_loss_ptr: Pointer to tensor to store the z loss. No operation if RETURN_Z_LOSS is 0.
61
+ loss_stride (int): The stride of the loss tensor.
62
+ n_cols (int): The number of columns in the input tensor.
63
+ n_non_ignore (float): The number of non-ignored elements in the batch.
64
+ sum_non_ignore_weight (float): The sum of non-ignored target's weights in the batch.
65
+ weight_sum (float): The sum of weight tensor.
66
+ ignore_index (int): The index to ignore in the target.
67
+ label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
68
+ lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training.
69
+ reduction (str): The string for the reduction to apply
70
+ softcap (float): The upper threshold for scaling logits to the range (-softcap, +softcap).
71
+ RETURN_Z_LOSS (int): The boolean value to decide whether storing z loss to z_loss_ptr or not. It must be 0 or 1.
72
+ BLOCK_SIZE (int): The block size for Triton operations.
73
+ HAS_WEIGHT (bool): The boolean value to determine whether assigning weight to each of the classes.
74
+ HAS_SOFTCAPPING (bool): The boolean value to determine whether applying soft-capping or not.
75
+ """
76
+
77
+ # https://github.com/triton-lang/triton/issues/1058
78
+ # If B*T*V is too large, program_id * stride will overflow out of int32, so we convert to int64
79
+ program_id = tl.program_id(0).to(tl.int64)
80
+
81
+ # 1. Load Y_ptr first because if the target is ignore_index, we can return right away
82
+ Y_ptr += program_id * Y_stride
83
+ y = tl.load(Y_ptr)
84
+
85
+ # 2. locate the start index
86
+ X_ptr += program_id * X_stride
87
+
88
+ if y == ignore_index:
89
+ # set all X_ptr as 0
90
+ for i in range(0, n_cols, BLOCK_SIZE):
91
+ X_offsets = i + tl.arange(0, BLOCK_SIZE)
92
+ tl.store(X_ptr + X_offsets, 0.0, mask=X_offsets < n_cols)
93
+ return
94
+
95
+ loss_ptr += program_id * loss_stride
96
+ if RETURN_Z_LOSS:
97
+ z_loss_ptr += program_id * loss_stride
98
+
99
+ if HAS_WEIGHT:
100
+ weight_y = tl.load(weight_ptr + y).cast(tl.float32)
101
+
102
+ # Online softmax: 2 loads + 1 store (compared with 3 loads + 1 store for the safe softmax)
103
+ # Refer to Algorithm 3 in the paper: https://arxiv.org/pdf/1805.02867
104
+
105
+ # 3. [Online softmax] first pass: find max + sum
106
+ m = float("-inf") # m is the max value. use the notation from the paper
107
+ d = 0.0 # d is the sum. use the notation from the paper
108
+ ori_X_y = tl.load(X_ptr + y).cast(tl.float32) # we need to store the original value of X_y for the loss calculation
109
+ if HAS_SOFTCAPPING:
110
+ ori_X_y = softcap * tanh(ori_X_y / softcap)
111
+
112
+ # Label smoothing is a general case of normal cross entropy
113
+ # See the full derivation at https://github.com/linkedin/Liger-Kernel/pull/198#issue-2503665310
114
+ scaled_x_sum = 0.0
115
+ eps = label_smoothing / n_cols
116
+
117
+ for i in range(0, n_cols, BLOCK_SIZE):
118
+ X_offsets = i + tl.arange(0, BLOCK_SIZE)
119
+ X_block = tl.load(
120
+ X_ptr + X_offsets,
121
+ mask=X_offsets < n_cols,
122
+ other=float("-inf"),
123
+ # Ensure float32 precision for softmax calculation
124
+ ).cast(tl.float32)
125
+ if HAS_SOFTCAPPING:
126
+ X_block = softcap * tanh(X_block / softcap)
127
+ block_max = tl.max(X_block)
128
+ if label_smoothing > 0:
129
+ # scale X beforehand to avoid overflow
130
+ if HAS_WEIGHT:
131
+ weight_block = tl.load(weight_ptr + X_offsets, mask=X_offsets < n_cols)
132
+ scaled_x_sum += tl.sum(tl.where(X_offsets < n_cols, -eps * X_block * weight_block, 0.0))
133
+ else:
134
+ scaled_x_sum += tl.sum(tl.where(X_offsets < n_cols, -eps * X_block, 0.0))
135
+ m_new = tl.maximum(m, block_max)
136
+ d = d * tl.exp(m - m_new) + tl.sum(tl.exp(X_block - m_new))
137
+ m = m_new
138
+
139
+ # log (sum(e^(X_i))) = log (sum(e ^ (max(X) * e ^ (X_i - max(X)))))
140
+ # = log (e^(max(X)) * sum(e ^ (X_i - max(X))))
141
+ # = max(X) + log (sum(e ^ (X_i - max(X)))) = m + log d
142
+ lse = m + tl.log(d)
143
+
144
+ # 4. [Online Softmax] Second pass: compute gradients
145
+ # For 'mean' reduction, gradients are normalized by number of non-ignored elements (N)
146
+ # dx_y = (softmax(x_y) - 1) / N
147
+ # dx_i = softmax(x_i) / N, i != y
148
+ # For label smoothing:
149
+ # dx_i = (softmax(x_i) - label_smoothing / V) / N, V = n_cols, i != y
150
+ # dx_y = (softmax(x_y) - label_smoothing / V - (1 - label_smoothing)) / N
151
+ # = dx_i - (1 - label_smoothing) / N
152
+ # With Z loss:
153
+ # dx_i = ((1 + 2 * lse_square_scale * lse) * softmax(x_i) - label_smoothing / V) / N, i != y
154
+ # dx_y = dx_i - (1 - label_smoothing) / N
155
+ # For 'sum' reduction, no normalization is applied:
156
+ # dx_y = softmax(x_y) - 1
157
+ # dx_i = softmax(x_i), for i ≠ y
158
+
159
+ for i in range(0, n_cols, BLOCK_SIZE):
160
+ X_offsets = i + tl.arange(0, BLOCK_SIZE)
161
+ X_block = tl.load(
162
+ X_ptr + X_offsets,
163
+ mask=X_offsets < n_cols,
164
+ other=float("-inf"),
165
+ # Ensure float32 precision for softmax calculation
166
+ ).cast(tl.float32)
167
+ if HAS_SOFTCAPPING:
168
+ intermediate = tanh(X_block / softcap)
169
+ X_block = softcap * intermediate
170
+
171
+ if not HAS_WEIGHT:
172
+ # softmax(x_i)
173
+ X_block = tl.exp(X_block - m) / d
174
+ # derivative of z-loss: 2 * lse_square_scale * lse * softmax(x_i)
175
+ X_block += 2 * lse_square_scale * lse * X_block
176
+ # smoothing term
177
+ X_block += -eps
178
+ # special handle dx_y
179
+ X_block = tl.where(X_offsets != y, X_block, X_block - (1 - label_smoothing))
180
+ # reduction scale
181
+ if reduction == "mean":
182
+ X_block = X_block / n_non_ignore
183
+ else:
184
+ weight_block = tl.load(weight_ptr + X_offsets, mask=X_offsets < n_cols)
185
+ softmax_X = tl.exp(X_block - m) / d
186
+ # derivative of original_loss
187
+ dloss_ori = (1 - label_smoothing) * softmax_X
188
+ # specially handle dx_y
189
+ dloss_ori = tl.where(X_offsets != y, dloss_ori, dloss_ori - (1 - label_smoothing))
190
+ dloss_ori = dloss_ori * weight_y
191
+ # derivative of smooth_loss
192
+ dloss_smooth = eps * (-weight_block + softmax_X * weight_sum)
193
+ # derivative of z-loss
194
+ dz_loss = 2 * lse_square_scale * lse * softmax_X
195
+ # reduction scale
196
+ if reduction == "mean":
197
+ dloss_ori = dloss_ori / sum_non_ignore_weight
198
+ dloss_smooth = dloss_smooth / sum_non_ignore_weight
199
+ # TODO: Implement weighted z_loss. Currently, z_loss is not scaled by weight.
200
+ dz_loss = dz_loss / n_non_ignore
201
+ # derivative of total_loss
202
+ X_block = dloss_ori + dloss_smooth + dz_loss
203
+
204
+ # chain rule softcapping
205
+ # d(softcap * tanh(x / softcap)) = (1 - tanh^2(x / softcap))
206
+ if HAS_SOFTCAPPING:
207
+ X_block = X_block * (1 - intermediate * intermediate)
208
+
209
+ tl.store(X_ptr + X_offsets, X_block, mask=X_offsets < n_cols)
210
+
211
+ # We need tl.debug_barrier() to ensure the new result of X_ptr is written as mentioned in
212
+ # https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/ops/cross_entropy.py#L34
213
+ tl.debug_barrier()
214
+
215
+ # 5. Calculate the loss
216
+
217
+ # loss = log (softmax(X_y)) = log ((e ^ (X_y - max(X)) / sum(e ^ (X - max(X))))
218
+ # = (X_y - max(X)) - log(sum(e ^ (X - max(X))))
219
+ # = X_y - m - log d = X_y - lse
220
+ # sum(e ^ (X - max(X))) must >= 1 because the max term is e ^ 0 = 1
221
+ # So we can safely calculate log (softmax(X_y)) without overflow
222
+ loss = lse - ori_X_y
223
+ if HAS_WEIGHT:
224
+ loss = weight_y * loss
225
+
226
+ # Original loss = H(q, p), with label smoothing regularization = H(q', p) and (label_smoothing / V) = eps
227
+ # H(q', p) = (1 - label_smoothing) * H(q, p) + label_smoothing * H(u, p)
228
+ # = (1 - label_smoothing) * H(q, p) + eps * sum(logsoftmax(x_i))
229
+ # By using m (global max of xi) and d (sum of e^(xi-m)), we can simplify as:
230
+ # = (1 - label_smoothing) * H(q, p) + (sum(-eps * x_i) + label_smoothing * (m + logd))
231
+ # Refer to H(q', p) in section 7 of the paper: https://arxiv.org/pdf/1512.00567
232
+ # pytorch: https://github.com/pytorch/pytorch/blob/2981534f54d49fa3a9755c9b0855e7929c2527f0/aten/src/ATen/native/LossNLL.cpp#L516
233
+ # See full derivation at https://github.com/linkedin/Liger-Kernel/pull/198#issuecomment-2333753087
234
+ if label_smoothing > 0:
235
+ if HAS_WEIGHT:
236
+ smooth_loss = scaled_x_sum + eps * lse * weight_sum
237
+ else:
238
+ smooth_loss = scaled_x_sum + label_smoothing * lse
239
+ loss = loss * (1 - label_smoothing) + smooth_loss
240
+
241
+ # An auxiliary loss, z_loss
242
+ # Refer to Page14 Loss function section in the paper PaLM: https://www.jmlr.org/papers/v24/22-1144.html
243
+ z_loss = lse_square_scale * lse * lse
244
+ # Normalize the loss by the number of non-ignored elements if reduction is "mean"
245
+ if reduction == "mean":
246
+ if HAS_WEIGHT:
247
+ loss = loss / sum_non_ignore_weight
248
+ else:
249
+ loss = loss / n_non_ignore
250
+ # TODO: Implement weighted z_loss. Currently, z_loss is not scaled by weight.
251
+ z_loss = z_loss / n_non_ignore
252
+ loss += z_loss
253
+
254
+ tl.store(loss_ptr, loss)
255
+ if RETURN_Z_LOSS:
256
+ tl.store(z_loss_ptr, z_loss)
257
+
258
+
259
+ # The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
260
+ # However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
261
+ # The optimal maximum block size depends on your hardware, your kernel, and your dtype
262
+ MAX_FUSED_SIZE = 4096 if infer_device() == "xpu" else 65536 // 2 # the best size we found by manually tuning
263
+
264
+
265
+ def cross_entropy_forward(
266
+ _input,
267
+ target,
268
+ weight,
269
+ ignore_index,
270
+ lse_square_scale,
271
+ label_smoothing,
272
+ reduction,
273
+ softcap,
274
+ return_z_loss,
275
+ ):
276
+ assert isinstance(return_z_loss, bool), f"return_z_loss must be True or False. Got: {return_z_loss}"
277
+
278
+ BT, V = _input.shape
279
+ n_rows = BT
280
+
281
+ BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
282
+
283
+ # unreduced loss
284
+ loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device)
285
+ z_loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device) if return_z_loss else None
286
+
287
+ target_mask = target != ignore_index
288
+ n_non_ignore = target_mask.sum().item()
289
+ assert (target * target_mask).max() < _input.shape[-1], (
290
+ f"Target {target.max()} is out of bounds. Expected < {_input.shape[-1]}"
291
+ )
292
+ assert (target * target_mask).min() >= 0, f"Target {target.min()} is out of bounds. Expected >= 0"
293
+ sum_non_ignore_weight = n_non_ignore
294
+ weight_sum = 0.0
295
+ if weight is not None:
296
+ assert weight.shape[0] == V, f"If given, weight has to be a Tensor of size V. Got: {weight.shape}"
297
+ assert torch.is_floating_point(weight), (
298
+ f"If given, weight has to be a Tensor of floating point dtype. Got: {weight.dtype}"
299
+ )
300
+ sum_non_ignore_weight = torch.gather(weight, dim=0, index=target.masked_select(target_mask)).sum().item()
301
+ weight_sum = weight.sum().item()
302
+ # ensure weight is contiguous
303
+ if weight.stride(-1) != 1:
304
+ weight = weight.contiguous()
305
+
306
+ # ensure _input and target are contiguous in the last dimension
307
+ if _input.stride(-1) != 1:
308
+ _input = _input.contiguous()
309
+ if target.stride(-1) != 1:
310
+ target = target.contiguous()
311
+
312
+ # Here we use a trick to store X_ptr gradient in X_ptr so we can save memory
313
+ liger_cross_entropy_kernel[(n_rows,)](
314
+ X_ptr=_input,
315
+ X_stride=_input.stride(-2),
316
+ Y_ptr=target,
317
+ Y_stride=target.stride(-1), # always 1
318
+ weight_ptr=weight, # dummy if None
319
+ loss_ptr=loss_1d,
320
+ z_loss_ptr=z_loss_1d,
321
+ loss_stride=loss_1d.stride(-1), # always 1
322
+ n_cols=V,
323
+ n_non_ignore=n_non_ignore,
324
+ sum_non_ignore_weight=sum_non_ignore_weight,
325
+ ignore_index=ignore_index,
326
+ weight_sum=weight_sum,
327
+ lse_square_scale=lse_square_scale,
328
+ label_smoothing=label_smoothing,
329
+ reduction=reduction,
330
+ softcap=softcap,
331
+ RETURN_Z_LOSS=return_z_loss,
332
+ BLOCK_SIZE=BLOCK_SIZE,
333
+ HAS_WEIGHT=True if weight is not None else False,
334
+ HAS_SOFTCAPPING=True if softcap is not None else False,
335
+ # TODO: 32 seems to give the best performance
336
+ # Performance is quite sensitive to num_warps
337
+ num_warps=32 if not is_hip() else 16,
338
+ )
339
+
340
+ if reduction == "none":
341
+ loss = loss_1d
342
+ z_loss = z_loss_1d if return_z_loss else None
343
+ else:
344
+ loss = torch.sum(loss_1d)
345
+ z_loss = torch.sum(z_loss_1d) if return_z_loss else None
346
+
347
+ return loss, z_loss, _input
348
+
349
+
350
+ def cross_entropy_backward(_input, grad_output):
351
+ # If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time
352
+ if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
353
+ pass
354
+ # If reduction is 'none'
355
+ elif grad_output.ndim > 0:
356
+ _input = _input * grad_output.unsqueeze(dim=1)
357
+ # If reduction is ['mean', 'sum'], grad_output is just a scalar
358
+ # We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place
359
+ # for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton.
360
+ else:
361
+ BT, V = _input.shape
362
+ n_rows = BT
363
+ BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
364
+
365
+ element_mul_kernel[(n_rows,)](
366
+ _input,
367
+ _input.stride(-2),
368
+ grad_output,
369
+ V,
370
+ BLOCK_SIZE=BLOCK_SIZE,
371
+ num_warps=32 if not is_hip() else 16,
372
+ )
373
+
374
+ return _input
375
+
376
+
377
+ class LigerCrossEntropyFunction(torch.autograd.Function):
378
+ """
379
+ This class implements a custom autograd function for the Liger Cross Entropy loss.
380
+ It overrides the forward and backward methods of the torch.autograd.Function class.
381
+ """
382
+
383
+ @staticmethod
384
+ def forward(
385
+ ctx,
386
+ _input: torch.Tensor,
387
+ target: torch.Tensor,
388
+ weight: Optional[torch.FloatTensor],
389
+ ignore_index: int = -100,
390
+ lse_square_scale: float = 0.0,
391
+ label_smoothing: float = 0.0,
392
+ reduction: str = "mean",
393
+ softcap: Optional[float] = None,
394
+ return_z_loss: bool = False,
395
+ ):
396
+ """
397
+ The forward pass of the Liger Cross Entropy loss.
398
+
399
+ Parameters:
400
+ ctx : The context object.
401
+ _input (tensor): The input tensor of shape (BT, V) where B is batch size, T is sequence length, V is vocab size.
402
+ target (tensor): The target tensor of shape (BT) where each value is in [0, V-1].
403
+ weight(Tensor, optional): a manual rescaling weight given to each class. If given, has to be a Tensor of size V and floating point dtype
404
+ ignore_index (int): The index to ignore in the target.
405
+ lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training.
406
+ label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
407
+ reduction (str): The reduction to apply to the output: "none" | "mean | "sum".
408
+ softcap (Optional[float]): The upper threshold for scaling logits to the range (-softcap, +softcap).
409
+ return_z_loss (bool): When `return_z_loss` is `True`, returns (loss, z_loss) instead of (loss, None). Default: `False`
410
+
411
+ Returns:
412
+ tuple: A tuple with the compouted losses with respect to loss and z loss. The elements are tensors or None.
413
+ """
414
+ loss, z_loss, _input = cross_entropy_forward(
415
+ _input,
416
+ target,
417
+ weight,
418
+ ignore_index,
419
+ lse_square_scale,
420
+ label_smoothing,
421
+ reduction,
422
+ softcap,
423
+ return_z_loss,
424
+ )
425
+ # TODO: investigation
426
+ # If we don't detach the _input tensor, the memory will double
427
+ # Not sure why but seems that there will be a time both grad and value exist but in different location
428
+ ctx.save_for_backward(_input.detach())
429
+ ctx.return_z_loss = return_z_loss
430
+
431
+ return loss, z_loss
432
+
433
+ @staticmethod
434
+ def backward(ctx, grad_output, grad_ouput2):
435
+ """
436
+ The backward pass of the Liger Cross Entropy loss.
437
+
438
+ Parameters:
439
+ ctx : The context object with saved tensors.
440
+ grad_output (tensor): The tensor containing the gradient of the loss with respect to the output.
441
+ grad_output2 (tenosr): No use.
442
+ Returns:
443
+ tuple: A tuple with the gradients with respect to the inputs. The elements are tensors or None.
444
+ """
445
+ if ctx.return_z_loss:
446
+ del grad_ouput2 # z_loss is only for logging
447
+
448
+ (_input,) = ctx.saved_tensors
449
+ _input = cross_entropy_backward(_input, grad_output)
450
+ return (
451
+ _input,
452
+ None,
453
+ None,
454
+ None,
455
+ None,
456
+ None,
457
+ None,
458
+ None,
459
+ None,
460
+ )
torch-ext/liger_kernels/dyt.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import operator
2
+
3
+ import torch
4
+ import triton
5
+ import triton.language as tl
6
+
7
+ from utils import calculate_settings
8
+ from utils import compare_version
9
+ from utils import ensure_contiguous
10
+ from utils import infer_device
11
+
12
+ if compare_version("triton", operator.ge, "3.0.0"):
13
+ try:
14
+ # typical import path with dispatch available
15
+ from triton.language.extra.libdevice import tanh
16
+ except ModuleNotFoundError:
17
+ # for working with NGC containers
18
+ from triton.language.extra.cuda.libdevice import tanh
19
+ else:
20
+ from triton.language.math import tanh
21
+
22
+
23
+ @triton.jit
24
+ def _dyt_fwd_kernel(
25
+ x_ptr,
26
+ x_row_stride,
27
+ alpha_ptr,
28
+ gamma_ptr,
29
+ beta_ptr,
30
+ y_ptr,
31
+ y_row_stride,
32
+ n_cols,
33
+ BLOCK_SIZE: tl.constexpr,
34
+ ):
35
+ """
36
+ Reference:
37
+ https://arxiv.org/abs/2503.10622
38
+
39
+ Shapes:
40
+ - x: (BT, C)
41
+ - alpha: (1)
42
+ - gamma: (C)
43
+ - beta: (C)
44
+ """
45
+ row_idx = tl.program_id(0)
46
+ offsets = tl.arange(0, BLOCK_SIZE)
47
+ mask = offsets < n_cols
48
+
49
+ x_ptr += row_idx * x_row_stride
50
+ y_ptr += row_idx * y_row_stride
51
+
52
+ alpha = tl.load(alpha_ptr)
53
+ gamma = tl.load(gamma_ptr + offsets, mask=mask)
54
+ beta = tl.load(beta_ptr + offsets, mask=mask)
55
+ x = tl.load(x_ptr + offsets, mask=mask)
56
+ y = gamma * tanh((alpha * x).cast(tl.float32)) + beta
57
+ tl.store(y_ptr + offsets, y, mask=mask)
58
+
59
+
60
+ @triton.jit
61
+ def _dyt_bwd_kernel(
62
+ x_ptr,
63
+ x_row_stride,
64
+ dy_ptr,
65
+ dy_row_stride,
66
+ dx_ptr,
67
+ dx_row_stride,
68
+ alpha_ptr,
69
+ dalpha_ptr,
70
+ gamma_ptr,
71
+ dgamma_ptr,
72
+ dgamma_row_stride,
73
+ n_cols,
74
+ n_rows,
75
+ ROWS_PER_PROGRAM: tl.constexpr,
76
+ BLOCK_SIZE: tl.constexpr,
77
+ ):
78
+ """
79
+ Reference:
80
+ https://arxiv.org/abs/2503.10622
81
+
82
+ Shapes:
83
+ - x: (BT, C)
84
+ - alpha: (1)
85
+ - gamma: (C)
86
+ - dx: (BT, C)
87
+ - dy: (BT, C)
88
+ - dgamma: (sm_count, C)
89
+ - dalpha: (sm_count,)
90
+ """
91
+ # d(gamma * tanh(alpha * x) + beta) / dx
92
+ # = gamma * (1 - tanh^2(alpha * x)) * alpha
93
+ # d(gamma * tanh(alpha * x) + beta) / dalpha
94
+ # = gamma * (1 - tanh^2(alpha * x)) * x
95
+ # d(gamma * tanh(alpha * x) + beta) / dgamma
96
+ # = tanh(alpha * x)
97
+ # d(gamma * tanh(alpha * x)) / dbeta = 1
98
+ pid = tl.program_id(0)
99
+
100
+ row_start = pid * ROWS_PER_PROGRAM
101
+ row_end = min((pid + 1) * ROWS_PER_PROGRAM, n_rows)
102
+ offsets = tl.arange(0, BLOCK_SIZE)
103
+ mask = offsets < n_cols
104
+
105
+ dalpha = 0.0
106
+ dgamma = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
107
+
108
+ x_ptr += row_start * x_row_stride
109
+ dx_ptr += row_start * dx_row_stride
110
+ dy_ptr += row_start * dy_row_stride
111
+ alpha = tl.load(alpha_ptr)
112
+ gamma = tl.load(gamma_ptr + offsets, mask=mask, other=0.0)
113
+
114
+ for _ in tl.range(row_start, row_end):
115
+ dy = tl.load(dy_ptr + offsets, mask=mask, other=0.0)
116
+ x = tl.load(x_ptr + offsets, mask=mask, other=0.0)
117
+ tanh_ax = tanh((alpha * x).cast(tl.float32))
118
+ sech2_ax = 1 - tanh_ax * tanh_ax
119
+
120
+ dx = dy * gamma * sech2_ax * alpha
121
+ dalpha += tl.sum(dy * gamma * sech2_ax * x)
122
+ dgamma += dy * tanh_ax
123
+ tl.store(dx_ptr + offsets, dx, mask=mask)
124
+
125
+ dy_ptr += dy_row_stride
126
+ x_ptr += x_row_stride
127
+ dx_ptr += dx_row_stride
128
+
129
+ tl.store(dgamma_ptr + pid * dgamma_row_stride + offsets, dgamma, mask=mask)
130
+ tl.store(dalpha_ptr + pid, dalpha)
131
+
132
+ pass
133
+
134
+
135
+ def liger_dyt_fwd(x, alpha, gamma, beta):
136
+ shape = x.shape
137
+ dim = shape[-1]
138
+ x = x.view(-1, dim)
139
+ n_rows, n_cols = x.shape
140
+ y = torch.empty_like(x)
141
+ BLOCK_SIZE, num_warps = calculate_settings(n_cols)
142
+ _dyt_fwd_kernel[(n_rows,)](
143
+ x_ptr=x,
144
+ alpha_ptr=alpha,
145
+ gamma_ptr=gamma,
146
+ beta_ptr=beta,
147
+ y_ptr=y,
148
+ x_row_stride=x.stride(0),
149
+ y_row_stride=y.stride(0),
150
+ n_cols=n_cols,
151
+ BLOCK_SIZE=BLOCK_SIZE,
152
+ num_warps=num_warps,
153
+ )
154
+ return y.view(*shape)
155
+
156
+
157
+ def liger_dyt_bwd(dy, x, alpha, gamma):
158
+ shape = dy.shape
159
+ dtype = x.dtype
160
+ dim = shape[-1]
161
+ dy = dy.view(-1, dim)
162
+ x = x.view(-1, dim)
163
+ n_rows, n_cols = dy.shape
164
+ BLOCK_SIZE, num_warps = calculate_settings(n_cols)
165
+ sm_count = 1
166
+ device = infer_device()
167
+ if device == "cuda":
168
+ sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count
169
+ elif device == "xpu":
170
+ sm_count = torch.xpu.get_device_properties(x.device).gpu_subslice_count
171
+ if n_cols > BLOCK_SIZE:
172
+ raise RuntimeError(
173
+ f"Feature dimension {dim} exceeds maximum supported size of {BLOCK_SIZE}. Consider using a smaller feature dimension."
174
+ )
175
+
176
+ dx = torch.empty_like(x, dtype=torch.float32)
177
+ _dalpha = torch.empty((sm_count,), dtype=torch.float32, device=x.device)
178
+ _dgamma = torch.empty((sm_count, n_cols), dtype=torch.float32, device=x.device)
179
+
180
+ grid = (sm_count,)
181
+ rows_per_program = triton.cdiv(n_rows, sm_count)
182
+ _dyt_bwd_kernel[grid](
183
+ x_ptr=x,
184
+ x_row_stride=x.stride(0),
185
+ dy_ptr=dy,
186
+ dy_row_stride=dy.stride(0),
187
+ dx_ptr=dx,
188
+ dx_row_stride=dx.stride(0),
189
+ alpha_ptr=alpha,
190
+ dalpha_ptr=_dalpha,
191
+ gamma_ptr=gamma,
192
+ dgamma_ptr=_dgamma,
193
+ dgamma_row_stride=_dgamma.stride(0),
194
+ n_cols=n_cols,
195
+ n_rows=n_rows,
196
+ ROWS_PER_PROGRAM=rows_per_program,
197
+ BLOCK_SIZE=BLOCK_SIZE,
198
+ num_warps=num_warps,
199
+ )
200
+ dalpha = _dalpha.sum(dim=0, keepdim=True).to(dtype)
201
+ dgamma = _dgamma.sum(dim=0).to(dtype)
202
+ dbeta = dy.sum(dim=0).to(dtype)
203
+ return dx.view(*shape), dalpha, dgamma, dbeta
204
+
205
+
206
+ class LigerDyTFunction(torch.autograd.Function):
207
+ @staticmethod
208
+ @ensure_contiguous
209
+ def forward(ctx, x, alpha, gamma, beta):
210
+ y = liger_dyt_fwd(x, alpha, gamma, beta)
211
+ ctx.save_for_backward(x, alpha, gamma)
212
+ return y
213
+
214
+ @staticmethod
215
+ @ensure_contiguous
216
+ def backward(ctx, grad_output):
217
+ x, alpha, gamma = ctx.saved_tensors
218
+ dx, dalpha, dgamma, dbeta = liger_dyt_bwd(
219
+ grad_output,
220
+ x,
221
+ alpha,
222
+ gamma,
223
+ )
224
+
225
+ return (dx, dalpha, dgamma, dbeta)
torch-ext/liger_kernels/fused_linear_cross_entropy.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import triton
3
+
4
+ from cross_entropy import liger_cross_entropy_kernel
5
+ from utils import amp_custom_bwd
6
+ from utils import amp_custom_fwd
7
+ from utils import element_mul_kernel
8
+ from utils import is_hip
9
+
10
+ # The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
11
+ # However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
12
+ # The optimal maximum block size depends on your hardware, your kernel, and your dtype
13
+ MAX_FUSED_SIZE = 65536 // 2
14
+
15
+
16
+ def fused_linear_cross_entropy_forward(
17
+ _input,
18
+ weight,
19
+ target,
20
+ ce_weight=None,
21
+ bias=None,
22
+ ignore_index=-100,
23
+ lse_square_scale=0.0,
24
+ label_smoothing=0.0,
25
+ reduction="mean",
26
+ softcap=None,
27
+ return_z_loss=False,
28
+ ):
29
+ assert isinstance(return_z_loss, bool), f"return_z_loss must be True or False. Got: {return_z_loss}"
30
+ device = _input.device
31
+
32
+ # inputs have shape: BT x H
33
+ # materialized activations will have shape: BT x V
34
+ # the increase in memory = BT x V
35
+ # reduction can be achieved by partitioning the number of tokens BT into smaller chunks.
36
+ # for ex: if we were to achieve the same memory consumption as BT x H, then the chunk size should be:
37
+ # inc_factor = (V+H-1)//H, chunk_size = (BT + inc_factor - 1)//inc_factor
38
+ # for ex: BT = 4096*4, V = 32000, H = 4096 ==> inc_factor = 8, chunk_size = 2048
39
+ BT, H = _input.shape
40
+ V = weight.shape[0]
41
+ BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
42
+
43
+ inc_factor = triton.cdiv(V, H) # (V + H - 1) // H
44
+ chunk_size = triton.next_power_of_2(triton.cdiv(BT, inc_factor)) # (BT + inc_factor - 1) // inc_factor
45
+ num_chunks = triton.cdiv(BT, chunk_size) # (BT + chunk_size - 1) // chunk_size
46
+
47
+ grad_weight = torch.zeros_like(weight, device=device) if weight.requires_grad else None
48
+ grad_input = torch.zeros_like(_input, device=device)
49
+ grad_bias = torch.zeros_like(bias, device=device) if bias is not None else None
50
+ # we use fp32 for loss accumulator
51
+ loss_1d = torch.zeros(BT, dtype=torch.float32, device=device)
52
+ z_loss_1d = torch.zeros(BT, dtype=_input.dtype, device=_input.device) if return_z_loss else None
53
+
54
+ # TODO: evaluate how CUDA synchronization caused by .item() affects the speed
55
+ target_mask = target != ignore_index
56
+ total_n_non_ignore = target_mask.sum().item()
57
+ total_sum_non_ignore_ce_weight = total_n_non_ignore
58
+ ce_weight_sum = 0.0
59
+ if ce_weight is not None:
60
+ assert ce_weight.shape[0] == V, f"If given, weight has to be a Tensor of size V. Got: {ce_weight.shape}"
61
+ assert torch.is_floating_point(ce_weight), (
62
+ f"If given, weight has to be a Tensor of floating point dtype. Got: {ce_weight.dtype}"
63
+ )
64
+ total_sum_non_ignore_ce_weight = (
65
+ torch.gather(ce_weight, dim=0, index=target.masked_select(target_mask)).sum().item()
66
+ )
67
+ ce_weight_sum = ce_weight.sum().item()
68
+ if ce_weight.stride(-1) != 1:
69
+ ce_weight = ce_weight.contiguous()
70
+
71
+ for chunk_id in range(num_chunks):
72
+ start_idx = chunk_id * chunk_size
73
+ end_idx = min((chunk_id + 1) * chunk_size, BT)
74
+ _input_chunk = _input[start_idx:end_idx] # chunk_size x H
75
+
76
+ # when doing matmul, use the original precision
77
+ logits_chunk = _input_chunk @ weight.t() # chunk_size x V
78
+ if bias is not None:
79
+ logits_chunk = logits_chunk + bias
80
+
81
+ target_chunk = target[start_idx:end_idx] # chunk_size,
82
+
83
+ n_rows = logits_chunk.shape[0]
84
+
85
+ # unreduced loss
86
+ loss_1d_slice = loss_1d[start_idx:end_idx] # chunk_size,
87
+ z_loss_1d_slice = z_loss_1d[start_idx:end_idx] if return_z_loss else None
88
+
89
+ # ensure _input and target are contiguous
90
+ logits_chunk = logits_chunk.contiguous()
91
+ target_chunk = target_chunk.contiguous()
92
+
93
+ # Here we calculate the gradient of logits_chunk in place so we can save memory.
94
+ liger_cross_entropy_kernel[(n_rows,)](
95
+ X_ptr=logits_chunk,
96
+ X_stride=logits_chunk.stride(-2),
97
+ Y_ptr=target_chunk,
98
+ Y_stride=target_chunk.stride(-1), # always 1
99
+ weight_ptr=ce_weight,
100
+ loss_ptr=loss_1d_slice,
101
+ z_loss_ptr=z_loss_1d_slice,
102
+ loss_stride=loss_1d_slice.stride(-1), # always 1
103
+ n_cols=V,
104
+ n_non_ignore=total_n_non_ignore,
105
+ sum_non_ignore_weight=total_sum_non_ignore_ce_weight,
106
+ weight_sum=ce_weight_sum,
107
+ ignore_index=ignore_index,
108
+ lse_square_scale=lse_square_scale,
109
+ label_smoothing=label_smoothing,
110
+ reduction=reduction,
111
+ softcap=softcap,
112
+ RETURN_Z_LOSS=return_z_loss,
113
+ HAS_WEIGHT=True if ce_weight is not None else False,
114
+ HAS_SOFTCAPPING=True if softcap is not None else False,
115
+ BLOCK_SIZE=BLOCK_SIZE,
116
+ num_warps=32 if not is_hip() else 16,
117
+ )
118
+
119
+ loss_1d[start_idx:end_idx] = loss_1d_slice
120
+ if return_z_loss:
121
+ z_loss_1d[start_idx:end_idx] = z_loss_1d_slice
122
+ grad_logits_chunk = logits_chunk # chunk_size x V
123
+
124
+ grad_input[start_idx:end_idx] = grad_logits_chunk @ weight
125
+
126
+ if grad_weight is not None:
127
+ torch.addmm(
128
+ input=grad_weight,
129
+ mat1=logits_chunk.t().to(
130
+ _input_chunk.dtype
131
+ ), # In an autocast scenario without bias, differing logits_chunk data types will cause an addmm operation error.
132
+ mat2=_input_chunk,
133
+ out=grad_weight,
134
+ alpha=1.0,
135
+ beta=1.0,
136
+ )
137
+
138
+ if bias is not None:
139
+ torch.add(
140
+ input=grad_bias,
141
+ other=logits_chunk.sum(dim=0),
142
+ out=grad_bias,
143
+ alpha=1.0,
144
+ )
145
+
146
+ # Need extra calculations for backward if reduction=='none'. Not supporting reduction='none' now.
147
+ # if reduction == "none":
148
+ # loss = loss_1d
149
+ # z_loss = z_loss_1d if return_z_loss else None
150
+
151
+ else:
152
+ loss = torch.sum(loss_1d)
153
+ z_loss = torch.sum(z_loss_1d) if return_z_loss else None
154
+ return loss, z_loss, grad_input, grad_weight, grad_bias
155
+
156
+
157
+ def fused_linear_cross_entropy_backward(grad_output, grad_input, grad_weight, grad_bias):
158
+ # If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time
159
+ if not torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
160
+ # We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place
161
+ # for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton.
162
+ BT, H = grad_input.shape
163
+ n_rows = BT
164
+ BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(H))
165
+
166
+ element_mul_kernel[(n_rows,)](
167
+ grad_input,
168
+ grad_input.stride(-2),
169
+ grad_output,
170
+ H,
171
+ BLOCK_SIZE=BLOCK_SIZE,
172
+ num_warps=32 if not is_hip() else 16,
173
+ )
174
+
175
+ # handle grad_weight
176
+ if grad_weight is not None:
177
+ V, H = grad_weight.shape
178
+ n_rows = V
179
+
180
+ element_mul_kernel[(n_rows,)](
181
+ grad_weight,
182
+ grad_weight.stride(-2),
183
+ grad_output,
184
+ H,
185
+ BLOCK_SIZE=BLOCK_SIZE,
186
+ num_warps=32 if not is_hip() else 16,
187
+ )
188
+
189
+ if grad_bias is not None:
190
+ V = grad_bias.shape[0]
191
+ n_rows = V
192
+
193
+ element_mul_kernel[(n_rows,)](
194
+ grad_bias,
195
+ grad_bias.stride(-1),
196
+ grad_output,
197
+ 1,
198
+ BLOCK_SIZE=BLOCK_SIZE,
199
+ num_warps=32 if not is_hip() else 16,
200
+ )
201
+ return grad_input, grad_weight, grad_bias
202
+
203
+
204
+ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
205
+ @staticmethod
206
+ @amp_custom_fwd
207
+ def forward(
208
+ ctx,
209
+ _input,
210
+ weight,
211
+ target,
212
+ bias=None,
213
+ ce_weight=None,
214
+ ignore_index=-100,
215
+ lse_square_scale=0.0,
216
+ label_smoothing=0.0,
217
+ reduction="mean",
218
+ softcap=None,
219
+ return_z_loss: bool = False,
220
+ ):
221
+ """
222
+ Fusing the last linear layer with cross-entropy loss
223
+ Reference: https://github.com/mgmalek/efficient_cross_entropy
224
+
225
+ Handle the forward and backward pass of the final linear layer via cross-entropy loss by avoiding
226
+ the materialization of the large logits tensor. Since Cross Entropy Loss is the last layer, we can
227
+ compute the gradient at the forward pass. By doing so, we don't have to store the _input and target
228
+ for the backward pass.
229
+
230
+ _input: (B*T, H) where B is batch size, T is sequence length, H is hidden dimension.
231
+ target: (B*T) where each value is in [0, V-1]
232
+ weight: (V, H) where V is the number of classes
233
+ bias: (V) where V is the number of classes
234
+ ce_weight: a manual rescaling weight given to each class. If given, has to be a Tensor of size V and floating point dtype
235
+ ignore_index: the index to ignore in the target
236
+ label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
237
+ reduction: reduction to apply
238
+ """
239
+
240
+ loss, z_loss, grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_forward(
241
+ _input=_input,
242
+ weight=weight,
243
+ target=target,
244
+ bias=bias,
245
+ ce_weight=ce_weight,
246
+ ignore_index=ignore_index,
247
+ lse_square_scale=lse_square_scale,
248
+ label_smoothing=label_smoothing,
249
+ reduction=reduction,
250
+ softcap=softcap,
251
+ return_z_loss=return_z_loss,
252
+ )
253
+ # downcast to dtype and store for backward
254
+ ctx.save_for_backward(
255
+ grad_input.detach(),
256
+ grad_weight.detach() if grad_weight is not None else None,
257
+ grad_bias.detach() if bias is not None else None,
258
+ )
259
+ ctx.return_z_loss = return_z_loss
260
+ return loss, z_loss
261
+
262
+ @staticmethod
263
+ @amp_custom_bwd
264
+ def backward(ctx, grad_output, grad_output2):
265
+ if ctx.return_z_loss:
266
+ del grad_output2 # z_loss is only for logging
267
+ (grad_input, grad_weight, grad_bias) = ctx.saved_tensors
268
+ grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_backward(
269
+ grad_output, grad_input, grad_weight, grad_bias
270
+ )
271
+ return (
272
+ grad_input,
273
+ grad_weight,
274
+ None,
275
+ grad_bias,
276
+ None,
277
+ None,
278
+ None,
279
+ None,
280
+ None,
281
+ None,
282
+ None,
283
+ )
torch-ext/liger_kernels/geglu.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import operator
2
+
3
+ import torch
4
+ import triton
5
+ import triton.language as tl
6
+
7
+ from utils import calculate_settings
8
+ from utils import compare_version
9
+ from utils import ensure_contiguous
10
+
11
+ if compare_version("triton", operator.ge, "3.0.0"):
12
+ try:
13
+ # typical import path with dispatch available
14
+ from triton.language.extra.libdevice import tanh
15
+ except ModuleNotFoundError:
16
+ # for working with NGC containers
17
+ from triton.language.extra.cuda.libdevice import tanh
18
+ else:
19
+ from triton.language.math import tanh
20
+
21
+
22
+ @triton.jit
23
+ def _geglu_tanh_forward_kernel(a, b, c, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr):
24
+ program_id = tl.program_id(0).to(tl.int64)
25
+
26
+ # locate start index
27
+ a += program_id * stride
28
+ b += program_id * stride
29
+ c += program_id * stride
30
+
31
+ col_offsets = tl.arange(0, BLOCK_SIZE)
32
+ mask = col_offsets < n_cols
33
+ a_row = tl.load(a + col_offsets, mask=mask, other=0).to(tl.float32)
34
+ b_row = tl.load(b + col_offsets, mask=mask, other=0)
35
+
36
+ # tanh approximation form of GELU is computed with:
37
+ # 0.5 * a * (1 + tanh(sqrt(2 / pi) * (a + 0.044715 * a^3)))
38
+ sqrt_2_over_pi = 0.7978845608028654 # sqrt(2 / pi)
39
+ a_cubed = a_row * a_row * a_row
40
+ tanh_arg = sqrt_2_over_pi * (a_row + 0.044715 * a_cubed)
41
+ tanh_result = tanh(tanh_arg)
42
+ geglu_a = 0.5 * a_row * (1 + tanh_result)
43
+ c_row = geglu_a * b_row
44
+ tl.store(c + col_offsets, c_row, mask=mask)
45
+
46
+
47
+ @triton.jit
48
+ def _geglu_tanh_backward_kernel(dc, a, b, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr):
49
+ program_id = tl.program_id(0).to(tl.int64)
50
+
51
+ # locate start index
52
+ dc += program_id * stride
53
+ a += program_id * stride
54
+ b += program_id * stride
55
+
56
+ col_offsets = tl.arange(0, BLOCK_SIZE)
57
+ mask = col_offsets < n_cols
58
+
59
+ dc_row = tl.load(dc + col_offsets, mask=mask, other=0)
60
+ a_row = tl.load(a + col_offsets, mask=mask, other=0).to(tl.float32)
61
+ b_row = tl.load(b + col_offsets, mask=mask, other=0)
62
+
63
+ # recomputation to save memory
64
+ sqrt_2_over_pi = 0.7978845608028654 # sqrt(2 / pi)
65
+ a_cubed = a_row * a_row * a_row
66
+ tanh_arg = sqrt_2_over_pi * (a_row + 0.044715 * a_cubed)
67
+ tanh_result = tanh(tanh_arg)
68
+ geglu_a = 0.5 * a_row * (1 + tanh_result)
69
+
70
+ db_row = dc_row * geglu_a
71
+
72
+ # Gradient w.r.t. a can be computed with:
73
+ # b * (0.5 * (1 + tanh(z)) + 0.5 * a * (1 - tanh(z)^2) * (sqrt(2/pi) * (1 + 3 * 0.044715 * a^2)))
74
+ # where z = sqrt(2/pi) * (a + 0.044715 * a^3)
75
+ term1 = 0.5 * (1 + tanh_result)
76
+ tanh_sq = tanh_result * tanh_result
77
+ term2 = 0.5 * a_row * (1 - tanh_sq) * (sqrt_2_over_pi * (1 + 3 * 0.044715 * a_row * a_row))
78
+ da_row = dc_row * b_row * (term1 + term2)
79
+
80
+ tl.store(a + col_offsets, da_row, mask=mask)
81
+ tl.store(b + col_offsets, db_row, mask=mask)
82
+
83
+
84
+ def geglu_forward(a, b):
85
+ ori_shape = a.shape
86
+
87
+ n_cols = ori_shape[-1]
88
+ a = a.view(-1, n_cols)
89
+ b = b.view(-1, n_cols)
90
+ c = torch.empty_like(a)
91
+ n_rows = a.shape[0]
92
+
93
+ BLOCK_SIZE, num_warps = calculate_settings(n_cols)
94
+
95
+ _geglu_tanh_forward_kernel[(n_rows,)](
96
+ a,
97
+ b,
98
+ c,
99
+ c.stride(-2),
100
+ n_cols=n_cols,
101
+ BLOCK_SIZE=BLOCK_SIZE,
102
+ num_warps=num_warps,
103
+ )
104
+ return a, b, c.view(*ori_shape)
105
+
106
+
107
+ def geglu_backward(a, b, dc):
108
+ ori_shape = dc.shape
109
+ n_cols = ori_shape[-1]
110
+ dc = dc.view(-1, n_cols)
111
+ n_rows = dc.shape[0]
112
+
113
+ BLOCK_SIZE, num_warps = calculate_settings(n_cols)
114
+
115
+ _geglu_tanh_backward_kernel[(n_rows,)](
116
+ dc,
117
+ a,
118
+ b,
119
+ dc.stride(-2),
120
+ n_cols=n_cols,
121
+ BLOCK_SIZE=BLOCK_SIZE,
122
+ num_warps=num_warps,
123
+ )
124
+
125
+ return a.view(*ori_shape), b.view(*ori_shape)
126
+
127
+
128
+ class LigerGELUMulFunction(torch.autograd.Function):
129
+ @staticmethod
130
+ @ensure_contiguous
131
+ def forward(ctx, a, b):
132
+ a, b, c = geglu_forward(a, b)
133
+ ctx.save_for_backward(a, b)
134
+ return c
135
+
136
+ @staticmethod
137
+ @ensure_contiguous
138
+ def backward(ctx, dc):
139
+ a, b = ctx.saved_tensors
140
+ a, b = geglu_backward(a, b, dc)
141
+ return a, b
torch-ext/liger_kernels/group_norm.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import operator
2
+
3
+ import torch
4
+ import triton
5
+ import triton.language as tl
6
+
7
+ from utils import compare_version
8
+ from utils import ensure_contiguous
9
+
10
+ if compare_version("triton", operator.ge, "3.0.0"):
11
+ try:
12
+ # typical import path with dispatch available
13
+ from triton.language.extra.libdevice import rsqrt
14
+ except ModuleNotFoundError:
15
+ # for working with NGC containers
16
+ from triton.language.extra.cuda.libdevice import rsqrt
17
+ else:
18
+ from triton.language.math import rsqrt
19
+
20
+ MAX_FUSED_SIZE = 65536
21
+
22
+
23
+ @triton.jit
24
+ def _group_norm_forward_kernel(
25
+ Y_ptr, # pointer to output, shape (n_rows, n_groups, hidden_size)
26
+ Y_row_stride, # stride of each row in output
27
+ Y_col_stride, # stride of each column in output
28
+ X_ptr, # pointer to input, shape (n_rows, n_groups, hidden_size)
29
+ X_row_stride, # stride of each row in input
30
+ X_col_stride, # stride of each column in input
31
+ Mean_ptr, # pointer to mean, shape (n_rows, n_groups)
32
+ Mean_row_stride, # stride of each row in mean
33
+ Mean_col_stride, # stride of each column in mean
34
+ RSTD_ptr, # pointer to rstd, shape (n_rows, n_groups)
35
+ RSTD_row_stride, # stride of each row in rstd
36
+ RSTD_col_stride, # stride of each column in rstd
37
+ W_ptr, # pointer to W
38
+ B_ptr, # pointer to B
39
+ hidden_size, # hidden size of X
40
+ channels_per_group, # the number of channels per group
41
+ eps,
42
+ BLOCK_SIZE: tl.constexpr,
43
+ ):
44
+ """
45
+ References:
46
+ https://nn.labml.ai/normalization/group_norm/index.html
47
+ """
48
+ batch_idx = tl.program_id(0)
49
+ group_idx = tl.program_id(1)
50
+
51
+ X_ptr += batch_idx * X_row_stride + group_idx * X_col_stride
52
+ Y_ptr += batch_idx * Y_row_stride + group_idx * Y_col_stride
53
+
54
+ block_range = tl.arange(0, BLOCK_SIZE)
55
+
56
+ # Compute mean and variance using the online algorithm
57
+ s = 0.0
58
+ squared_sum = 0.0
59
+ for i in tl.range(0, hidden_size, BLOCK_SIZE):
60
+ hidden_size_offsets = i + block_range
61
+ mask = hidden_size_offsets < hidden_size
62
+ X = tl.load(X_ptr + hidden_size_offsets, mask=mask, other=0.0)
63
+ s += tl.sum(X)
64
+ # X**2
65
+ squared_sum += tl.sum(X * X)
66
+
67
+ m = s / hidden_size
68
+
69
+ # variance = E[X**2] - E[X]**2
70
+ variance = (squared_sum / hidden_size) - (m * m)
71
+
72
+ # 1/std
73
+ rstd = rsqrt(variance + eps)
74
+
75
+ # Normalize
76
+ hidden_size_per_channel = hidden_size // channels_per_group
77
+ for channel_idx in tl.range(group_idx * channels_per_group, (group_idx + 1) * channels_per_group):
78
+ W = tl.load(W_ptr + channel_idx)
79
+ B = tl.load(B_ptr + channel_idx)
80
+ for i in range(0, hidden_size_per_channel, BLOCK_SIZE):
81
+ hidden_size_offsets = i + block_range
82
+ mask = hidden_size_offsets < hidden_size_per_channel
83
+ X = tl.load(X_ptr + hidden_size_offsets, mask=mask, other=m)
84
+ Y = (X - m) * rstd * W + B
85
+ tl.store(Y_ptr + hidden_size_offsets, Y, mask=mask)
86
+
87
+ X_ptr += hidden_size_per_channel
88
+ Y_ptr += hidden_size_per_channel
89
+
90
+ tl.store(Mean_ptr + batch_idx * Mean_row_stride + group_idx * Mean_col_stride, m)
91
+ tl.store(RSTD_ptr + batch_idx * RSTD_row_stride + group_idx * RSTD_col_stride, rstd)
92
+
93
+
94
+ @triton.jit
95
+ def _group_norm_backward_kernel(
96
+ X_ptr, # pointer to input, shape (n_rows, n_channels, hidden_size)
97
+ X_row_stride, # stride of each row in input
98
+ X_col_stride, # stride of each column in input
99
+ W_ptr, # pointer to weights, shape (n_channels)
100
+ Mean_ptr, # pointer to mean, shape (n_rows, n_groups)
101
+ Mean_ptr_row_stride, # stride of each column in mean
102
+ Mean_ptr_col_stride, # stride of each column in mean
103
+ RSTD_ptr, # pointer to rstd, shape (n_rows, n_groups)
104
+ DX_ptr, # pointer to input grad, shape (n_rows, n_groups, hidden_size)
105
+ DW_ptr, # pointer to weights grad, shape (n_channels)
106
+ DB_ptr, # pointer to bias grad, shape (n_channels)
107
+ UPSTREAM_ptr, # pointer to output grad, shape (n_rows, n_channels, hidden_size)
108
+ hidden_size: tl.constexpr, # hidden size
109
+ channels_per_group: tl.constexpr, # number of groups in group norm
110
+ BLOCK_SIZE: tl.constexpr,
111
+ dtype: tl.constexpr,
112
+ ):
113
+ """
114
+ References:
115
+ https://nn.labml.ai/normalization/group_norm/index.html
116
+ https://github.com/karpathy/llm.c/blob/master/doc/layernorm/layernorm.md
117
+
118
+ The backprop equations are the same for group_norm and layer_norm
119
+ the only difference here is that we load the Mean, Rstd corresponding to the
120
+ group we're computing gradients for and the mean and rstd are computed over n-channels
121
+ so the total number of elements we compute the mean over is num_channels_per_group * hidden_size
122
+
123
+ We also need to load the Weights corresponding to the current channel to compute the gradients.
124
+ """
125
+ batch_idx = tl.program_id(0)
126
+ group_idx = tl.program_id(1)
127
+
128
+ # Move the pointers to the correct batch
129
+ X_ptr += batch_idx * X_row_stride
130
+ DX_ptr += batch_idx * X_row_stride
131
+ UPSTREAM_ptr += batch_idx * X_row_stride
132
+
133
+ # Mean and rstd are the same shape so have the same strides
134
+ mean = tl.load(Mean_ptr + batch_idx * Mean_ptr_row_stride + group_idx * Mean_ptr_col_stride)
135
+ rstd = tl.load(RSTD_ptr + batch_idx * Mean_ptr_row_stride + group_idx * Mean_ptr_col_stride)
136
+
137
+ c1 = 0.0
138
+ c2 = 0.0
139
+ block_range = tl.arange(0, BLOCK_SIZE)
140
+
141
+ # We need to compute the sum terms of the backprop equations across all channels in the group
142
+ for channel_idx in range(group_idx * channels_per_group, (group_idx + 1) * channels_per_group):
143
+ dW = 0.0
144
+ dB = 0.0
145
+ # Move the pointers to the correct channel
146
+ W = tl.load(W_ptr + channel_idx)
147
+ for i in tl.range(0, hidden_size, BLOCK_SIZE):
148
+ hidden_size_offsets = i + block_range
149
+ mask = hidden_size_offsets < hidden_size
150
+ X = tl.load(
151
+ X_ptr + channel_idx * X_col_stride + hidden_size_offsets,
152
+ mask=mask,
153
+ other=0.0,
154
+ )
155
+ UPSTREAM_grad = tl.load(
156
+ UPSTREAM_ptr + channel_idx * X_col_stride + hidden_size_offsets,
157
+ mask=mask,
158
+ other=0.0,
159
+ )
160
+
161
+ x_hat = (X - mean) * rstd
162
+ dW += tl.sum(UPSTREAM_grad * x_hat)
163
+ dB += tl.sum(UPSTREAM_grad)
164
+
165
+ wdy = W * UPSTREAM_grad
166
+ c1 += tl.sum(x_hat * wdy)
167
+ c2 += tl.sum(wdy)
168
+
169
+ # Need to ensure additions to the same channel are atomic
170
+ tl.atomic_add(DW_ptr + channel_idx, dW.to(dtype))
171
+ tl.atomic_add(DB_ptr + channel_idx, dB.to(dtype))
172
+
173
+ N = hidden_size * channels_per_group
174
+ c1 = c1 / N
175
+ c2 = c2 / N
176
+
177
+ for channel_idx in tl.range(group_idx * channels_per_group, (group_idx + 1) * channels_per_group):
178
+ # Move the pointers to the correct channel
179
+ W = tl.load(W_ptr + channel_idx)
180
+ for i in range(0, hidden_size, BLOCK_SIZE):
181
+ hidden_size_offsets = i + block_range
182
+ mask = hidden_size_offsets < hidden_size
183
+ X = tl.load(
184
+ X_ptr + channel_idx * X_col_stride + hidden_size_offsets,
185
+ mask=mask,
186
+ other=0.0,
187
+ )
188
+ UPSTREAM_grad = tl.load(
189
+ UPSTREAM_ptr + channel_idx * X_col_stride + hidden_size_offsets,
190
+ mask=mask,
191
+ other=0.0,
192
+ )
193
+
194
+ x_hat = (X - mean) * rstd
195
+ wdy = W * UPSTREAM_grad
196
+ dx = (wdy - (x_hat * c1 + c2)) * rstd
197
+ tl.store(DX_ptr + channel_idx * X_col_stride + hidden_size_offsets, dx, mask=mask)
198
+
199
+
200
+ def group_norm_forward(X, num_channels, num_groups, W, B, eps):
201
+ shape = X.shape
202
+ batch_size = shape[0]
203
+ channels_per_group = num_channels // num_groups
204
+ # Reshape X so that the mean and std are computed across the groups
205
+ X = X.view(batch_size, num_groups, -1).contiguous()
206
+ hidden_size = X.shape[-1]
207
+ BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(hidden_size))
208
+ Y = torch.empty((batch_size, num_groups, hidden_size), dtype=X.dtype, device=X.device)
209
+ Mean = torch.zeros((batch_size, num_groups), dtype=X.dtype, device=X.device)
210
+ RSTD = torch.zeros((batch_size, num_groups), dtype=X.dtype, device=X.device)
211
+
212
+ _group_norm_forward_kernel[(batch_size, num_groups)](
213
+ Y,
214
+ Y.stride(0),
215
+ Y.stride(1),
216
+ X,
217
+ X.stride(0),
218
+ X.stride(1),
219
+ Mean,
220
+ Mean.stride(0),
221
+ Mean.stride(1),
222
+ RSTD,
223
+ RSTD.stride(0),
224
+ RSTD.stride(1),
225
+ W,
226
+ B,
227
+ hidden_size,
228
+ channels_per_group,
229
+ eps,
230
+ BLOCK_SIZE=BLOCK_SIZE,
231
+ )
232
+ # Return tensors in the original shape
233
+ return Y.view(*shape), X.view(*shape), Mean, RSTD, BLOCK_SIZE
234
+
235
+
236
+ def group_norm_backward(dY, X, W, B, Mean, RSTD, num_channels, num_groups):
237
+ shape = dY.shape
238
+ batch_size = shape[0]
239
+ hidden_size = dY.shape[-1]
240
+ channels_per_group = num_channels // num_groups
241
+ dY = dY.view(batch_size, num_groups, -1)
242
+ DX = torch.empty(
243
+ (batch_size, num_groups, hidden_size * channels_per_group),
244
+ dtype=X.dtype,
245
+ device=X.device,
246
+ )
247
+ DW = torch.zeros((num_channels), dtype=W.dtype, device=W.device)
248
+ DB = torch.zeros((num_channels), dtype=B.dtype, device=B.device)
249
+ triton_dtype = tl.float32 if X.dtype == torch.float32 else tl.bfloat16
250
+
251
+ BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(hidden_size))
252
+ _group_norm_backward_kernel[(batch_size, num_groups)](
253
+ X,
254
+ X.stride(0),
255
+ X.stride(1),
256
+ W,
257
+ Mean,
258
+ Mean.stride(0),
259
+ Mean.stride(1),
260
+ RSTD,
261
+ DX,
262
+ DW,
263
+ DB,
264
+ dY,
265
+ hidden_size,
266
+ channels_per_group,
267
+ BLOCK_SIZE=BLOCK_SIZE,
268
+ dtype=triton_dtype,
269
+ )
270
+
271
+ # Return tensors in the original shape
272
+ return DX.view(*shape), DW, DB
273
+
274
+
275
+ class LigerGroupNormFunction(torch.autograd.Function):
276
+ @staticmethod
277
+ @ensure_contiguous
278
+ def forward(
279
+ ctx,
280
+ X,
281
+ affine_scaling_weight,
282
+ affine_shifting_bias,
283
+ num_channels,
284
+ num_groups,
285
+ eps,
286
+ ):
287
+ Y, X, Mean, RSTD, BLOCK_SIZE = group_norm_forward(
288
+ X,
289
+ num_channels,
290
+ num_groups,
291
+ affine_scaling_weight,
292
+ affine_shifting_bias,
293
+ eps,
294
+ )
295
+ ctx.num_channels = num_channels
296
+ ctx.num_groups = num_groups
297
+ ctx.save_for_backward(X, affine_scaling_weight, affine_shifting_bias, Mean, RSTD)
298
+ return Y
299
+
300
+ @staticmethod
301
+ @ensure_contiguous
302
+ def backward(ctx, dY):
303
+ X, W, B, Mean, RSTD = ctx.saved_tensors
304
+ DX, DW, DB = group_norm_backward(dY, X, W, B, Mean, RSTD, ctx.num_channels, ctx.num_groups)
305
+ return DX, DW, DB, None, None, None
torch-ext/liger_kernels/jsd.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+ import triton
5
+ import triton.language as tl
6
+
7
+ from utils import ensure_contiguous
8
+ from utils import infer_device
9
+
10
+
11
+ @triton.jit
12
+ def _jsd_kernel(
13
+ X_ptr, # input in logspace, X = log Q
14
+ X_stride,
15
+ Y_ptr, # ground truth in logspace, Y = log P
16
+ Y_stride,
17
+ loss_ptr,
18
+ loss_stride,
19
+ dX_ptr,
20
+ dX_stride,
21
+ label_ptr,
22
+ beta: tl.constexpr,
23
+ n_non_ignore: int,
24
+ ignore_index: tl.constexpr,
25
+ n_cols,
26
+ BLOCK_SIZE: tl.constexpr,
27
+ HAS_LABEL: tl.constexpr,
28
+ ):
29
+ # JSD(P || Q) = (KL(P || M) + KL(Q || M)) / 2, M = (1/2) * (P + Q) = (1/2) * (e ^ Y + e ^ X)
30
+ # = sum(P * log P + Q * log Q - 2 * M * log M) / 2
31
+ # = sum(e ^ Y * Y + e ^ X * X - 2 * M * log M) / 2
32
+ # grad_x_i = 0.5 * Q * (X - log_M)
33
+ pid = tl.program_id(0).to(tl.int64)
34
+ X_ptr += pid * X_stride
35
+ dX_ptr += pid * dX_stride
36
+ Y_ptr += pid * Y_stride
37
+ loss_ptr += pid * loss_stride
38
+ label_ptr += pid
39
+
40
+ if HAS_LABEL:
41
+ label = tl.load(label_ptr)
42
+ if label == ignore_index:
43
+ for i in range(0, n_cols, BLOCK_SIZE):
44
+ offsets = i + tl.arange(0, BLOCK_SIZE)
45
+ tl.store(dX_ptr + offsets, 0.0, mask=offsets < n_cols)
46
+ return
47
+
48
+ for i in range(0, n_cols, BLOCK_SIZE):
49
+ offsets = i + tl.arange(0, BLOCK_SIZE)
50
+ mask = offsets < n_cols
51
+ X = tl.load(X_ptr + offsets, mask=mask, other=float("-inf")).to(tl.float32)
52
+ Y = tl.load(Y_ptr + offsets, mask=mask, other=float("-inf")).to(tl.float32)
53
+
54
+ if beta == 0.0: # forward KL
55
+ Y_max = tl.max(Y, axis=0)
56
+ Y_shifted = Y - Y_max
57
+ Y_prob = tl.exp(Y_shifted) * tl.exp(Y_max) # Compensate for the shift
58
+ loss = Y_prob * (Y - X)
59
+ dX = -Y_prob
60
+ elif beta == 1.0: # reverse KL
61
+ X_max = tl.max(X, axis=0)
62
+ X_shifted = X - X_max
63
+ X_prob = tl.exp(X_shifted) * tl.exp(X_max) # Compensate for the shift
64
+ loss = X_prob * (X - Y)
65
+ dX = loss + X_prob
66
+ else:
67
+ max_val = tl.maximum(tl.max(X, axis=0), tl.max(Y, axis=0))
68
+ X_shifted = X - max_val
69
+ Y_shifted = Y - max_val
70
+
71
+ # Pre-compute exp(max_val) since it's used twice
72
+ exp_max = tl.exp(max_val)
73
+
74
+ # Compute exp terms with compensation
75
+ Q = tl.exp(X_shifted) * exp_max # = exp(X)
76
+ P = tl.exp(Y_shifted) * exp_max # = exp(Y)
77
+
78
+ # Pre-compute common terms
79
+ beta_P = beta * P
80
+ one_minus_beta_Q = (1 - beta) * Q
81
+ M = beta_P + one_minus_beta_Q
82
+ log_M = tl.log(M) # No need to compensate as M is already in original scale
83
+
84
+ loss = beta_P * Y + one_minus_beta_Q * X - M * log_M
85
+ dX = one_minus_beta_Q * (X - log_M)
86
+
87
+ # Pre-compute scaling factor
88
+ scale = 1.0 / n_non_ignore
89
+ loss = loss * scale
90
+ dX = dX * scale
91
+
92
+ tl.store(loss_ptr + offsets, loss, mask=mask)
93
+ tl.store(dX_ptr + offsets, dX, mask=mask)
94
+
95
+
96
+ MAX_FUSED_SIZE = 4096 if infer_device() == "xpu" else 65536
97
+
98
+
99
+ def jsd_forward(_input, target, shift_labels, beta, ignore_index, has_label):
100
+ BT, V = _input.shape
101
+ n_rows = BT
102
+ BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
103
+ # non reduction loss
104
+ loss = torch.zeros(_input.shape, dtype=torch.float32, device=_input.device)
105
+ dX = torch.empty_like(_input)
106
+
107
+ if has_label:
108
+ n_non_ignore = (shift_labels != ignore_index).sum().item()
109
+ else:
110
+ n_non_ignore = BT
111
+
112
+ _jsd_kernel[(n_rows,)](
113
+ X_ptr=_input, # input in logspace, X = log Q
114
+ X_stride=_input.stride(-2),
115
+ Y_ptr=target, # ground truth in logspace, Y = log P
116
+ Y_stride=target.stride(-2),
117
+ loss_ptr=loss,
118
+ loss_stride=loss.stride(-2),
119
+ dX_ptr=dX,
120
+ dX_stride=dX.stride(-2),
121
+ label_ptr=(shift_labels if has_label else torch.empty(1, device=_input.device)), # dummy ptr if no label
122
+ beta=beta,
123
+ n_non_ignore=n_non_ignore,
124
+ ignore_index=ignore_index,
125
+ n_cols=V,
126
+ BLOCK_SIZE=BLOCK_SIZE,
127
+ HAS_LABEL=has_label,
128
+ )
129
+
130
+ loss = torch.sum(loss)
131
+ return loss.to(_input.dtype), dX
132
+
133
+
134
+ def jsd_backward(dX, grad_output):
135
+ # If jsd is the last layer, grad_output is 1.0. Skip the mul to save time
136
+ if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
137
+ return dX
138
+ else:
139
+ return grad_output * dX
140
+
141
+
142
+ class LigerJSDFunction(torch.autograd.Function):
143
+ r"""
144
+ This class implements the forward and backward pass for the generalized Jensen-Shannon Divergence.
145
+ .. math::
146
+ JSD(\beta)(P || Q)
147
+ = \beta * KLDiv(P || (\beta * P + (1 - \beta) * Q)) + (1 - \beta) * KLDiv(Q || (\beta * P + (1 - \beta) * Q))
148
+
149
+ .. note::
150
+ As all the other losses in PyTorch, this function expects the first argument,
151
+ :attr:`_input`, to be the predictions, the output of the student model, in log-space
152
+ and the second, :attr:`target`, to be the observations, the output of the teacher model, in log-space.
153
+ This differs from the standard mathematical notation :math:`JSD(P || Q)` where
154
+ :math:`P` denotes the teacher model and :math:`Q` denotes the student model.
155
+ """
156
+
157
+ @staticmethod
158
+ @ensure_contiguous
159
+ def forward(
160
+ ctx,
161
+ _input: torch.Tensor,
162
+ target: torch.Tensor,
163
+ shift_labels: Optional[torch.Tensor] = None,
164
+ beta: float = 0.5,
165
+ ignore_index: int = -100,
166
+ ) -> torch.Tensor:
167
+ """
168
+ Args:
169
+ _input (torch.Tensor): predict values with shape (BT, V) in logspace
170
+ target (torch.Tensor): ground truth values with shape (BT, V) in logspace
171
+ shift_labels (Optional[torch.LongTensor]): indicator of next predicted vocab with shape (BT) where each value is in [0, V-1].
172
+ beta (float): coefficient beta of generalized JSD in the interval [0, 1]. It implements forward/reverse KL when beta equals 0 and 1 respectively. Default: `0.5`
173
+ ignore_index (int): the index to ignore. Default: -100
174
+
175
+ Returns:
176
+ loss (torch.Tensor): generalized JSD
177
+ """
178
+ has_label = False
179
+ if shift_labels is not None:
180
+ assert shift_labels.shape == (_input.shape[0],), (
181
+ f"the shape of shift_labels must be (BT,). Got: {shift_labels.shape}"
182
+ )
183
+ shift_labels = shift_labels.contiguous()
184
+ has_label = True
185
+
186
+ loss, dX = jsd_forward(_input, target, shift_labels, beta, ignore_index, has_label)
187
+ ctx.save_for_backward(dX)
188
+ return loss
189
+
190
+ @staticmethod
191
+ @ensure_contiguous
192
+ def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
193
+ (dX,) = ctx.saved_tensors
194
+ dX = jsd_backward(dX, grad_output)
195
+ return (
196
+ dX,
197
+ None,
198
+ None,
199
+ None,
200
+ None,
201
+ )
torch-ext/liger_kernels/kl_div.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Literal
2
+
3
+ import torch
4
+ import triton
5
+ import triton.language as tl
6
+
7
+ from utils import ensure_contiguous
8
+ from utils import is_hip
9
+ from utils import infer_device
10
+
11
+
12
+ def get_num_warps(BLOCK_SIZE):
13
+ num_warps = 4
14
+ if BLOCK_SIZE >= 32768:
15
+ num_warps = 32 if not is_hip() else 16
16
+ elif BLOCK_SIZE >= 8192:
17
+ num_warps = 16
18
+ elif BLOCK_SIZE >= 2048:
19
+ num_warps = 8
20
+
21
+ return num_warps
22
+
23
+
24
+ MAX_FUSED_SIZE = 65536 // 4 # 65536 // 4 or 8 works the best
25
+
26
+ REDUCTION_LITERAL = Literal["none", "sum", "mean", "batchmean"]
27
+
28
+ _REDUCTION_MODE_NONE: tl.constexpr = tl.constexpr(0)
29
+ _REDUCTION_MODE_SUM: tl.constexpr = tl.constexpr(1)
30
+ _REDUCTION_MODE_MEAN: tl.constexpr = tl.constexpr(2)
31
+ _REDUCTION_MODE_BATCHMEAN: tl.constexpr = tl.constexpr(3)
32
+
33
+ _str_to_reduction_mode = {
34
+ "none": _REDUCTION_MODE_NONE.value,
35
+ "sum": _REDUCTION_MODE_SUM.value,
36
+ "mean": _REDUCTION_MODE_MEAN.value,
37
+ "batchmean": _REDUCTION_MODE_BATCHMEAN.value,
38
+ }
39
+
40
+
41
+ @triton.jit
42
+ def _kldiv_kernel_forward(
43
+ y_ptr, # [B, S], prediction ptr, the kernel expects the prediction in log-space
44
+ y_stride, # int, prediction stride
45
+ gt_ptr, # [B, S], ground truth ptr
46
+ gt_stride, # int, ground truth stride
47
+ loss_ptr, # [B] or [B, S] if reduction == _REDUCTION_MODE_NONE, output ptr
48
+ loss_stride, # int, output stride
49
+ n_cols, # int, number of columns in the input tensor
50
+ eps,
51
+ BLOCK_SIZE: tl.constexpr,
52
+ log_target: tl.constexpr = False,
53
+ reduction: tl.constexpr = _REDUCTION_MODE_BATCHMEAN,
54
+ ):
55
+ pid = tl.program_id(0).to(tl.int64)
56
+ y_ptr += pid * y_stride
57
+ gt_ptr += pid * gt_stride
58
+ loss_ptr += pid * loss_stride
59
+
60
+ base_offsets = tl.arange(0, BLOCK_SIZE)
61
+
62
+ loss_sum = 0.0
63
+ for i in range(0, n_cols, BLOCK_SIZE):
64
+ offsets = i + base_offsets
65
+ mask = offsets < n_cols
66
+ y = tl.load(y_ptr + offsets, mask=mask, other=0.0)
67
+ y_true = tl.load(gt_ptr + offsets, mask=mask, other=0.0)
68
+
69
+ # KL(y_true || y) = y_true * (log(y_true) - log(y))
70
+ # We compute KL(y_true || y) with y in the log-space
71
+ if not log_target:
72
+ loss = y_true * (tl.log(tl.maximum(y_true, eps)) - y)
73
+ else:
74
+ loss = tl.exp(y_true) * (y_true - y)
75
+
76
+ if reduction == _REDUCTION_MODE_NONE:
77
+ tl.store(loss_ptr + offsets, loss, mask=mask)
78
+ else:
79
+ loss_sum += tl.sum(loss, axis=0)
80
+
81
+ if reduction != _REDUCTION_MODE_NONE:
82
+ tl.store(loss_ptr, loss_sum)
83
+
84
+
85
+ @triton.jit
86
+ def _kldiv_kernel_backward(
87
+ target_ptr,
88
+ target_stride,
89
+ new_grads_ptr,
90
+ new_grads_stride,
91
+ n_cols,
92
+ BLOCK_SIZE: tl.constexpr,
93
+ log_target: tl.constexpr = False,
94
+ ):
95
+ pid = tl.program_id(0).to(tl.int64)
96
+
97
+ target_ptr += pid * target_stride
98
+ new_grads_ptr += pid * new_grads_stride
99
+
100
+ offsets = tl.arange(0, BLOCK_SIZE)
101
+ mask = offsets < n_cols
102
+
103
+ for i in range(0, n_cols, BLOCK_SIZE):
104
+ offsets = i + tl.arange(0, BLOCK_SIZE)
105
+ mask = offsets < n_cols
106
+
107
+ target = tl.load(target_ptr + offsets, mask=mask, other=0.0)
108
+
109
+ if not log_target:
110
+ res = target * -1
111
+ else:
112
+ res = -tl.exp(target)
113
+
114
+ tl.store(new_grads_ptr + offsets, res, mask=mask)
115
+
116
+
117
+ def kldiv_forward_triton(y_pred, y_true, log_target, reduction, eps): # [BT, V]
118
+ BT, V = y_pred.shape
119
+ BLOCK_SIZE = (
120
+ min(8192, triton.next_power_of_2(V))
121
+ if infer_device() == "xpu"
122
+ else min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
123
+ )
124
+ num_warps = 32 if infer_device() == "xpu" else get_num_warps(BLOCK_SIZE)
125
+
126
+ grid = (BT,)
127
+ reduction = _str_to_reduction_mode[reduction]
128
+
129
+ out_size = (BT, V) if reduction == _REDUCTION_MODE_NONE.value else (BT,)
130
+ output_tensor = torch.zeros(out_size, device=y_pred.device, dtype=torch.float32)
131
+
132
+ _kldiv_kernel_forward[grid](
133
+ y_pred,
134
+ y_pred.stride(0),
135
+ y_true,
136
+ y_true.stride(0),
137
+ output_tensor,
138
+ output_tensor.stride(0),
139
+ V,
140
+ eps=eps,
141
+ BLOCK_SIZE=BLOCK_SIZE,
142
+ num_warps=num_warps,
143
+ log_target=log_target,
144
+ reduction=reduction,
145
+ )
146
+
147
+ # calculated according to the reduction mode same as in Pytorch. In the later versions, `mean` will be changed to the same behavior as `batchmean`
148
+ # https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html
149
+ # https://github.com/pytorch/pytorch/blob/d7b57c4d63edb42e1deeeba9497fcb5f1f748ff2/torch/nn/functional.py#L3372
150
+ if reduction == _REDUCTION_MODE_BATCHMEAN.value:
151
+ return output_tensor.sum() / BT
152
+ elif reduction == _REDUCTION_MODE_SUM.value:
153
+ return output_tensor.sum(dim=0)
154
+ elif reduction == _REDUCTION_MODE_MEAN.value:
155
+ return output_tensor.sum() / (BT * V)
156
+ else:
157
+ return output_tensor
158
+
159
+
160
+ def kldiv_backward_triton(target, grad_output, new_grads, log_target):
161
+ BT, V = target.shape
162
+ BLOCK_SIZE = (
163
+ min(8192, triton.next_power_of_2(V))
164
+ if infer_device() == "xpu"
165
+ else min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
166
+ )
167
+ num_warps = 32 if infer_device() == "xpu" else get_num_warps(BLOCK_SIZE)
168
+
169
+ grid = (BT,)
170
+
171
+ # We store the gradients in-place in the input tensor
172
+ _kldiv_kernel_backward[grid](
173
+ target,
174
+ target.stride(0),
175
+ new_grads,
176
+ new_grads.stride(0),
177
+ V,
178
+ BLOCK_SIZE=BLOCK_SIZE,
179
+ num_warps=num_warps,
180
+ log_target=log_target,
181
+ )
182
+
183
+ # If cross entropy is the last layer, grad_output is 1.0. Skip the mul then.
184
+ if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
185
+ return new_grads
186
+
187
+ return new_grads * grad_output
188
+
189
+
190
+ class LigerKLDivLossFunction(torch.autograd.Function):
191
+ """
192
+ Class implementing the forward and backward pass for the KL Divergence Loss using Triton, as defined by the following formula:
193
+ ```python
194
+ if log_target:
195
+ loss = target.exp() * (target - input)
196
+ else:
197
+ loss = target * (target.log() - input)
198
+ ```,
199
+ then the loss is reduced according to the `reduction` parameter.
200
+ as defined in the PyTorch documentation: https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html
201
+ """
202
+
203
+ @staticmethod
204
+ @ensure_contiguous
205
+ def forward(
206
+ ctx,
207
+ y_pred: torch.Tensor,
208
+ y_true: torch.Tensor,
209
+ reduction: REDUCTION_LITERAL = "batchmean",
210
+ log_target: bool = False,
211
+ eps: float = 1e-10,
212
+ ) -> torch.Tensor:
213
+ """A forward pass for the KL Divergence Loss.
214
+
215
+ Args:
216
+ ctx: Torch autograd context
217
+ y_pred (torch.Tensor): A tensor of shape (BT, V) containing the predicted values, expected to be log-probabilities.
218
+ y_true (torch.Tensor): A tensor of shape (BT, V) containing the target values, expected to be either probabilities or log-probabilities, depending on the value of `log_target`.
219
+ reduction (REDUCTION_LITERAL, optional): Reduction to be used. Defaults to "batchmean".
220
+ log_target (bool, optional): If set to true, expects the ground truth to already be log-probabilities. Defaults to False.
221
+ eps: (float, optional): A small value to avoid division by zero. Defaults to 1e-10.
222
+
223
+ Returns:
224
+ torch.Tensor: The computed KL Divergence Loss, with shape (BT, V) if `reduction` is "none", else a scalar.
225
+ """
226
+ ctx.save_for_backward(y_true)
227
+ ctx.reduction = reduction
228
+ ctx.log_target = log_target
229
+ return kldiv_forward_triton(y_pred, y_true, log_target=log_target, reduction=reduction, eps=eps)
230
+
231
+ @staticmethod
232
+ @ensure_contiguous
233
+ def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
234
+ """A backward pass for the KL Divergence Loss.
235
+
236
+ Args:
237
+ ctx: Torch autograd context
238
+ grad_output (torch.Tensor): The gradient of the loss with respect to the output.
239
+
240
+ Returns:
241
+ tuple[torch.Tensor, None, None, None, None]: The gradient of the loss with respect to the inputs and None for the other arguments of the forward method.
242
+ """
243
+ (y_true,) = ctx.saved_tensors
244
+
245
+ new_grads = torch.empty_like(y_true)
246
+
247
+ derivative = kldiv_backward_triton(y_true, grad_output, new_grads, ctx.log_target)
248
+
249
+ if ctx.reduction == "batchmean":
250
+ derivative = derivative / y_true.shape[0]
251
+ elif ctx.reduction == "sum" or ctx.reduction == "none":
252
+ pass
253
+ elif ctx.reduction == "mean":
254
+ derivative = derivative / (y_true.shape[0] * y_true.shape[1])
255
+
256
+ return (
257
+ derivative,
258
+ None,
259
+ None,
260
+ None,
261
+ None,
262
+ )
torch-ext/liger_kernels/layer_norm.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import operator
3
+
4
+ import torch
5
+ import triton
6
+ import triton.language as tl
7
+
8
+ from utils import calculate_settings
9
+ from utils import compare_version
10
+ from utils import ensure_contiguous
11
+
12
+ if compare_version("triton", operator.ge, "3.0.0"):
13
+ try:
14
+ # typical import path with dispatch available
15
+ from triton.language.extra.libdevice import rsqrt
16
+ except ModuleNotFoundError:
17
+ # for working with NGC containers
18
+ from triton.language.extra.cuda.libdevice import rsqrt
19
+ else:
20
+ from triton.language.math import rsqrt
21
+
22
+
23
+ @triton.jit
24
+ def _layer_norm_forward_kernel(
25
+ Y_ptr, # pointer to output, shape (n_rows, n_cols)
26
+ Y_row_stride, # stride of each row in output
27
+ X_ptr, # pointer to input, shape (n_rows, n_cols)
28
+ X_row_stride, # stride of each row in input
29
+ W_ptr, # pointer to weights, shape (n_cols,)
30
+ W_row_stride, # stride of each row in weights
31
+ B_ptr, # pointer to bias, shape (n_cols,)
32
+ B_row_stride, # stride of each row in bias
33
+ Mean_ptr, # pointer to mean, shape (n_rows,)
34
+ Mean_row_stride, # stride of each row in mean
35
+ RSTD_ptr, # pointer to rstd, shape (n_rows,)
36
+ RSTD_row_stride, # stride of each row in rstd
37
+ n_cols,
38
+ eps,
39
+ BLOCK_SIZE: tl.constexpr,
40
+ ):
41
+ """
42
+ References:
43
+ https://arxiv.org/abs/1607.06450
44
+ https://github.com/karpathy/llm.c/blob/master/doc/layernorm/layernorm.md
45
+ """
46
+ row_idx = tl.program_id(0)
47
+ col_offsets = tl.arange(0, BLOCK_SIZE)
48
+ mask = col_offsets < n_cols
49
+
50
+ Y_ptr += row_idx * Y_row_stride
51
+ X_ptr += row_idx * X_row_stride
52
+ Mean_ptr += row_idx * Mean_row_stride
53
+ RSTD_ptr += row_idx * RSTD_row_stride
54
+
55
+ X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0)
56
+ W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0)
57
+ B_row = tl.load(B_ptr + col_offsets, mask=mask, other=0)
58
+
59
+ mean = tl.sum(X_row, axis=0) / n_cols
60
+ Xmm = tl.where(mask, X_row - mean, 0)
61
+ var = tl.sum(Xmm * Xmm, axis=0) / n_cols
62
+ rstd = rsqrt(var + eps)
63
+
64
+ tl.store(Mean_ptr, mean)
65
+ tl.store(RSTD_ptr, rstd)
66
+
67
+ Y_row = Xmm * rstd * W_row + B_row
68
+
69
+ tl.store(Y_ptr + col_offsets, Y_row, mask=mask)
70
+
71
+
72
+ @triton.jit
73
+ def _layer_norm_backward_kernel(
74
+ X_ptr, # pointer to input, shape (n_rows, n_cols)
75
+ W_ptr, # pointer to weights, shape (n_cols,)
76
+ Mean_ptr, # pointer to mean, shape (n_rows,)
77
+ RSTD_ptr, # pointer to rstd, shape (n_rows,)
78
+ DX_ptr, # pointer to input grad, shape (n_rows, n_cols)
79
+ DW_ptr, # pointer to weights grad, shape (n_cols,)
80
+ DB_ptr, # pointer to bias grad, shape (n_cols,)
81
+ DY_ptr, # pointer to output grad, shape (n_rows, n_cols)
82
+ stride_x, # stride of each row in input
83
+ stride_dx, # stride of each row in input grad
84
+ stride_dw, # stride of each row in weights grad
85
+ stride_db, # stride of each row in bias grad
86
+ stride_dy, # stride of each row in output grad
87
+ n_rows,
88
+ n_cols,
89
+ rows_per_program: tl.constexpr,
90
+ BLOCK_SIZE: tl.constexpr,
91
+ dtype: tl.constexpr,
92
+ ):
93
+ """
94
+ References:
95
+ https://arxiv.org/abs/1607.06450
96
+ https://github.com/karpathy/llm.c/blob/master/doc/layernorm/layernorm.md
97
+ https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
98
+ https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/triton/layer_norm.py
99
+ """
100
+ row_block_id = tl.program_id(0)
101
+ row_start = row_block_id * rows_per_program
102
+ row_end = min((row_block_id + 1) * rows_per_program, n_rows)
103
+ cols = tl.arange(0, BLOCK_SIZE)
104
+ mask = cols < n_cols
105
+
106
+ dw_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
107
+ db_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
108
+
109
+ X_ptr += row_start * stride_x
110
+ Mean_ptr += row_start
111
+ RSTD_ptr += row_start
112
+ DX_ptr += row_start * stride_dx
113
+ DY_ptr += row_start * stride_dy
114
+
115
+ for _ in range(row_start, row_end):
116
+ x = tl.load(X_ptr + cols, mask=mask, other=0.0)
117
+ w = tl.load(W_ptr + cols, mask=mask, other=0.0)
118
+ dy = tl.load(DY_ptr + cols, mask=mask, other=0.0)
119
+ mean = tl.load(Mean_ptr)
120
+ rstd = tl.load(RSTD_ptr)
121
+
122
+ x_hat = (x - mean) * rstd
123
+ wdy = w * dy
124
+ c1 = tl.sum(x_hat * wdy, axis=0) / n_cols
125
+ c2 = tl.sum(wdy, axis=0) / n_cols
126
+ dx = (wdy - (x_hat * c1 + c2)) * rstd
127
+ tl.store(DX_ptr + cols, dx.to(dtype), mask=mask)
128
+
129
+ dw_row += dy * x_hat
130
+ db_row += dy
131
+
132
+ X_ptr += stride_x
133
+ Mean_ptr += 1
134
+ RSTD_ptr += 1
135
+ DX_ptr += stride_dx
136
+ DY_ptr += stride_dy
137
+
138
+ tl.store(DW_ptr + row_block_id * stride_dw + cols, dw_row.to(dtype), mask=mask)
139
+ tl.store(DB_ptr + row_block_id * stride_db + cols, db_row.to(dtype), mask=mask)
140
+
141
+
142
+ def layer_norm_forward(X, W, B, eps):
143
+ shape = X.shape
144
+ dim = shape[-1]
145
+ X = X.view(-1, dim)
146
+ n_rows, n_cols = X.shape
147
+ BLOCK_SIZE, num_warps = calculate_settings(n_cols)
148
+ Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
149
+ Mean = torch.empty(n_rows, dtype=X.dtype, device=X.device)
150
+ RSTD = torch.empty(n_rows, dtype=X.dtype, device=X.device)
151
+ if X.shape[1] != W.shape[0]:
152
+ raise ValueError(
153
+ f"Incompatible dimensions: input feature size (X.shape[1]={X.shape[1]}) "
154
+ f"must match weight size (W.shape[0]={W.shape[0]})"
155
+ )
156
+
157
+ # XPU-specific optimization
158
+ kernel_args = {}
159
+ if X.device.type == "xpu":
160
+ kernel_args["grf_mode"] = "large"
161
+
162
+ _layer_norm_forward_kernel[(n_rows,)](
163
+ Y,
164
+ Y.stride(0),
165
+ X,
166
+ X.stride(0),
167
+ W,
168
+ W.stride(0),
169
+ B,
170
+ B.stride(0),
171
+ Mean,
172
+ Mean.stride(0),
173
+ RSTD,
174
+ RSTD.stride(0),
175
+ n_cols,
176
+ eps,
177
+ BLOCK_SIZE=BLOCK_SIZE,
178
+ num_warps=num_warps,
179
+ **kernel_args, # XPU-specific optimization
180
+ )
181
+ return Y.view(*shape), X, Mean, RSTD, BLOCK_SIZE, num_warps
182
+
183
+
184
+ def layer_norm_backward(dY, X, W, B, Mean, RSTD):
185
+ shape = dY.shape
186
+ dim = shape[-1]
187
+ dY = dY.view(-1, dim)
188
+ n_rows, n_cols = dY.shape
189
+
190
+ sm_count = 1
191
+ if X.device.type == "cuda":
192
+ sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
193
+ elif X.device.type == "xpu":
194
+ sm_count = torch.xpu.get_device_properties(X.device).gpu_eu_count
195
+
196
+ DX = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
197
+ _DW = torch.empty((sm_count, n_cols), dtype=W.dtype, device=W.device)
198
+ _DB = torch.empty((sm_count, n_cols), dtype=W.dtype, device=W.device)
199
+
200
+ BLOCK_SIZE, num_warps = calculate_settings(n_cols)
201
+ if n_cols > BLOCK_SIZE:
202
+ raise RuntimeError(
203
+ f"Feature dimension {n_cols} exceeds maximum supported size of {BLOCK_SIZE}. Consider using a smaller feature dimension."
204
+ )
205
+
206
+ rows_per_program = math.ceil(n_rows / sm_count)
207
+ grid = (sm_count,)
208
+ triton_dtype = (
209
+ tl.float32
210
+ if X.dtype == torch.float32
211
+ else tl.bfloat16
212
+ if X.dtype == torch.bfloat16
213
+ else tl.float16
214
+ if X.dtype == torch.float16
215
+ else tl.float32 # fallback to float32 for other types
216
+ )
217
+
218
+ # XPU-specific optimization
219
+ kernel_args = {}
220
+ if X.device.type == "xpu":
221
+ kernel_args.update({"grf_mode": "large", "num_warps": 32, "num_stages": 4})
222
+
223
+ _layer_norm_backward_kernel[grid](
224
+ X,
225
+ W,
226
+ Mean,
227
+ RSTD,
228
+ DX,
229
+ _DW,
230
+ _DB,
231
+ dY,
232
+ X.stride(0),
233
+ DX.stride(0),
234
+ _DW.stride(0),
235
+ _DB.stride(0),
236
+ dY.stride(0),
237
+ n_rows,
238
+ n_cols,
239
+ rows_per_program,
240
+ BLOCK_SIZE=BLOCK_SIZE,
241
+ dtype=triton_dtype,
242
+ **kernel_args, # XPU-specific optimization
243
+ )
244
+
245
+ DW = _DW.sum(dim=0).to(W.dtype)
246
+ DB = _DB.sum(dim=0).to(W.dtype)
247
+
248
+ DX = DX.view(*shape)
249
+ return DX, DW, DB
250
+
251
+
252
+ class LigerLayerNormFunction(torch.autograd.Function):
253
+ @staticmethod
254
+ @ensure_contiguous
255
+ def forward(ctx, X, W, B, eps):
256
+ Y, X, Mean, RSTD, BLOCK_SIZE, num_warps = layer_norm_forward(X, W, B, eps)
257
+ ctx.save_for_backward(X, W, B, Mean, RSTD)
258
+ return Y
259
+
260
+ @staticmethod
261
+ @ensure_contiguous
262
+ def backward(ctx, dY):
263
+ X, W, B, Mean, RSTD = ctx.saved_tensors
264
+ DX, DW, DB = layer_norm_backward(dY, X, W, B, Mean, RSTD)
265
+ return DX, DW, DB, None
torch-ext/liger_kernels/qwen2vl_mrope.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import triton
3
+ import triton.language as tl
4
+
5
+
6
+ @triton.jit
7
+ def _triton_qwen2vl_mrope(
8
+ q_ptr,
9
+ k_ptr,
10
+ cos,
11
+ sin,
12
+ sl,
13
+ bs: tl.constexpr,
14
+ n_qh: tl.constexpr,
15
+ n_kh: tl.constexpr,
16
+ hd: tl.constexpr,
17
+ pad_n_qh: tl.constexpr,
18
+ pad_n_kh: tl.constexpr,
19
+ pad_hd: tl.constexpr,
20
+ mrope_section_t: tl.constexpr,
21
+ mrope_section_h: tl.constexpr,
22
+ BLOCK_SIZE: tl.constexpr,
23
+ BACKWARD_PASS: tl.constexpr = False,
24
+ ):
25
+ pid = tl.program_id(0)
26
+
27
+ # locate start address
28
+ q_ptr = q_ptr + pid * (n_qh * hd)
29
+ k_ptr = k_ptr + pid * (n_kh * hd)
30
+
31
+ # ####################################################################
32
+ # get the cos(mθ_{i...d/2}) and sin(mθ_{i...d/2}) for token position
33
+ # m of this program instance
34
+ # ####################################################################
35
+
36
+ # 1. program instances are laid out in a 1D vector of size bsz * seq_len, which
37
+ # effectively represents a 2D grid of size [bsz, seq_len] with seq_len dimension
38
+ # being the fastest changing dimension. Thus we can simply do pid // sl to get the batch index
39
+ # and pid % sl to get the sequence index.
40
+ # 2. We only need the left half of cos and sin matrix because the right half is just
41
+ # a clone of the left half.
42
+ t_end = mrope_section_t
43
+ h_end = t_end + mrope_section_h
44
+
45
+ t_cos = cos + pid * hd
46
+ h_cos = t_cos + bs * sl * hd
47
+ w_cos = h_cos + bs * sl * hd
48
+ t_sin = sin + pid * hd
49
+ h_sin = t_sin + bs * sl * hd
50
+ w_sin = h_sin + bs * sl * hd
51
+
52
+ cos_offsets = tl.arange(0, pad_hd // 2)
53
+ t_mask = cos_offsets < t_end
54
+ h_mask = (t_end <= cos_offsets) & (cos_offsets < h_end)
55
+ w_mask = (h_end <= cos_offsets) & (cos_offsets < hd // 2)
56
+ t_cos_row = tl.load(t_cos + cos_offsets, mask=t_mask, other=0)
57
+ h_cos_row = tl.load(h_cos + cos_offsets, mask=h_mask, other=0)
58
+ w_cos_row = tl.load(w_cos + cos_offsets, mask=w_mask, other=0)
59
+ t_sin_row = tl.load(t_sin + cos_offsets, mask=t_mask, other=0)
60
+ h_sin_row = tl.load(h_sin + cos_offsets, mask=h_mask, other=0)
61
+ w_sin_row = tl.load(w_sin + cos_offsets, mask=w_mask, other=0)
62
+ cos_row = t_cos_row + h_cos_row + w_cos_row
63
+ sin_row = t_sin_row + h_sin_row + w_sin_row
64
+
65
+ # ####################################################################
66
+ # Load the left and right half of q and k for the current
67
+ # program instance (i.e. for the current token) separately
68
+ # ####################################################################
69
+ # left half of the head
70
+ first_half_q_offsets = tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
71
+ first_half_k_offsets = tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
72
+ first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (tl.arange(0, pad_hd // 2)[None, :] < hd // 2)
73
+ first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (tl.arange(0, pad_hd // 2)[None, :] < hd // 2)
74
+ q_tile_1 = tl.load(q_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to(sin_row.dtype)
75
+ k_tile_1 = tl.load(k_ptr + first_half_k_offsets, mask=first_k_mask, other=0).to(sin_row.dtype)
76
+
77
+ # right half of the head
78
+ second_half_q_offsets = first_half_q_offsets + (hd // 2)
79
+ second_half_k_offsets = first_half_k_offsets + (hd // 2)
80
+ second_q_mask = first_q_mask
81
+ second_k_mask = first_k_mask
82
+ q_tile_2 = tl.load(q_ptr + second_half_q_offsets, mask=second_q_mask, other=0).to(sin_row.dtype)
83
+ k_tile_2 = tl.load(k_ptr + second_half_k_offsets, mask=second_k_mask, other=0).to(sin_row.dtype)
84
+
85
+ if not BACKWARD_PASS:
86
+ # y = [x1, x2] * [cos, cos] + [-x2, x1] * [sin, sin]
87
+ new_q_tile_1 = q_tile_1 * cos_row - q_tile_2 * sin_row
88
+ tl.store(q_ptr + first_half_q_offsets, new_q_tile_1, mask=first_q_mask)
89
+ new_q_tile_2 = q_tile_2 * cos_row + q_tile_1 * sin_row
90
+ tl.store(q_ptr + second_half_q_offsets, new_q_tile_2, mask=second_q_mask)
91
+
92
+ new_k_tile_1 = k_tile_1 * cos_row - k_tile_2 * sin_row
93
+ tl.store(k_ptr + first_half_k_offsets, new_k_tile_1, mask=first_k_mask)
94
+ new_k_tile_2 = k_tile_2 * cos_row + k_tile_1 * sin_row
95
+ tl.store(k_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask)
96
+ else:
97
+ # with some math, we can get:
98
+ # dy = [dx1, dx2] * [cos, cos] + [-dx2, dx1] * [-sin, -sin]
99
+ new_q_tile_1 = q_tile_1 * cos_row + q_tile_2 * sin_row
100
+ tl.store(q_ptr + first_half_q_offsets, new_q_tile_1, mask=first_q_mask)
101
+ new_q_tile_2 = q_tile_2 * cos_row - q_tile_1 * sin_row
102
+ tl.store(q_ptr + second_half_q_offsets, new_q_tile_2, mask=second_q_mask)
103
+
104
+ new_k_tile_1 = k_tile_1 * cos_row + k_tile_2 * sin_row
105
+ tl.store(k_ptr + first_half_k_offsets, new_k_tile_1, mask=first_k_mask)
106
+ new_k_tile_2 = k_tile_2 * cos_row - k_tile_1 * sin_row
107
+ tl.store(k_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask)
108
+
109
+
110
+ def qwen2vl_mrope_forward(q, k, cos, sin, mrope_section):
111
+ # transpose it back to the physical shape because Triton looks at the physical storage
112
+ # note: q and k are incontiguous before the transformation and will become contiguous after transpose
113
+ q = q.transpose(1, 2)
114
+ k = k.transpose(1, 2)
115
+
116
+ batch_size, seq_len, n_q_head, head_dim = q.shape
117
+ n_kv_head = k.shape[2]
118
+ pad_hd = triton.next_power_of_2(head_dim)
119
+ pad_n_q_head = triton.next_power_of_2(n_q_head)
120
+ pad_n_kv_head = triton.next_power_of_2(n_kv_head)
121
+ BLOCK_SIZE = max(pad_n_q_head, pad_n_kv_head)
122
+
123
+ n_row = batch_size * seq_len
124
+
125
+ # ensure tensors passed into the kernel are contiguous. It will be no-op if they are already contiguous
126
+ q = q.contiguous()
127
+ k = k.contiguous()
128
+ cos = cos.contiguous()
129
+ sin = sin.contiguous()
130
+
131
+ _triton_qwen2vl_mrope[(n_row,)](
132
+ q,
133
+ k,
134
+ cos,
135
+ sin,
136
+ seq_len,
137
+ batch_size,
138
+ n_q_head,
139
+ n_kv_head,
140
+ head_dim,
141
+ pad_n_q_head,
142
+ pad_n_kv_head,
143
+ pad_hd,
144
+ mrope_section[0],
145
+ mrope_section[1],
146
+ BLOCK_SIZE=BLOCK_SIZE,
147
+ BACKWARD_PASS=False,
148
+ )
149
+ return q.transpose(1, 2), k.transpose(1, 2), cos, sin
150
+
151
+
152
+ def qwen2vl_mrope_backward(dq, dk, cos, sin, mrope_section):
153
+ dq = dq.transpose(1, 2)
154
+ dk = dk.transpose(1, 2)
155
+
156
+ batch_size, seq_len, n_q_head, head_dim = dq.shape
157
+ n_kv_head = dk.shape[2]
158
+ pad_hd = triton.next_power_of_2(head_dim)
159
+ pad_n_q_head = triton.next_power_of_2(n_q_head)
160
+ pad_n_kv_head = triton.next_power_of_2(n_kv_head)
161
+ BLOCK_SIZE = max(pad_n_q_head, pad_n_kv_head)
162
+
163
+ n_row = batch_size * seq_len
164
+
165
+ # ensure dq and dk are contiguous
166
+ dq = dq.contiguous()
167
+ dk = dk.contiguous()
168
+
169
+ # backward is similar to forward except swapping few ops
170
+ _triton_qwen2vl_mrope[(n_row,)](
171
+ dq,
172
+ dk,
173
+ cos,
174
+ sin,
175
+ seq_len,
176
+ batch_size,
177
+ n_q_head,
178
+ n_kv_head,
179
+ head_dim,
180
+ pad_n_q_head,
181
+ pad_n_kv_head,
182
+ pad_hd,
183
+ mrope_section[0],
184
+ mrope_section[1],
185
+ BLOCK_SIZE=BLOCK_SIZE,
186
+ BACKWARD_PASS=True,
187
+ )
188
+ return dq.transpose(1, 2), dk.transpose(1, 2)
189
+
190
+
191
+ class LigerQwen2VLMRopeFunction(torch.autograd.Function):
192
+ """
193
+ Triton implementation of the Qwen2VL Multimodal Rotary Positional Embedding (M-RoPE) operation.
194
+
195
+ Please find the corresponding HuggingFace implementation here:
196
+ https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py
197
+ """
198
+
199
+ @staticmethod
200
+ def forward(ctx, q, k, cos, sin, mrope_section, unsqueeze_dim=1):
201
+ """
202
+ q size: (bsz, n_q_head, seq_len, head_dim)
203
+ k size: (bsz, n_kv_head, seq_len, head_dim)
204
+ cos size: (3, bsz, seq_len, head_dim)
205
+ sin size: (3, bsz, seq_len, head_dim)
206
+ """
207
+ q, k, cos, sin = qwen2vl_mrope_forward(q, k, cos, sin, mrope_section)
208
+ ctx.save_for_backward(cos, sin)
209
+ ctx.mrope_section = mrope_section
210
+ return q, k
211
+
212
+ def backward(ctx, dq, dk):
213
+ """
214
+ dq size: (bsz, n_q_head, seq_len, head_dim)
215
+ dk size: (bsz, n_kv_head, seq_len, head_dim)
216
+ cos size: (3, bsz, seq_len, head_dim)
217
+ sin size: (3, bsz, seq_len, head_dim)
218
+ """
219
+ cos, sin = ctx.saved_tensors
220
+ mrope_section = ctx.mrope_section
221
+ dq, dk = qwen2vl_mrope_backward(dq, dk, cos, sin, mrope_section)
222
+ return dq, dk, None, None, None, None
torch-ext/liger_kernels/rms_norm.py ADDED
@@ -0,0 +1,365 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file incorporates code from Unsloth licensed under the Apache License, Version 2.0.
3
+ See the original Unsloth repository at https://github.com/unslothai/unsloth.
4
+
5
+ The following line
6
+ https://github.com/linkedin/Liger-Kernel/blob/7382a8761f9af679482b968f9348013d933947c7/src/liger_kernel/ops/rms_norm.py#L30
7
+ is based on code from Unsloth, located at:
8
+ https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/rms_layernorm.py#L22
9
+
10
+ Modifications made by Yanning Chen, 2024.
11
+ """
12
+
13
+ import math
14
+ import operator
15
+
16
+ import torch
17
+ import triton
18
+ import triton.language as tl
19
+
20
+ from utils import calculate_settings
21
+ from utils import compare_version
22
+ from utils import ensure_contiguous
23
+ from utils import torch_to_triton_dtype
24
+
25
+ if compare_version("triton", operator.ge, "3.0.0"):
26
+ try:
27
+ # typical import path with dispatch available
28
+ from triton.language.extra.libdevice import rsqrt
29
+ except ModuleNotFoundError:
30
+ # for working with NGC containers
31
+ from triton.language.extra.cuda.libdevice import rsqrt
32
+ else:
33
+ from triton.language.math import rsqrt
34
+
35
+
36
+ _CASTING_MODE_NONE: tl.constexpr = tl.constexpr(-1)
37
+ _CASTING_MODE_LLAMA: tl.constexpr = tl.constexpr(0)
38
+ _CASTING_MODE_GEMMA: tl.constexpr = tl.constexpr(1)
39
+
40
+
41
+ @triton.jit
42
+ def _rms_norm_forward_kernel(
43
+ Y_ptr,
44
+ Y_row_stride,
45
+ X_ptr,
46
+ X_row_stride,
47
+ W_ptr,
48
+ W_row_stride,
49
+ RSTD_ptr,
50
+ RSTD_row_stride,
51
+ n_cols,
52
+ eps,
53
+ offset,
54
+ casting_mode: tl.constexpr, # constexpr so the `if` blocks can be optimized out
55
+ BLOCK_SIZE: tl.constexpr,
56
+ ):
57
+ """
58
+ y_i = (x_i / (RMS)) * (offset + wi), RMS = sqrt(sum(x_i^2) / N)
59
+
60
+ Reference:
61
+ 1. https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
62
+ 2. https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/rms_layernorm.py#L22
63
+ 3. https://arxiv.org/pdf/1910.07467
64
+ """
65
+
66
+ row_idx = tl.program_id(0)
67
+ col_offsets = tl.arange(0, BLOCK_SIZE)
68
+ mask = col_offsets < n_cols
69
+
70
+ Y_ptr += row_idx * Y_row_stride
71
+ X_ptr += row_idx * X_row_stride
72
+ RSTD_ptr += row_idx * RSTD_row_stride
73
+
74
+ X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0)
75
+ X_row_dtype = X_row.dtype
76
+ W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0)
77
+
78
+ # On Llama, only rstd is computed on fp32
79
+ if casting_mode == _CASTING_MODE_LLAMA:
80
+ X_row = X_row.to(tl.float32)
81
+
82
+ # Gemma computes everything on fp32, and then casts back the output to the original dtype
83
+ if casting_mode == _CASTING_MODE_GEMMA:
84
+ W_row = W_row.to(tl.float32)
85
+ X_row = X_row.to(tl.float32)
86
+
87
+ if casting_mode == _CASTING_MODE_NONE:
88
+ eps = eps.to(X_row_dtype)
89
+ offset = offset.to(X_row_dtype)
90
+
91
+ mean_square = tl.sum(X_row * X_row, axis=0) / n_cols
92
+ rstd = rsqrt(mean_square + eps)
93
+
94
+ # We can save time by caching rms with minimal memory overhead
95
+ # because rms is much smaller compared to X_row, as rms is for each row.
96
+ # However, on the computation side, it can save 4 operations (*, sum, /, sqrt).
97
+ tl.store(RSTD_ptr, rstd)
98
+
99
+ X_row = X_row * rstd
100
+
101
+ # On Llama, the multiplication with the weight is done on the original dtype
102
+ if casting_mode == _CASTING_MODE_LLAMA:
103
+ X_row = X_row.to(X_row_dtype)
104
+
105
+ Y_row = X_row * (offset + W_row)
106
+
107
+ if casting_mode == _CASTING_MODE_GEMMA:
108
+ Y_row = Y_row.to(X_row_dtype)
109
+
110
+ tl.store(Y_ptr + col_offsets, Y_row, mask=mask)
111
+
112
+
113
+ @triton.jit
114
+ def _rms_norm_backward_kernel(
115
+ dY_ptr,
116
+ dY_row_stride,
117
+ dX_ptr,
118
+ dX_row_stride,
119
+ X_ptr,
120
+ X_row_stride,
121
+ X_dtype: tl.constexpr,
122
+ W_ptr,
123
+ W_row_stride,
124
+ RSTD_ptr,
125
+ RSTD_row_stride,
126
+ dW_ptr,
127
+ dW_row_stride,
128
+ n_rows,
129
+ n_cols,
130
+ offset,
131
+ rows_per_program: tl.constexpr,
132
+ casting_mode: tl.constexpr,
133
+ BLOCK_SIZE: tl.constexpr,
134
+ ):
135
+ """
136
+ dx = (1 / RMS) * [dy * (w + offset - (1 / N) * (1 / RMS^2) * ((dy * (w + offset)) dot x) * x]. * means element-wise multiplication, whileas dot means dot product
137
+ dw = sum(dy * (x / RMS)). summation over BxT dimension
138
+ """
139
+
140
+ row_block_id = tl.program_id(0)
141
+ row_start = row_block_id * rows_per_program
142
+ row_end = min((row_block_id + 1) * rows_per_program, n_rows)
143
+ col_offsets = tl.arange(0, BLOCK_SIZE)
144
+ mask = col_offsets < n_cols
145
+
146
+ dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
147
+
148
+ dY_ptr += row_start * dY_row_stride
149
+ dX_ptr += row_start * dX_row_stride
150
+
151
+ X_ptr += row_start * X_row_stride
152
+ RSTD_ptr += row_start
153
+
154
+ W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0.0)
155
+ W_row = W_row + offset
156
+
157
+ for _ in range(row_start, row_end):
158
+ dY_row = tl.load(dY_ptr + col_offsets, mask=mask, other=0.0)
159
+ X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0.0)
160
+
161
+ # Get cached rms
162
+ rstd_row = tl.load(RSTD_ptr)
163
+
164
+ X_row = X_row.to(tl.float32)
165
+
166
+ # Different bacward graphs for different casting modes
167
+ if casting_mode == _CASTING_MODE_LLAMA:
168
+ m = (dY_row * W_row).to(tl.float32)
169
+
170
+ elif casting_mode == _CASTING_MODE_GEMMA:
171
+ dY_row = dY_row.to(tl.float32)
172
+ m = dY_row * W_row
173
+ else:
174
+ m = dY_row * W_row
175
+
176
+ dX_row = rstd_row * m
177
+
178
+ dX_row += (rstd_row) * (-(1 / n_cols) * rstd_row * rstd_row * tl.sum(m * X_row, axis=0) * X_row)
179
+
180
+ # calculate the gradient of W
181
+ if casting_mode == _CASTING_MODE_LLAMA:
182
+ dW_row += dY_row * (X_row * rstd_row).to(X_dtype)
183
+ else:
184
+ # here X_row is already in fp32 (see previous if block)
185
+ dW_row += dY_row * (X_row * rstd_row)
186
+
187
+ tl.store(dX_ptr + col_offsets, dX_row.to(X_dtype), mask=mask)
188
+
189
+ dY_ptr += dY_row_stride
190
+ dX_ptr += dX_row_stride
191
+ X_ptr += X_row_stride
192
+ RSTD_ptr += RSTD_row_stride
193
+
194
+ tl.store(dW_ptr + row_block_id * dW_row_stride + col_offsets, dW_row, mask=mask)
195
+
196
+
197
+ _str_to_casting_mode = {
198
+ "llama": _CASTING_MODE_LLAMA.value,
199
+ "gemma": _CASTING_MODE_GEMMA.value,
200
+ "none": _CASTING_MODE_NONE.value,
201
+ }
202
+
203
+
204
+ def rms_norm_forward(X, W, eps, offset, casting_mode):
205
+ if not isinstance(casting_mode, int):
206
+ assert casting_mode in _str_to_casting_mode, f"Invalid casting mode: {casting_mode}"
207
+ casting_mode = _str_to_casting_mode[casting_mode]
208
+ else:
209
+ assert casting_mode in _str_to_casting_mode.values(), f"Invalid casting mode: {casting_mode}"
210
+
211
+ shape = X.shape
212
+ dim = shape[-1]
213
+ X = X.view(-1, dim)
214
+ n_rows, n_cols = X.shape
215
+ BLOCK_SIZE, num_warps = calculate_settings(n_cols)
216
+
217
+ Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
218
+ # RSTD is to cache rstd for each row
219
+ # RSTD is always computed/stored in fp32 if we are using Llama or Gemma casting mode
220
+ rstd_dtype = torch.float32 if casting_mode in (_CASTING_MODE_LLAMA.value, _CASTING_MODE_GEMMA.value) else X.dtype
221
+ RSTD = torch.empty(n_rows, dtype=rstd_dtype, device=X.device)
222
+
223
+ # Check constraints.
224
+ assert X.shape[1] == W.shape[0], "Incompatible hidden size dimension between tensor1.shape[1] and tensor2.shape[0]"
225
+
226
+ # XPU-specific optimization
227
+ kernel_args = {}
228
+ if X.device.type == "xpu":
229
+ kernel_args["grf_mode"] = "large"
230
+ _rms_norm_forward_kernel[(n_rows,)](
231
+ Y,
232
+ Y.stride(0),
233
+ X,
234
+ X.stride(0),
235
+ W,
236
+ W.stride(0),
237
+ RSTD,
238
+ RSTD.stride(0),
239
+ n_cols,
240
+ eps,
241
+ offset,
242
+ casting_mode,
243
+ BLOCK_SIZE=BLOCK_SIZE,
244
+ num_warps=num_warps,
245
+ **kernel_args, # XPU-specific optimization
246
+ )
247
+ return Y.view(*shape), X, RSTD, BLOCK_SIZE, num_warps, casting_mode
248
+
249
+
250
+ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warps, in_place):
251
+ shape = dY.shape
252
+ dim = shape[-1]
253
+ dY = dY.view(-1, dim)
254
+ n_rows, n_cols = dY.shape
255
+
256
+ sm_count = 1
257
+ if X.device.type == "cuda":
258
+ sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
259
+ elif X.device.type == "xpu":
260
+ sm_count = torch.xpu.get_device_properties(X.device).gpu_eu_count
261
+
262
+ # fp32 for numerical stability especially.
263
+ _dW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)
264
+
265
+ if n_cols > BLOCK_SIZE:
266
+ raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
267
+ rows_per_program = math.ceil(n_rows / sm_count)
268
+ grid = (sm_count,)
269
+
270
+ if in_place is True:
271
+ dX = dY
272
+ else:
273
+ dX = torch.zeros_like(dY)
274
+
275
+ # XPU-specific optimization
276
+ kernel_args = {}
277
+ if X.device.type == "xpu":
278
+ kernel_args["grf_mode"] = "large"
279
+
280
+ _rms_norm_backward_kernel[grid](
281
+ dY,
282
+ dY.stride(0),
283
+ dX,
284
+ dX.stride(0),
285
+ X,
286
+ X.stride(0),
287
+ torch_to_triton_dtype[X.dtype],
288
+ W,
289
+ W.stride(0),
290
+ RSTD,
291
+ RSTD.stride(0),
292
+ _dW,
293
+ _dW.stride(0),
294
+ n_rows,
295
+ n_cols,
296
+ offset,
297
+ rows_per_program,
298
+ casting_mode,
299
+ BLOCK_SIZE=BLOCK_SIZE,
300
+ num_warps=num_warps,
301
+ **kernel_args, # XPU-specific optimization
302
+ )
303
+ dX = dX.view(*shape)
304
+ dW = _dW.sum(dim=0).to(W.dtype)
305
+
306
+ return dX, dW
307
+
308
+
309
+ class LigerRMSNormFunction(torch.autograd.Function):
310
+ """
311
+ Performs RMSNorm (Root Mean Square Normalization), which normalizes the input tensor `X` using the
312
+ weight tensor `W`, with an optional offset and casting mode.
313
+
314
+ Some models use an 'offset' to shift the weight tensor `W` by a constant value. For example, Gemma
315
+ uses an offset of 1.0, so the computation becomes `(X / RMS(X)) * (W + 1.0)` instead of the usual
316
+ `(X / RMS(X)) * W`. You can pass the offset value as an argument to the forward function.
317
+
318
+ In addition, different models cast their inputs at different places during RMSNorm computation. For
319
+ example, Gemma casts everything to fp32 nefore starting the computation, while Llama casts only the
320
+ inverse RMS to fp32. You can specify the casting mode using the `casting_mode` argument. We currently
321
+ support the following casting modes (they match HuggingFace Transformers' implementations):
322
+ - 'llama': matches the Llama implementation, where only the inverse RMS is computed on fp32.
323
+ - 'gemma': matches the Gemma implementation, where everything is cast to fp32, then computed, then cast back to the original dtype.
324
+ - 'none': no casting is done. The computation is done in the original dtype. This saves memory and is slightly faster, but has more error w.r.t. the original implementation.
325
+
326
+ `in_place` option means whether to in_place modify dY to store dX. This is default to `True` to save memory. However, under certain cases, it can produce incorrect inputs.
327
+ For example, gemma2 uses two rmsnorm sequentially with residual in between. The resesidual part needs dY so it cannot be modified in-place.
328
+ Therefore, for the patching of RMSNorm in gemma2, we set `in_place` to `False`
329
+ """
330
+
331
+ @staticmethod
332
+ @ensure_contiguous
333
+ def forward(ctx, X, W, eps, offset=0.0, casting_mode="llama", in_place=True):
334
+ """
335
+ X: (B, T, H) or (BxT, H)
336
+ W: (H,)
337
+ """
338
+ Y, X, RSTD, BLOCK_SIZE, num_warps, casting_mode = rms_norm_forward(X, W, eps, offset, casting_mode)
339
+ ctx.offset = offset
340
+ ctx.casting_mode = casting_mode
341
+ ctx.in_place = in_place
342
+ ctx.BLOCK_SIZE = BLOCK_SIZE
343
+ ctx.num_warps = num_warps
344
+ ctx.save_for_backward(X, W, RSTD)
345
+ return Y
346
+
347
+ @staticmethod
348
+ @ensure_contiguous
349
+ def backward(ctx, dY):
350
+ """
351
+ Y: (B, T, H) or (BxT, H)
352
+ """
353
+ X, W, RSTD = ctx.saved_tensors
354
+ dX, dW = rms_norm_backward(
355
+ dY,
356
+ X,
357
+ W,
358
+ RSTD,
359
+ ctx.offset,
360
+ ctx.casting_mode,
361
+ ctx.BLOCK_SIZE,
362
+ ctx.num_warps,
363
+ ctx.in_place,
364
+ )
365
+ return dX, dW, None, None, None, None
torch-ext/liger_kernels/rope.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import triton
3
+ import triton.language as tl
4
+
5
+
6
+ @triton.jit
7
+ def _triton_rope(
8
+ q_ptr,
9
+ q_row_stride,
10
+ k_ptr,
11
+ k_row_stride,
12
+ cos,
13
+ cos_row_stride,
14
+ sin,
15
+ sin_row_stride,
16
+ sl,
17
+ bs: tl.constexpr,
18
+ cos_bs: tl.constexpr,
19
+ n_qh: tl.constexpr,
20
+ n_kh: tl.constexpr,
21
+ hd: tl.constexpr,
22
+ pad_n_qh: tl.constexpr,
23
+ pad_n_kh: tl.constexpr,
24
+ pad_hd: tl.constexpr,
25
+ BLOCK_SIZE: tl.constexpr,
26
+ BACKWARD_PASS: tl.constexpr = False,
27
+ ):
28
+ # q size: (bsz, seq_len, num_q_heads, head_dim)
29
+ # q stride: (seq_len * num_q_heads * head_dim, num_q_heads * head_dim, head_dim, 1)
30
+ # k size: (bsz, seq_len, num_kv_heads, head_dim)
31
+ # k stride: (seq_len * num_kv_heads * head_dim, num_kv_heads * head_dim, head_dim, 1)
32
+
33
+ # cos size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
34
+ # stride: (seq_len * head_dim, head_dim, 1)
35
+ pid = tl.program_id(0)
36
+
37
+ # locate start address
38
+ q_ptr = q_ptr + pid * q_row_stride
39
+ k_ptr = k_ptr + pid * k_row_stride
40
+
41
+ # ####################################################################
42
+ # get the cos(mθ_{i...d/2}) and sin(mθ_{i...d/2}) for token position
43
+ # m of this program instance
44
+ # ####################################################################
45
+
46
+ # 1. program instances are laid out in a 1D vector of size bsz * seq_len, which
47
+ # effectively represents a 2D grid of size [bsz, seq_len] with seq_len dimension
48
+ # being the fastest changing dimension. Thus we can simply do pid // sl to get the batch index
49
+ # and pid % sl to get the sequence index.
50
+ # 2. We only need the left half of cos and sin matrix because the right half is just
51
+ # a clone of the left half.
52
+ batch_idx = pid // sl
53
+ cos_row_idx = pid % sl
54
+ cos = cos + tl.where(
55
+ cos_bs == 1,
56
+ cos_row_idx * cos_row_stride,
57
+ batch_idx * (sl * cos_row_stride) + cos_row_idx * cos_row_stride,
58
+ )
59
+ sin = sin + tl.where(
60
+ cos_bs == 1,
61
+ cos_row_idx * sin_row_stride,
62
+ batch_idx * (sl * sin_row_stride) + cos_row_idx * sin_row_stride,
63
+ )
64
+
65
+ cos_offsets = tl.arange(0, pad_hd // 2)
66
+ cos_mask = cos_offsets < hd // 2
67
+ cos_row = tl.load(cos + cos_offsets, mask=cos_mask, other=0)
68
+ sin_row = tl.load(sin + cos_offsets, mask=cos_mask, other=0)
69
+
70
+ # ####################################################################
71
+ # Load the left and right half of q and k for the current
72
+ # program instance (i.e. for the current token) separately
73
+ # ####################################################################
74
+ # left half of the head
75
+ first_half_q_offsets = tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
76
+ first_half_k_offsets = tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
77
+ first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (tl.arange(0, pad_hd // 2)[None, :] < hd // 2)
78
+ first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (tl.arange(0, pad_hd // 2)[None, :] < hd // 2)
79
+ q_tile_1 = tl.load(q_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to(sin_row.dtype)
80
+ k_tile_1 = tl.load(k_ptr + first_half_k_offsets, mask=first_k_mask, other=0).to(sin_row.dtype)
81
+
82
+ # right half of the head
83
+ second_half_q_offsets = first_half_q_offsets + (hd // 2)
84
+ second_half_k_offsets = first_half_k_offsets + (hd // 2)
85
+ second_q_mask = first_q_mask
86
+ second_k_mask = first_k_mask
87
+ q_tile_2 = tl.load(q_ptr + second_half_q_offsets, mask=second_q_mask, other=0).to(sin_row.dtype)
88
+ k_tile_2 = tl.load(k_ptr + second_half_k_offsets, mask=second_k_mask, other=0).to(sin_row.dtype)
89
+
90
+ if not BACKWARD_PASS:
91
+ # y = [x1, x2] * [cos, cos] + [-x2, x1] * [sin, sin]
92
+ new_q_tile_1 = q_tile_1 * cos_row - q_tile_2 * sin_row
93
+ tl.store(q_ptr + first_half_q_offsets, new_q_tile_1, mask=first_q_mask)
94
+ new_q_tile_2 = q_tile_2 * cos_row + q_tile_1 * sin_row
95
+ tl.store(q_ptr + second_half_q_offsets, new_q_tile_2, mask=second_q_mask)
96
+
97
+ new_k_tile_1 = k_tile_1 * cos_row - k_tile_2 * sin_row
98
+ tl.store(k_ptr + first_half_k_offsets, new_k_tile_1, mask=first_k_mask)
99
+ new_k_tile_2 = k_tile_2 * cos_row + k_tile_1 * sin_row
100
+ tl.store(k_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask)
101
+ else:
102
+ # with some math, we can get:
103
+ # dy = [dx1, dx2] * [cos, cos] + [-dx2, dx1] * [-sin, -sin]
104
+ new_q_tile_1 = q_tile_1 * cos_row + q_tile_2 * sin_row
105
+ tl.store(q_ptr + first_half_q_offsets, new_q_tile_1, mask=first_q_mask)
106
+ new_q_tile_2 = q_tile_2 * cos_row - q_tile_1 * sin_row
107
+ tl.store(q_ptr + second_half_q_offsets, new_q_tile_2, mask=second_q_mask)
108
+
109
+ new_k_tile_1 = k_tile_1 * cos_row + k_tile_2 * sin_row
110
+ tl.store(k_ptr + first_half_k_offsets, new_k_tile_1, mask=first_k_mask)
111
+ new_k_tile_2 = k_tile_2 * cos_row - k_tile_1 * sin_row
112
+ tl.store(k_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask)
113
+
114
+
115
+ def rope_forward(q, k, cos, sin):
116
+ # transpose it back to the physical shape because Triton looks at the physical storage
117
+ # note: q and k are incontiguous before the transformation and will become contiguous after transpose
118
+ q = q.transpose(1, 2)
119
+ k = k.transpose(1, 2)
120
+
121
+ batch_size, seq_len, n_q_head, head_dim = q.shape
122
+ n_kv_head = k.shape[2]
123
+ pad_hd = triton.next_power_of_2(head_dim)
124
+ pad_n_q_head = triton.next_power_of_2(n_q_head)
125
+ pad_n_kv_head = triton.next_power_of_2(n_kv_head)
126
+ BLOCK_SIZE = max(pad_n_q_head, pad_n_kv_head)
127
+
128
+ n_row = batch_size * seq_len
129
+
130
+ # ensure tensors passed into the kernel are contiguous. It will be no-op if they are already contiguous
131
+ q = q.contiguous()
132
+ k = k.contiguous()
133
+ cos = cos.contiguous()
134
+ sin = sin.contiguous()
135
+ cos_batch_size = cos.shape[0]
136
+
137
+ _triton_rope[(n_row,)](
138
+ q,
139
+ q.stride(1),
140
+ k,
141
+ k.stride(1),
142
+ cos,
143
+ cos.stride(-2),
144
+ sin,
145
+ sin.stride(-2),
146
+ seq_len,
147
+ batch_size,
148
+ cos_batch_size,
149
+ n_q_head,
150
+ n_kv_head,
151
+ head_dim,
152
+ pad_n_q_head,
153
+ pad_n_kv_head,
154
+ pad_hd,
155
+ BLOCK_SIZE=BLOCK_SIZE,
156
+ BACKWARD_PASS=False,
157
+ )
158
+ return q.transpose(1, 2), k.transpose(1, 2), cos, sin
159
+
160
+
161
+ def rope_backward(dq, dk, cos, sin):
162
+ dq = dq.transpose(1, 2)
163
+ dk = dk.transpose(1, 2)
164
+
165
+ batch_size, seq_len, n_q_head, head_dim = dq.shape
166
+ cos_batch_size = cos.shape[0]
167
+ n_kv_head = dk.shape[2]
168
+ pad_hd = triton.next_power_of_2(head_dim)
169
+ pad_n_q_head = triton.next_power_of_2(n_q_head)
170
+ pad_n_kv_head = triton.next_power_of_2(n_kv_head)
171
+ BLOCK_SIZE = max(pad_n_q_head, pad_n_kv_head)
172
+
173
+ n_row = batch_size * seq_len
174
+
175
+ # ensure dq and dk are contiguous
176
+ dq = dq.contiguous()
177
+ dk = dk.contiguous()
178
+
179
+ # backward is similar to forward except swapping few ops
180
+ _triton_rope[(n_row,)](
181
+ dq,
182
+ dq.stride(1),
183
+ dk,
184
+ dk.stride(1),
185
+ cos,
186
+ cos.stride(-2),
187
+ sin,
188
+ sin.stride(-2),
189
+ seq_len,
190
+ batch_size,
191
+ cos_batch_size,
192
+ n_q_head,
193
+ n_kv_head,
194
+ head_dim,
195
+ pad_n_q_head,
196
+ pad_n_kv_head,
197
+ pad_hd,
198
+ BLOCK_SIZE=BLOCK_SIZE,
199
+ BACKWARD_PASS=True,
200
+ )
201
+ return dq.transpose(1, 2), dk.transpose(1, 2)
202
+
203
+
204
+ class LigerRopeFunction(torch.autograd.Function):
205
+ """
206
+ Triton implementation of the Rotary Positional Embedding (RoPE) operation. Please note that
207
+ this implements the HuggingFace Llama & Mistral version, whose rotation matrix is slightly different
208
+ than the original RoPE paper.
209
+
210
+ Please find the corresponding HuggingFace implementation here:
211
+ https://github.com/huggingface/transformers/blob/v4.40.2/src/transformers/models/llama/modeling_llama.py#L184
212
+
213
+ For more details about the rotation matrix used here, please refer to:
214
+ https://discuss.huggingface.co/t/is-llama-rotary-embedding-implementation-correct/44509/2
215
+ """
216
+
217
+ @staticmethod
218
+ def forward(ctx, q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
219
+ """
220
+ q size: (bsz, n_q_head, seq_len, head_dim)
221
+ k size: (bsz, n_kv_head, seq_len, head_dim)
222
+ cos size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
223
+ sin size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
224
+ """
225
+ q, k, cos, sin = rope_forward(q, k, cos, sin)
226
+ ctx.save_for_backward(cos, sin)
227
+ return q, k
228
+
229
+ def backward(ctx, dq, dk):
230
+ """
231
+ dq size: (bsz, n_q_head, seq_len, head_dim)
232
+ dk size: (bsz, n_kv_head, seq_len, head_dim)
233
+ cos size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
234
+ sin size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
235
+ """
236
+
237
+ cos, sin = ctx.saved_tensors
238
+ dq, dk = rope_backward(dq, dk, cos, sin)
239
+ return dq, dk, None, None, None, None
torch-ext/liger_kernels/swiglu.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from utils import calculate_settings
6
+ from utils import ensure_contiguous
7
+
8
+
9
+ @triton.jit
10
+ def silu(x):
11
+ return x * tl.sigmoid(x)
12
+
13
+
14
+ @triton.jit
15
+ def _swiglu_forward_kernel(a_ptr, b_ptr, c_ptr, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr):
16
+ program_id = tl.program_id(0).to(tl.int64)
17
+
18
+ # locate start index
19
+ a_ptr += program_id * stride
20
+ b_ptr += program_id * stride
21
+ c_ptr += program_id * stride
22
+
23
+ col_offsets = tl.arange(0, BLOCK_SIZE)
24
+ mask = col_offsets < n_cols
25
+
26
+ # sigmoid requires type float32
27
+ a_row = tl.load(a_ptr + col_offsets, mask=mask, other=0).to(tl.float32)
28
+ b_row = tl.load(b_ptr + col_offsets, mask=mask, other=0)
29
+ c_row = silu(a_row) * b_row
30
+ tl.store(c_ptr + col_offsets, c_row, mask=mask)
31
+
32
+
33
+ @triton.jit
34
+ def _swiglu_backward_kernel(dc_ptr, a_ptr, b_ptr, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr):
35
+ program_id = tl.program_id(0).to(tl.int64)
36
+
37
+ # locate start index
38
+ dc_ptr += program_id * stride
39
+ a_ptr += program_id * stride
40
+ b_ptr += program_id * stride
41
+
42
+ col_offsets = tl.arange(0, BLOCK_SIZE)
43
+ mask = col_offsets < n_cols
44
+
45
+ dc_row = tl.load(dc_ptr + col_offsets, mask=mask, other=0)
46
+ # sigmoid requires type float32
47
+ a_row = tl.load(a_ptr + col_offsets, mask=mask, other=0).to(tl.float32)
48
+ b_row = tl.load(b_ptr + col_offsets, mask=mask, other=0)
49
+
50
+ # recomputation to save memory
51
+ sig_a = tl.sigmoid(a_row)
52
+ silu_a = a_row * sig_a
53
+ db_row = dc_row * silu_a
54
+ da_row = dc_row * (silu_a * (1 - sig_a) + sig_a) * b_row
55
+
56
+ tl.store(a_ptr + col_offsets, da_row, mask=mask)
57
+ tl.store(b_ptr + col_offsets, db_row, mask=mask)
58
+
59
+
60
+ def swiglu_forward(a, b):
61
+ ori_shape = a.shape
62
+
63
+ n_cols = ori_shape[-1]
64
+ a = a.view(-1, n_cols)
65
+ b = b.view(-1, n_cols)
66
+ c = torch.empty_like(a)
67
+ n_rows = a.shape[0]
68
+
69
+ BLOCK_SIZE, num_warps = calculate_settings(n_cols)
70
+
71
+ _swiglu_forward_kernel[(n_rows,)](
72
+ a,
73
+ b,
74
+ c,
75
+ c.stride(-2),
76
+ n_cols=n_cols,
77
+ BLOCK_SIZE=BLOCK_SIZE,
78
+ num_warps=num_warps,
79
+ )
80
+ return a, b, c.view(*ori_shape)
81
+
82
+
83
+ def swiglu_backward(a, b, dc):
84
+ ori_shape = dc.shape
85
+ n_cols = ori_shape[-1]
86
+ dc = dc.view(-1, n_cols)
87
+ n_rows = dc.shape[0]
88
+
89
+ BLOCK_SIZE, num_warps = calculate_settings(n_cols)
90
+
91
+ _swiglu_backward_kernel[(n_rows,)](
92
+ dc,
93
+ a,
94
+ b,
95
+ dc.stride(-2),
96
+ n_cols=n_cols,
97
+ BLOCK_SIZE=BLOCK_SIZE,
98
+ num_warps=num_warps,
99
+ )
100
+ return a.view(*ori_shape), b.view(*ori_shape)
101
+
102
+
103
+ class LigerSiLUMulFunction(torch.autograd.Function):
104
+ @staticmethod
105
+ @ensure_contiguous
106
+ def forward(ctx, a, b):
107
+ a, b, c = swiglu_forward(a, b)
108
+ ctx.save_for_backward(a, b)
109
+ return c
110
+
111
+ @staticmethod
112
+ @ensure_contiguous
113
+ def backward(ctx, dc):
114
+ a, b = ctx.saved_tensors
115
+ a, b = swiglu_backward(a, b, dc)
116
+ return a, b
torch-ext/liger_kernels/tvd.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Literal
2
+ from typing import Optional
3
+
4
+ import torch
5
+ import triton
6
+ import triton.language as tl
7
+
8
+ from utils import ensure_contiguous
9
+
10
+ MAX_FUSED_SIZE = 65536 // 4
11
+
12
+ REDUCTION_LITERAL = Literal["none", "sum", "mean", "batchmean"]
13
+
14
+ _REDUCTION_MODE_NONE = tl.constexpr(0)
15
+ _REDUCTION_MODE_SUM = tl.constexpr(1)
16
+ _REDUCTION_MODE_MEAN = tl.constexpr(2)
17
+ _REDUCTION_MODE_BATCHMEAN = tl.constexpr(3)
18
+
19
+ _str_to_reduction_mode = {
20
+ "none": _REDUCTION_MODE_NONE.value,
21
+ "sum": _REDUCTION_MODE_SUM.value,
22
+ "mean": _REDUCTION_MODE_MEAN.value,
23
+ "batchmean": _REDUCTION_MODE_BATCHMEAN.value,
24
+ }
25
+
26
+
27
+ def get_num_warps(BLOCK_SIZE):
28
+ num_warps = 4
29
+ if BLOCK_SIZE >= 32768:
30
+ num_warps = 32
31
+ elif BLOCK_SIZE >= 8192:
32
+ num_warps = 16
33
+ elif BLOCK_SIZE >= 2048:
34
+ num_warps = 8
35
+
36
+ return num_warps
37
+
38
+
39
+ @triton.jit
40
+ def _tv_distance_kernel(
41
+ p_ptr,
42
+ p_stride,
43
+ q_ptr,
44
+ q_stride,
45
+ loss_ptr,
46
+ loss_stride,
47
+ grads_ptr,
48
+ grads_stride,
49
+ label_ptr,
50
+ ignore_index: tl.constexpr,
51
+ n_cols,
52
+ BLOCK_SIZE: tl.constexpr,
53
+ HAS_LABEL: tl.constexpr,
54
+ reduction: tl.constexpr = _REDUCTION_MODE_BATCHMEAN,
55
+ ):
56
+ pid = tl.program_id(0).to(tl.int64)
57
+ p_ptr += pid * p_stride
58
+ q_ptr += pid * q_stride
59
+ loss_ptr += pid * loss_stride
60
+ grads_ptr += pid * grads_stride
61
+ label_ptr += pid
62
+
63
+ base_offsets = tl.arange(0, BLOCK_SIZE)
64
+
65
+ if HAS_LABEL:
66
+ label = tl.load(label_ptr)
67
+ if label == ignore_index:
68
+ for i in range(0, n_cols, BLOCK_SIZE):
69
+ offsets = i + base_offsets
70
+ mask = offsets < n_cols
71
+ tl.store(grads_ptr + offsets, 0.0, mask=mask)
72
+ if reduction == _REDUCTION_MODE_NONE:
73
+ tl.store(loss_ptr + offsets, 0.0, mask=mask)
74
+ return
75
+
76
+ loss_sum = 0.0
77
+ for i in range(0, n_cols, BLOCK_SIZE):
78
+ offsets = i + base_offsets
79
+ mask = offsets < n_cols
80
+
81
+ p = tl.load(p_ptr + offsets, mask=mask, other=0.0)
82
+ q = tl.load(q_ptr + offsets, mask=mask, other=0.0)
83
+
84
+ # TVD(P || Q) = 0.5 * |P - Q|
85
+ tv_loss = 0.5 * tl.abs(p - q)
86
+
87
+ grad_res = tl.where(p > q, 0.5, -0.5)
88
+
89
+ tl.store(grads_ptr + offsets, grad_res, mask=mask)
90
+
91
+ if reduction == _REDUCTION_MODE_NONE:
92
+ tl.store(loss_ptr + offsets, tv_loss, mask=mask)
93
+ else:
94
+ loss_sum += tl.sum(tv_loss, axis=0)
95
+
96
+ if reduction != _REDUCTION_MODE_NONE:
97
+ tl.store(loss_ptr, loss_sum)
98
+
99
+
100
+ def tv_distance_forward_triton(p, q, shift_labels, reduction, ignore_index, has_label):
101
+ BT, V = p.shape
102
+
103
+ BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
104
+ num_warps = get_num_warps(BLOCK_SIZE)
105
+
106
+ grid = (BT,)
107
+
108
+ reduction = _str_to_reduction_mode[reduction]
109
+
110
+ out_size = (BT, V) if reduction == _REDUCTION_MODE_NONE.value else (BT,)
111
+ output_tensor = torch.zeros(out_size, device=p.device, dtype=torch.float32)
112
+ grads = torch.empty_like(p)
113
+
114
+ n_non_ignore = (shift_labels != ignore_index).sum().item() if has_label else BT
115
+
116
+ _tv_distance_kernel[grid](
117
+ p,
118
+ p.stride(0),
119
+ q,
120
+ q.stride(0),
121
+ output_tensor,
122
+ output_tensor.stride(0),
123
+ grads,
124
+ grads.stride(0),
125
+ shift_labels if has_label else torch.empty(1, device=p.device),
126
+ ignore_index,
127
+ V,
128
+ BLOCK_SIZE=BLOCK_SIZE,
129
+ HAS_LABEL=has_label,
130
+ num_warps=num_warps,
131
+ reduction=reduction,
132
+ )
133
+
134
+ if reduction == _REDUCTION_MODE_BATCHMEAN.value:
135
+ return output_tensor.sum() / n_non_ignore, grads / n_non_ignore
136
+ elif reduction == _REDUCTION_MODE_SUM.value:
137
+ return output_tensor.sum(dim=0), grads
138
+ elif reduction == _REDUCTION_MODE_MEAN.value:
139
+ return output_tensor.sum() / (n_non_ignore * V), grads / (n_non_ignore * V)
140
+ else:
141
+ return output_tensor, grads
142
+
143
+
144
+ def tvd_backward_triton(grad_output, grads):
145
+ # If cross entropy is the last layer, grad_output is 1.0. Skip the mul then.
146
+ if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
147
+ return grads
148
+
149
+ return grads * grad_output
150
+
151
+
152
+ class LigerTVDLossFunction(torch.autograd.Function):
153
+ """
154
+ Class implementing the forward and backward pass for the Total Variation Distance Loss using Triton.
155
+ """
156
+
157
+ @staticmethod
158
+ @ensure_contiguous
159
+ def forward(
160
+ ctx,
161
+ p: torch.Tensor,
162
+ q: torch.Tensor,
163
+ shift_labels: Optional[torch.Tensor] = None,
164
+ reduction: REDUCTION_LITERAL = "batchmean",
165
+ ignore_index: int = -100,
166
+ ) -> torch.Tensor:
167
+ """A forward pass for the Total Variation Distance Loss.
168
+
169
+ Args:
170
+ ctx: Torch autograd context
171
+ p (torch.Tensor): A tensor of shape (BT, V) containing the first distribution.
172
+ q (torch.Tensor): A tensor of shape (BT, V) containing the second distribution.
173
+ shift_labels (Optional[torch.Tensor]): A tensor of shape (BT,) containing the labels.
174
+ reduction (REDUCTION_LITERAL, optional): The reduction method to be applied. Defaults to "batchmean".
175
+ ignore_index (int, optional): The index to ignore during loss calculation. Defaults to -100.
176
+
177
+ Returns:
178
+ torch.Tensor: The computed Total Variation Distance Loss.
179
+ """
180
+ has_label = False
181
+ if shift_labels is not None:
182
+ assert shift_labels.shape == (p.shape[0],), (
183
+ f"the shape of shift_labels must be (BT,). Got: {shift_labels.shape}"
184
+ )
185
+ shift_labels = shift_labels.contiguous()
186
+ has_label = True
187
+
188
+ loss, grads = tv_distance_forward_triton(p, q, shift_labels, reduction, ignore_index, has_label)
189
+ ctx.save_for_backward(grads)
190
+ return loss
191
+
192
+ @staticmethod
193
+ @ensure_contiguous
194
+ def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
195
+ """A backward pass for the Total Variation Distance Loss.
196
+
197
+ Args:
198
+ ctx: Torch autograd context
199
+ grad_output (torch.Tensor): The gradient of the loss with respect to the output.
200
+
201
+ Returns:
202
+ tuple[torch.Tensor, None, None, None, None]: The gradient of the loss with respect to the inputs.
203
+ """
204
+ (grads,) = ctx.saved_tensors
205
+ grads = tvd_backward_triton(grad_output, grads)
206
+
207
+ return grads, None, None, None, None
torch-ext/liger_kernels/utils.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file incorporates code from Unsloth licensed under the Apache License, Version 2.0.
3
+ See the original Unsloth repository at https://github.com/unslothai/unsloth.
4
+
5
+ The following line
6
+ https://github.com/linkedin/Liger-Kernel/blob/7382a8761f9af679482b968f9348013d933947c7/src/liger_kernel/ops/utils.py#L23
7
+ is based on code from Unsloth, located at:
8
+ https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/utils.py#L43
9
+
10
+ Modifications made by Yanning Chen, 2024.
11
+ """
12
+
13
+ import functools
14
+ import importlib
15
+ import operator
16
+
17
+ from typing import Callable
18
+
19
+ import torch
20
+ import triton
21
+ import triton.language as tl
22
+
23
+ from packaging.version import Version
24
+
25
+ def infer_device():
26
+ """
27
+ Get current device name based on available devices
28
+ """
29
+ if torch.cuda.is_available(): # Works for both Nvidia and AMD
30
+ return "cuda"
31
+ elif torch.xpu.is_available():
32
+ return "xpu"
33
+ else:
34
+ return "cpu"
35
+
36
+ def is_hip() -> bool:
37
+ return torch.version.hip is not None
38
+
39
+
40
+ def ensure_contiguous(fn):
41
+ @functools.wraps(fn)
42
+ def wrapper(ctx, *args, **kwargs):
43
+ def maybe_to_contiguous(x):
44
+ return x.contiguous() if isinstance(x, torch.Tensor) else x
45
+
46
+ args = [maybe_to_contiguous(arg) for arg in args]
47
+ kwargs = {k: maybe_to_contiguous(v) for k, v in kwargs.items()}
48
+ return fn(ctx, *args, **kwargs)
49
+
50
+ return wrapper
51
+
52
+
53
+ def calculate_settings(n):
54
+ # reference: https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/utils.py#L43
55
+
56
+ MAX_FUSED_SIZE = 65536
57
+ BLOCK_SIZE = triton.next_power_of_2(n)
58
+ if BLOCK_SIZE > MAX_FUSED_SIZE:
59
+ raise RuntimeError(
60
+ f"Cannot launch Triton kernel since n = {n} exceeds the recommended Triton blocksize = {MAX_FUSED_SIZE}."
61
+ )
62
+
63
+ num_warps = 4
64
+ if BLOCK_SIZE >= 32768:
65
+ num_warps = 32 if not is_hip() else 16
66
+ elif BLOCK_SIZE >= 8192:
67
+ num_warps = 16
68
+ elif BLOCK_SIZE >= 2048:
69
+ num_warps = 8
70
+ return BLOCK_SIZE, num_warps
71
+
72
+
73
+ def compare_version(package: str, operator: Callable, target: str):
74
+ try:
75
+ pkg = importlib.import_module(package)
76
+ except ImportError:
77
+ return False
78
+ pkg_version = Version(pkg.__version__)
79
+ return operator(pkg_version, Version(target))
80
+
81
+
82
+ def get_amp_custom_fwd_bwd() -> Callable:
83
+ device = infer_device()
84
+ if compare_version("torch", operator.ge, "2.4.0"):
85
+ return (
86
+ functools.partial(torch.amp.custom_fwd, device_type=device),
87
+ functools.partial(torch.amp.custom_bwd, device_type=device),
88
+ )
89
+ return torch.cuda.amp.custom_fwd, torch.cuda.amp.custom_bwd
90
+
91
+
92
+ amp_custom_fwd, amp_custom_bwd = get_amp_custom_fwd_bwd()
93
+
94
+
95
+ torch_to_triton_dtype = {
96
+ torch.float32: tl.float32,
97
+ torch.float16: tl.float16,
98
+ torch.bfloat16: tl.bfloat16,
99
+ }
100
+
101
+
102
+ @triton.jit
103
+ def element_mul_kernel(
104
+ X_ptr,
105
+ X_stride,
106
+ grad_output_ptr,
107
+ n_cols,
108
+ BLOCK_SIZE: tl.constexpr,
109
+ ):
110
+ """
111
+ This function multiplies each element of the tensor pointed by X_ptr with the value pointed by grad_output_ptr.
112
+ The multiplication is performed in-place on the tensor pointed by X_ptr.
113
+
114
+ Parameters:
115
+ X_ptr: Pointer to the input tensor.
116
+ X_stride (int): The stride of the input tensor.
117
+ grad_output_ptr: Pointer to the gradient output value.
118
+ n_cols (int): The number of columns in the input tensor.
119
+ BLOCK_SIZE (int): The block size for Triton operations.
120
+ """
121
+
122
+ # Get the program ID and convert it to int64 to avoid overflow
123
+ program_id = tl.program_id(0).to(tl.int64)
124
+
125
+ # Locate the start index
126
+ X_ptr += program_id * X_stride
127
+
128
+ # Load the gradient output value
129
+ grad_output = tl.load(grad_output_ptr)
130
+
131
+ # Perform the element-wise multiplication
132
+ for i in range(0, n_cols, BLOCK_SIZE):
133
+ X_offsets = i + tl.arange(0, BLOCK_SIZE)
134
+ X_block = tl.load(X_ptr + X_offsets, mask=X_offsets < n_cols)
135
+ tl.store(X_ptr + X_offsets, X_block * grad_output, mask=X_offsets < n_cols)