medmekk HF Staff commited on
Commit
a210373
·
verified ·
1 Parent(s): 295a3b4

Upload folder using huggingface_hub

Browse files
README.md ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # Unsloth Kernels
2
+
3
+ Unsloth Kernels is a collection of kernels for the Unsloth project.
4
+
5
+ ## Installation
6
+
build.toml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ [general]
2
+ name = "unsloth_kernels"
3
+
4
+ [torch]
5
+ universal = true
flake.nix ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ description = "Flake for ReLU kernel";
3
+
4
+ inputs = {
5
+ kernel-builder.url = "path:../..";
6
+ };
7
+
8
+ outputs =
9
+ {
10
+ self,
11
+ kernel-builder,
12
+ }:
13
+ kernel-builder.lib.genFlakeOutputs {
14
+ path = ./.;
15
+ rev = self.shortRev or self.dirtyShortRev or self.lastModifiedDate;
16
+ };
17
+ }
torch-ext/unsloth_kernels/__init__.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .cross_entropy_loss import fast_cross_entropy_loss
2
+ from .fast_lora import fast_lora_forward
3
+ from .flex_attention import slow_inference_attention_softcapping
4
+ from .layernorm import fast_layernorm
5
+ from .rope_embedding import inplace_rope_embedding, fast_rope_embedding
6
+ from .rms_layernorm import fast_rms_layernorm
7
+ from .swiglu import swiglu_fg_kernel
8
+ from .geglu import geglu_approx_forward_kernel, geglu_approx_backward_kernel, geglu_exact_forward_kernel, geglu_exact_backward_kernel
9
+ from .swiglu import swiglu_fg_kernel
10
+
11
+ __all__ = ["fast_cross_entropy_loss",
12
+ "fast_lora_forward",
13
+ "slow_inference_attention_softcapping",
14
+ "fast_layernorm",
15
+ "inplace_rope_embedding",
16
+ "fast_rms_layernorm",
17
+ "swiglu_fg_kernel",
18
+ "geglu_approx_forward_kernel",
19
+ "geglu_approx_backward_kernel",
20
+ "geglu_exact_forward_kernel",
21
+ "geglu_exact_backward_kernel",
22
+ "fast_rope_embedding"
23
+ ]
torch-ext/unsloth_kernels/cross_entropy_loss.py ADDED
@@ -0,0 +1,420 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import triton
16
+ import triton.language as tl
17
+ import torch
18
+ from .utils import (
19
+ calculate_settings,
20
+ MAX_FUSED_SIZE,
21
+ triton_tanh,
22
+ triton_cast,
23
+ torch_cuda_device,
24
+ )
25
+ from transformers.models.llama.modeling_llama import logger
26
+ from packaging.version import Version
27
+
28
+ from unsloth_zoo.loss_utils import (
29
+ patch_loss_functions as _patch_loss_functions,
30
+ post_patch_loss_function,
31
+ )
32
+
33
+
34
+ def _cross_entropy_forward(
35
+ logits_ptr ,
36
+ logits_row_stride ,
37
+ loss_ptr ,
38
+ logsumexp_ptr ,
39
+ labels_ptr ,
40
+ VOCAB_SIZE : tl.constexpr,
41
+ BLOCK_SIZE : tl.constexpr,
42
+ DO_SOFTCAPPING : tl.constexpr,
43
+ SOFTCAP : tl.constexpr,
44
+ DO_LOGIT_SCALING : tl.constexpr,
45
+ LOGIT_SCALE : tl.constexpr,
46
+ ):
47
+ """
48
+ Cross Entropy Loss = 1/n sum [ -yi log(Pi) ]
49
+ Pi = exp(xi) / sum(exp(xi))
50
+ CE_i = -y log(p) = -y log[ exp(x) / sum(exp(x)) ]
51
+ = -y [ x - log[sum(exp(x))] ]
52
+ = y * (log[sum(exp(x))] - x)
53
+ If y == 0: CE_i = 0
54
+ If y == 1: CE_i = logsumexp - x
55
+
56
+ logsumexp is also stable
57
+ Take y = log[sum(exp(x))]
58
+ exp(y) = sum(exp(x))
59
+ exp(y) = sum(exp(x - c)*exp(c)) Since e^(x-c)*e^c = e^x
60
+ exp(y) = exp(c)*sum(exp(x - c))
61
+ y = log(exp(c)*sum(exp(x - c)))
62
+ y = c + log[sum(exp(x - c))]
63
+ This means we can set c = max(x) to make sure
64
+ exp(x - c) always is exp(x - max(x)).
65
+ This ensures exp(x - max(x))'s maximum is 1 as exp(0) = 1.
66
+ """
67
+ row_idx = tl.program_id(0)
68
+ logits_ptr += row_idx * triton_cast(logits_row_stride, tl.int64)
69
+ loss_ptr += row_idx
70
+ logsumexp_ptr += row_idx
71
+ labels_ptr += row_idx
72
+
73
+ col_offsets = tl.arange(0, BLOCK_SIZE)
74
+ mask = col_offsets < VOCAB_SIZE
75
+
76
+ label_idx = tl.load(labels_ptr).to(tl.int32)
77
+ logits = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf")).to(tl.float32)
78
+
79
+ # Go logit scaling for Cohere: t * x
80
+ if DO_LOGIT_SCALING: logits = LOGIT_SCALE * logits
81
+ # Do logit softcapping for Gemma 2: t * tanh(1/t * x)
82
+ if DO_SOFTCAPPING: logits = SOFTCAP * triton_tanh(logits / SOFTCAP)
83
+
84
+ c = tl.max(logits, 0)
85
+ logsumexp = c + tl.log(tl.sum(tl.exp(logits - c), 0))
86
+
87
+ if label_idx != -100:
88
+ x = tl.load(logits_ptr + label_idx).to(tl.float32)
89
+ # Go logit scaling for Cohere: t * x
90
+ if DO_LOGIT_SCALING: x = LOGIT_SCALE * x
91
+ # Do logit softcapping for Gemma 2: t * tanh(1/t * x)
92
+ if DO_SOFTCAPPING: x = SOFTCAP * triton_tanh(x / SOFTCAP)
93
+ loss = logsumexp - x
94
+ else:
95
+ loss = 0.0
96
+ tl.store(logsumexp_ptr, logsumexp)
97
+ tl.store(loss_ptr, loss)
98
+ pass
99
+ _cross_entropy_forward = triton.jit(_cross_entropy_forward)
100
+ _cross_entropy_forward = triton.heuristics(
101
+ {
102
+ "DO_SOFTCAPPING": lambda args: bool(args["DO_SOFTCAPPING" ]),
103
+ "DO_LOGIT_SCALING": lambda args: bool(args["DO_LOGIT_SCALING"]),
104
+ }
105
+ )(_cross_entropy_forward)
106
+
107
+
108
+ def _chunked_cross_entropy_forward(
109
+ logits_ptr ,
110
+ logits_row_stride ,
111
+ loss_ptr ,
112
+ logsumexp_ptr ,
113
+ labels_ptr ,
114
+ VOCAB_SIZE : tl.constexpr,
115
+ N_CHUNKS : tl.constexpr,
116
+ BLOCK_SIZE : tl.constexpr,
117
+ DO_SOFTCAPPING : tl.constexpr,
118
+ SOFTCAP : tl.constexpr,
119
+ DO_LOGIT_SCALING : tl.constexpr,
120
+ LOGIT_SCALE : tl.constexpr,
121
+ ):
122
+ """
123
+ 256K vocab divided in 4 chunks
124
+
125
+ |-65536-| |-65536-| |-65536-| |-65536-|
126
+ |-------| |-------| |-------| |-------|
127
+ |-------| |-------| |-------| |-------|
128
+
129
+ If y == 0: CE_i = 0
130
+ If y == 1: CE_i = logsumexp - x
131
+
132
+ Notice we can do logsumexp for each chunk and then
133
+ logsumexp[chunk_sum(logsumexp)] == logsumexp
134
+
135
+ chunk_sum = log[chunk_sum(logsumexp)]
136
+ = log[exp(logsumexp(a)) + ... + exp(logsumexp(z))]
137
+ = log[exp(log[sum(exp(a))]) + ... + exp(log[sum(exp(z))])]
138
+ = log[sum(exp(a)) + ... + sum(exp(z))]
139
+ = logsumexp(x)
140
+
141
+ This means we can perform a logsumexp for each chunk, then do a
142
+ final logsumexp reduction!
143
+
144
+ Ie do: logsumexp(chunked_logsumexp) - x
145
+ """
146
+ row_idx = tl.program_id(0)
147
+ chunk_idx = tl.program_id(1)
148
+ logits_ptr += row_idx * triton_cast(logits_row_stride, tl.int64)
149
+ loss_ptr += row_idx
150
+ logsumexp_ptr += row_idx * N_CHUNKS + chunk_idx
151
+ labels_ptr += row_idx
152
+
153
+ col_offsets = chunk_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
154
+ mask = col_offsets < VOCAB_SIZE
155
+
156
+ label_idx = tl.load(labels_ptr).to(tl.int32)
157
+ logits = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf")).to(tl.float32)
158
+
159
+ # Go logit scaling for Cohere: t * x
160
+ if DO_LOGIT_SCALING: logits = LOGIT_SCALE * logits
161
+ # Do logit softcapping for Gemma 2: t * tanh(1/t * x)
162
+ if DO_SOFTCAPPING: logits = SOFTCAP * triton_tanh(logits / SOFTCAP)
163
+
164
+ c = tl.max(logits, 0)
165
+ logsumexp = c + tl.log(tl.sum(tl.exp(logits - c), 0))
166
+
167
+ if chunk_idx == 0:
168
+ # logsumexp(chunked_logsumexp) - x
169
+ # Do the -x separately
170
+ if label_idx != -100:
171
+ x = tl.load(logits_ptr + label_idx).to(tl.float32)
172
+ # Go logit scaling for Cohere: t * x
173
+ if DO_LOGIT_SCALING: x = LOGIT_SCALE * x
174
+ # Do logit softcapping for Gemma 2: t * tanh(1/t * x)
175
+ if DO_SOFTCAPPING: x = SOFTCAP * triton_tanh(x / SOFTCAP)
176
+ loss = -1.0 * x
177
+ else:
178
+ loss = 0.0
179
+ tl.store(loss_ptr, loss)
180
+ pass
181
+ tl.store(logsumexp_ptr, logsumexp)
182
+ pass
183
+ _chunked_cross_entropy_forward = triton.jit(_chunked_cross_entropy_forward)
184
+ _chunked_cross_entropy_forward = triton.heuristics(
185
+ {
186
+ "DO_SOFTCAPPING": lambda args: bool(args["DO_SOFTCAPPING" ]),
187
+ "DO_LOGIT_SCALING": lambda args: bool(args["DO_LOGIT_SCALING"]),
188
+ }
189
+ )(_chunked_cross_entropy_forward)
190
+
191
+
192
+ def _cross_entropy_backward(
193
+ logits_ptr ,
194
+ logits_row_stride ,
195
+ dloss_ptr ,
196
+ dloss_row_stride ,
197
+ logsumexp_ptr ,
198
+ labels_ptr ,
199
+ VOCAB_SIZE : tl.constexpr,
200
+ BLOCK_SIZE : tl.constexpr,
201
+ DO_SOFTCAPPING : tl.constexpr,
202
+ SOFTCAP : tl.constexpr,
203
+ DO_LOGIT_SCALING : tl.constexpr,
204
+ LOGIT_SCALE : tl.constexpr,
205
+ ):
206
+ """
207
+ CE_i = -y log(P) = y * (log[sum(exp(x))] - x)
208
+ dC/dx = d/dx (y * log[sum(exp(x))] - x * y)
209
+
210
+ From https://en.wikipedia.org/wiki/LogSumExp
211
+ d/dx logsumexp = exp(x) / sum(exp(x)) = softmax(x)
212
+
213
+ dC/dx = y * exp(x) / sum(exp(x)) - d/dx (x * y)
214
+ dC/dx = y * exp[ log[exp(x) / sum(exp(x))] ] using x = exp(log(x)) trick
215
+ dC/dx = y * exp[x - logsumexp] - d/dx (x * y)
216
+
217
+ If y == 0: dC/dx = 0
218
+ If y == 1 and x == label: dC/dlabel = exp[x - logsumexp] - 1
219
+ If y == 1 and x != label: dC/dx = exp[x - logsumexp]
220
+ """
221
+ row_idx = tl.program_id(0)
222
+ block_idx = tl.program_id(1)
223
+
224
+ logits_ptr += row_idx * triton_cast(logits_row_stride, tl.int64)
225
+ dloss_ptr += row_idx * dloss_row_stride
226
+ col_offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
227
+ mask = col_offsets < VOCAB_SIZE
228
+ label_idx = tl.load(labels_ptr + row_idx).to(tl.int32)
229
+
230
+ if label_idx != -100:
231
+ dloss = tl.load(dloss_ptr)
232
+ else:
233
+ dloss = 0.0
234
+
235
+ x = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf")).to(tl.float32)
236
+
237
+ # Do logit scaling for Cohere
238
+ if DO_LOGIT_SCALING:
239
+ # d/dx [s * x] = s
240
+ x = x * LOGIT_SCALE
241
+ pass
242
+
243
+ # Do logit softcapping for Gemma 2: t * tanh(1/t * x)
244
+ partial = x
245
+ if DO_SOFTCAPPING:
246
+ # d/dx [t * tanh(1/t * x)] = 1 - tanh^2(1/t * x)
247
+ partial = triton_tanh(x / SOFTCAP)
248
+ x = SOFTCAP * partial
249
+ pass
250
+
251
+ logsumexp = tl.load(logsumexp_ptr + row_idx)
252
+ y = tl.exp(x - logsumexp)
253
+ y = tl.where(
254
+ col_offsets == label_idx,
255
+ y - 1.0, # exp(x - logsumexp) - 1
256
+ y, # exp(x - logsumexp)
257
+ )
258
+
259
+ if DO_LOGIT_SCALING:
260
+ # d/dx [s * x] = s
261
+ y = y * LOGIT_SCALE
262
+ pass
263
+
264
+ if DO_SOFTCAPPING:
265
+ # d/dx [t * tanh(1/t * x)] = 1 - tanh^2(1/t * x)
266
+ y = y * (1.0 - partial*partial)
267
+ pass
268
+
269
+ # If y == 0: dC/dx = 0 ==> we already masked it to be = 0, so dloss = 0.
270
+ tl.store(logits_ptr + col_offsets, dloss * y, mask = mask)
271
+ pass
272
+ _cross_entropy_backward = triton.jit(_cross_entropy_backward)
273
+ _cross_entropy_backward = triton.heuristics(
274
+ {
275
+ "DO_SOFTCAPPING": lambda args: bool(args["DO_SOFTCAPPING" ]),
276
+ "DO_LOGIT_SCALING": lambda args: bool(args["DO_LOGIT_SCALING"]),
277
+ }
278
+ )(_cross_entropy_backward)
279
+
280
+
281
+ MAX_FUSED_SIZE = 65536 # 2**16
282
+ class Fast_CrossEntropyLoss(torch.autograd.Function):
283
+ @staticmethod
284
+ def forward(ctx, logits, labels, logit_softcapping : float = 0, logit_scaling : float = 0):
285
+ n_rows : int
286
+ vocab_size : int
287
+ n_rows, vocab_size = logits.shape
288
+ device = logits.device
289
+
290
+ div, mod = divmod(vocab_size, MAX_FUSED_SIZE)
291
+ n_chunks : int = div + (mod != 0)
292
+ losses = torch.empty(n_rows, dtype = torch.float32, device = device)
293
+
294
+ DO_SOFTCAPPING : bool = bool(logit_softcapping != 0)
295
+ DO_LOGIT_SCALING : bool = bool(logit_scaling != 0)
296
+
297
+ BLOCK_SIZE : int
298
+ num_warps : int
299
+ if n_chunks == 1:
300
+ # For small vocabs <= 65336 like Llama, Mistral
301
+ BLOCK_SIZE, num_warps = calculate_settings(vocab_size)
302
+ logsumexp = torch.empty(n_rows, dtype = torch.float32, device = device)
303
+
304
+ with torch_cuda_device(device):
305
+ _cross_entropy_forward[(n_rows,)](
306
+ logits, logits.stride(0),
307
+ losses,
308
+ logsumexp,
309
+ labels,
310
+ VOCAB_SIZE = vocab_size,
311
+ BLOCK_SIZE = BLOCK_SIZE,
312
+ DO_SOFTCAPPING = DO_SOFTCAPPING,
313
+ SOFTCAP = logit_softcapping,
314
+ DO_LOGIT_SCALING = DO_LOGIT_SCALING,
315
+ LOGIT_SCALE = logit_scaling,
316
+ num_warps = num_warps,
317
+ )
318
+ else:
319
+ # For large vocabs > 65336 like Gemma 256K
320
+ logsumexp = torch.empty((n_rows, n_chunks,), dtype = torch.float32, device = device)
321
+
322
+ with torch_cuda_device(device):
323
+ _chunked_cross_entropy_forward[(n_rows, n_chunks,)](
324
+ logits, logits.stride(0),
325
+ losses,
326
+ logsumexp,
327
+ labels,
328
+ VOCAB_SIZE = vocab_size,
329
+ N_CHUNKS = n_chunks,
330
+ BLOCK_SIZE = MAX_FUSED_SIZE,
331
+ DO_SOFTCAPPING = DO_SOFTCAPPING,
332
+ SOFTCAP = logit_softcapping,
333
+ DO_LOGIT_SCALING = DO_LOGIT_SCALING,
334
+ LOGIT_SCALE = logit_scaling,
335
+ num_warps = 32,
336
+ )
337
+ # logsumexp(chunked_logsumexp) - x
338
+ # Do the -x separately
339
+ logsumexp = torch.logsumexp(logsumexp, dim = 1) # Row sum
340
+ losses += logsumexp
341
+ losses.masked_fill_(labels == -100, 0) # Don't forget to mask padding out!
342
+ pass
343
+
344
+ ctx.save_for_backward(logits, logsumexp, labels)
345
+ ctx.DO_SOFTCAPPING = DO_SOFTCAPPING
346
+ ctx.logit_softcapping = logit_softcapping
347
+ ctx.DO_LOGIT_SCALING = DO_LOGIT_SCALING
348
+ ctx.logit_scaling = logit_scaling
349
+ return losses
350
+ pass
351
+
352
+
353
+ @staticmethod
354
+ def backward(ctx, dlosses):
355
+ logits, logsumexp, labels = ctx.saved_tensors
356
+ n_rows : int
357
+ vocab_size : int
358
+ n_rows, vocab_size = logits.shape
359
+
360
+ BLOCK_SIZE : int = 4096
361
+ div : int
362
+ mod : int
363
+ div, mod = divmod(vocab_size, BLOCK_SIZE)
364
+ n_blocks : int = div + (mod != 0)
365
+
366
+ with torch_cuda_device(dlosses.device):
367
+ _cross_entropy_backward[(n_rows, n_blocks,)](
368
+ logits, logits.stride(0),
369
+ dlosses, dlosses.stride(0),
370
+ logsumexp,
371
+ labels,
372
+ VOCAB_SIZE = vocab_size,
373
+ BLOCK_SIZE = BLOCK_SIZE,
374
+ DO_SOFTCAPPING = ctx.DO_SOFTCAPPING,
375
+ SOFTCAP = ctx.logit_softcapping,
376
+ DO_LOGIT_SCALING = ctx.DO_LOGIT_SCALING,
377
+ LOGIT_SCALE = ctx.logit_scaling,
378
+ num_warps = 8,
379
+ )
380
+ return logits, None, None, None,
381
+ pass
382
+ pass
383
+
384
+
385
+ def fast_cross_entropy_loss(
386
+ logits,
387
+ labels,
388
+ logit_softcapping = 0,
389
+ logit_scaling = 0,
390
+ n_items = None,
391
+ ):
392
+ """
393
+ Arguments:
394
+ logits: (batch, seq_len, vocab_size)
395
+ labels: (batch, seq_len,)
396
+ Returns:
397
+ losses: float
398
+ """
399
+ batch, seq_len, d = logits.shape
400
+ assert(labels.shape == (batch, seq_len))
401
+
402
+ loss = Fast_CrossEntropyLoss.apply(
403
+ logits.view(batch*seq_len, d),
404
+ labels.view(-1),
405
+ logit_softcapping,
406
+ logit_scaling,
407
+ )
408
+ if n_items is None:
409
+ n_items = torch.count_nonzero(labels != -100)
410
+ return loss.sum() / n_items
411
+ pass
412
+ if (Version(torch.__version__) < Version("2.4.0")) and \
413
+ not hasattr(fast_cross_entropy_loss, "__wrapped__"):
414
+ fast_cross_entropy_loss = torch._disable_dynamo(fast_cross_entropy_loss)
415
+ pass
416
+
417
+ # Patch CE Losses in transformers
418
+ def patch_loss_functions(torch_compile = True):
419
+ _patch_loss_functions(fast_cross_entropy_loss, torch_compile = torch_compile)
420
+ pass
torch-ext/unsloth_kernels/fast_lora.py ADDED
@@ -0,0 +1,537 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+ from .utils import (
17
+ fast_dequantize,
18
+ QUANT_STATE,
19
+ get_lora_parameters,
20
+ get_lora_parameters_bias,
21
+ matmul_lora,
22
+ torch_amp_custom_fwd,
23
+ torch_amp_custom_bwd,
24
+ )
25
+
26
+
27
+ class LoRA_MLP(torch.autograd.Function):
28
+ """
29
+ ### LoRA weights
30
+ G = G + Ag @ Bg
31
+ U = U + Au @ Bu
32
+ W = W + Aw @ Bw
33
+
34
+ ### SwiGLU(X)
35
+ e = X @ G
36
+ f = e * sigmoid(e)
37
+ g = X @ U
38
+ h = f * g
39
+ i = h @ W
40
+
41
+ ### Backpropagation chain rule
42
+ See our blog post for more details
43
+
44
+ df = sigmoid(e) * (1 - f) + f
45
+ dC/dW = h.T @ dY
46
+ dC/dU = X.T @ (D @ W.T * f)
47
+ dC/dG = X.T @ (D @ W.T * df * g)
48
+
49
+ ### Down projection LoRA weights
50
+ dC/dAw = dC/dW @ B.T
51
+ dC/dBw = A.T @ dC/dW
52
+ dC/dAw = h.T @ dY @ B.T
53
+ dC/dBw = A.T @ h.T @ dY
54
+
55
+ ### Up projection LoRA weights
56
+ dC/dAu = X.T @ (D @ W.T * f) @ B.T
57
+ dC/dBu = A.T @ X.T @ (D @ W.T * f)
58
+
59
+ ### Gate projection LoRA weights
60
+ dC/dAg = X.T @ (D @ W.T * df * g) @ B.T
61
+ dC/dBg = A.T @ X.T @ (D @ W.T * df * g)
62
+
63
+ Don't forget to see our blog post for more details!
64
+ """
65
+ @staticmethod
66
+ @torch_amp_custom_fwd
67
+ def forward(ctx, X : torch.Tensor,
68
+ gateW, gateW_quant, gateA, gateB, gateS,
69
+ upW, upW_quant, upA, upB, upS,
70
+ downW, downW_quant, downA, downB, downS,
71
+ _forward_function, _backward_function,
72
+ inplace = True,):
73
+ dtype = X.dtype
74
+
75
+ e = matmul_lora(X, gateW, gateW_quant, gateA, gateB, gateS)
76
+ g = matmul_lora(X, upW, upW_quant, upA, upB, upS)
77
+ h = _forward_function(e, g)
78
+ i = matmul_lora(h, downW, downW_quant, downA, downB, downS)
79
+
80
+ ctx.custom_saved_tensors = (
81
+ gateW, gateW_quant, gateS,
82
+ upW, upW_quant, upS,
83
+ downW, downW_quant, downS,
84
+ _backward_function,
85
+ )
86
+ ctx.save_for_backward(gateA, gateB, upA, upB, downA, downB,
87
+ X, e, g)
88
+ ctx.inplace = inplace
89
+ return i
90
+ pass
91
+
92
+
93
+ @staticmethod
94
+ @torch_amp_custom_bwd
95
+ def backward(ctx, dY : torch.Tensor):
96
+ gateW, gateW_quant, gateS, upW, upW_quant, upS, downW, downW_quant, downS, \
97
+ _backward_function = ctx.custom_saved_tensors
98
+ gateA, gateB, upA, upB, downA, downB, \
99
+ X, e, g = ctx.saved_tensors
100
+
101
+ batch, seq_len, hd = X.shape
102
+ dY = dY.view(-1, dY.shape[-1])
103
+ X = X .view(-1, X .shape[-1])
104
+ e = e .view(-1, e .shape[-1])
105
+ g = g .view(-1, g .shape[-1])
106
+ dtype = X.dtype
107
+
108
+ gateA, gateB, upA, upB, downA, downB = \
109
+ gateA.to(dtype), gateB.to(dtype), upA.to(dtype), upB.to(dtype), downA.to(dtype), downB.to(dtype)
110
+
111
+ gateA, gateB, upA, upB, downA, downB = \
112
+ gateA.t(), gateB.t(), upA.t(), upB.t(), downA.t(), downB.t()
113
+
114
+ DW = matmul_lora(dY, downW.t(), downW_quant, downB, downA, downS)
115
+ DW, e, g = _backward_function(DW, e, g)
116
+ h, df, de = DW, e, g
117
+
118
+ d_downA = torch.empty_like(downA)
119
+ d_downB = torch.empty_like(downB)
120
+ d_gateA = torch.empty_like(gateA)
121
+ d_gateB = torch.empty_like(gateB)
122
+ d_upA = torch.empty_like(upA)
123
+ d_upB = torch.empty_like(upB)
124
+
125
+ # Down projection LoRA weights
126
+ # d_downA = h.t() @ (dY @ downB.t())
127
+ # d_downB = (downA.t() @ h.t()) @ dY
128
+ # d_downA *= downS
129
+ # d_downB *= downS
130
+ d_downA.addmm_(h.t(), dY @ downB.t(), alpha = downS, beta = 0)
131
+ d_downB.addmm_(downA.t() @ h.t(), dY, alpha = downS, beta = 0)
132
+
133
+ # Up projection LoRA weights
134
+ # d_upA = X.t() @ (df @ upB.t())
135
+ # d_upB = (upA.t() @ X.t()) @ df
136
+ # d_upA *= upS
137
+ # d_upB *= upS
138
+ d_upA.addmm_(X.t(), df @ upB.t(), alpha = upS, beta = 0)
139
+ d_upB.addmm_(upA.t() @ X.t(), df, alpha = upS, beta = 0)
140
+
141
+ # Gate projection LoRA weights
142
+ # d_gateA = X.t() @ (de @ gateB.t())
143
+ # d_gateB = (gateA.t() @ X.t()) @ de
144
+ # d_gateA *= gateS
145
+ # d_gateB *= gateS
146
+ d_gateA.addmm_(X.t(), de @ gateB.t(), alpha = gateS, beta = 0)
147
+ d_gateB.addmm_(gateA.t() @ X.t(), de, alpha = gateS, beta = 0)
148
+
149
+ # dX = matmul_lora(df, upW.t(), upW_quant, upB, upA, upS)
150
+ # dX += matmul_lora(de, gateW.t(), gateW_quant, gateB, gateA, gateS)
151
+ upW = fast_dequantize(upW.t(), upW_quant)
152
+ dX = torch.matmul(df, upW.t(), out = X if ctx.inplace else None)
153
+ del upW
154
+ # dX += df @ upB.to(dtype).t() @ (upS * upA.to(dtype).t())
155
+ dX.addmm_(df @ upB.t(), upA.t(), alpha = upS)
156
+
157
+ gateW = fast_dequantize(gateW.t(), gateW_quant)
158
+ # dX += de @ gateW.t()
159
+ dX.addmm_(de, gateW.t())
160
+ del gateW
161
+ # dX += de @ gateB.to(dtype).t() @ (gateS * gateA.to(dtype).t())
162
+ dX.addmm_(de @ gateB.t(), gateA.t(), alpha = gateS)
163
+
164
+ # gateW, gateW_quant, gateA, gateB, gateS,
165
+ # upW, upW_quant, upA, upB, upS,
166
+ # downW, downW_quant, downA, downB, downS,
167
+ return dX.view(batch, seq_len, hd), \
168
+ None, None, d_gateA.t(), d_gateB.t(), None, \
169
+ None, None, d_upA.t(), d_upB.t(), None, \
170
+ None, None, d_downA.t(), d_downB.t(), None, \
171
+ None, None, None, # _backward and _forward and inplace
172
+ pass
173
+ pass
174
+
175
+
176
+ from .swiglu import swiglu_fg_kernel, swiglu_DWf_DW_dfg_kernel
177
+ def apply_lora_mlp_swiglu(self, X, inplace = True):
178
+ gateW, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj)
179
+ upW, upW_quant, upA, upB, upS = get_lora_parameters(self. up_proj)
180
+ downW, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj)
181
+ out = LoRA_MLP.apply(X,
182
+ gateW, gateW_quant, gateA, gateB, gateS,
183
+ upW, upW_quant, upA, upB, upS,
184
+ downW, downW_quant, downA, downB, downS,
185
+ swiglu_fg_kernel, swiglu_DWf_DW_dfg_kernel,
186
+ inplace,)
187
+ return out
188
+ pass
189
+
190
+
191
+ from .geglu import geglu_exact_forward_kernel, geglu_exact_backward_kernel
192
+ def apply_lora_mlp_geglu_exact(self, X, inplace = True):
193
+ gateW, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj)
194
+ upW, upW_quant, upA, upB, upS = get_lora_parameters(self. up_proj)
195
+ downW, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj)
196
+ out = LoRA_MLP.apply(X,
197
+ gateW, gateW_quant, gateA, gateB, gateS,
198
+ upW, upW_quant, upA, upB, upS,
199
+ downW, downW_quant, downA, downB, downS,
200
+ geglu_exact_forward_kernel, geglu_exact_backward_kernel,
201
+ inplace,)
202
+ return out
203
+ pass
204
+
205
+
206
+ from .geglu import geglu_approx_forward_kernel, geglu_approx_backward_kernel
207
+ def apply_lora_mlp_geglu_approx(self, X):
208
+ gateW, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj)
209
+ upW, upW_quant, upA, upB, upS = get_lora_parameters(self. up_proj)
210
+ downW, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj)
211
+ out = LoRA_MLP.apply(X,
212
+ gateW, gateW_quant, gateA, gateB, gateS,
213
+ upW, upW_quant, upA, upB, upS,
214
+ downW, downW_quant, downA, downB, downS,
215
+ geglu_approx_forward_kernel, geglu_approx_backward_kernel,)
216
+ return out
217
+ pass
218
+
219
+
220
+ class LoRA_QKV(torch.autograd.Function):
221
+ """
222
+ ### LoRA weights
223
+ Wq = Wq + Aq @ Bq
224
+ Wk = Wk + Ak @ Bk
225
+ Wv = Wv + Av @ Bv
226
+ Q = X @ Wq = X @ Wq + X @ Aq @ Bq
227
+ K = X @ Wk = X @ Wk + X @ Ak @ Bk
228
+ V = X @ Wv = X @ Wv + X @ Av @ Bv
229
+
230
+ ### Backpropagation chain rule
231
+ See our blogpost for more details.
232
+
233
+ dC/dWq = X.T @ D(Wq)
234
+ dC/dWk = X.T @ D(Wk)
235
+ dC/dWv = X.T @ D(Wv)
236
+ We then sum them all find dC/dX
237
+
238
+ ### Q projection LoRA weights
239
+ dC/dAq = X.T @ D(Wq) @ B.T
240
+ dC/dBq = A.T @ X.T @ D(Wq)
241
+
242
+ ### K projection LoRA weights
243
+ dC/dAk = X.T @ D(Wk) @ B.T
244
+ dC/dBk = A.T @ X.T @ D(Wk)
245
+
246
+ ### V projection LoRA weights
247
+ dC/dAv = X.T @ D(Wv) @ B.T
248
+ dC/dBv = A.T @ X.T @ D(Wv)
249
+ """
250
+ @staticmethod
251
+ @torch_amp_custom_fwd
252
+ def forward(ctx, X : torch.Tensor,
253
+ QW, QW_quant, QA, QB, QS,
254
+ KW, KW_quant, KA, KB, KS,
255
+ VW, VW_quant, VA, VB, VS,
256
+ inplace = True):
257
+ dtype = X.dtype
258
+
259
+ Q = matmul_lora(X, QW, QW_quant, QA, QB, QS)
260
+ K = matmul_lora(X, KW, KW_quant, KA, KB, KS)
261
+ V = matmul_lora(X, VW, VW_quant, VA, VB, VS)
262
+
263
+ ctx.custom_saved_tensors = (
264
+ QW, QW_quant, QS,
265
+ KW, KW_quant, KS,
266
+ VW, VW_quant, VS,
267
+ )
268
+ ctx.save_for_backward(X, QA, QB, KA, KB, VA, VB,)
269
+ ctx.inplace = inplace
270
+ return Q, K, V
271
+ pass
272
+
273
+ @staticmethod
274
+ @torch_amp_custom_bwd
275
+ def backward(ctx, dQ, dK, dV):
276
+ QW, QW_quant, QS, KW, KW_quant, KS, VW, VW_quant, VS = \
277
+ ctx.custom_saved_tensors
278
+ X, QA, QB, KA, KB, VA, VB, = ctx.saved_tensors
279
+
280
+ batch, seq_len, hd = X.shape
281
+ dQ = dQ.view(-1, dQ.shape[-1])
282
+ dK = dK.reshape(-1, dK.shape[-1]) # view doesn't work on K.T
283
+ dV = dV.view(-1, dV.shape[-1])
284
+ X = X .view(-1, X .shape[-1])
285
+ dtype = X.dtype
286
+
287
+ QA, QB, KA, KB, VA, VB = \
288
+ QA.to(dtype), QB.to(dtype), KA.to(dtype), KB.to(dtype), VA.to(dtype), VB.to(dtype)
289
+
290
+ QA, QB, KA, KB, VA, VB = \
291
+ QA.t(), QB.t(), KA.t(), KB.t(), VA.t(), VB.t()
292
+
293
+ ### Weight projection LoRA weights
294
+ # See our blogpost for more details.
295
+ d_QA = torch.empty_like(QA)
296
+ d_QB = torch.empty_like(QB)
297
+ d_KA = torch.empty_like(KA)
298
+ d_KB = torch.empty_like(KB)
299
+ d_VA = torch.empty_like(VA)
300
+ d_VB = torch.empty_like(VB)
301
+
302
+ # Q Projection
303
+ # d_QA = X.t() @ (dQ @ QB.t())
304
+ # d_QB = (QA.t() @ X.t()) @ dQ
305
+ # d_QA *= QS
306
+ # d_QB *= QS
307
+ d_QA.addmm_(X.t(), dQ @ QB.t(), alpha = QS, beta = 0)
308
+ d_QB.addmm_(QA.t() @ X.t(), dQ, alpha = QS, beta = 0)
309
+
310
+ # K Projection
311
+ # d_KA = X.t() @ (dK @ KB.t())
312
+ # d_KB = (KA.t() @ X.t()) @ dK
313
+ # d_KA *= KS
314
+ # d_KB *= KS
315
+ d_KA.addmm_(X.t(), dK @ KB.t(), alpha = KS, beta = 0)
316
+ d_KB.addmm_(KA.t() @ X.t(), dK, alpha = KS, beta = 0)
317
+
318
+ # V Projection
319
+ # d_VA = X.t() @ (dV @ VB.t())
320
+ # d_VB = (VA.t() @ X.t()) @ dV
321
+ # d_VA *= VS
322
+ # d_VB *= VS
323
+ d_VA.addmm_(X.t(), dV @ VB.t(), alpha = VS, beta = 0)
324
+ d_VB.addmm_(VA.t() @ X.t(), dV, alpha = VS, beta = 0)
325
+
326
+ # Combine derivatives to find dX
327
+ # dQ
328
+ QW = fast_dequantize(QW.t(), QW_quant)
329
+ dX = torch.matmul(dQ, QW.t(), out = X if ctx.inplace else None)
330
+ del QW
331
+ # dX += (dQ @ QB.to(dtype).t() @ (QS * QA.to(dtype).t()))
332
+ dX.addmm_(dQ @ QB.t(), QA.t(), alpha = QS)
333
+
334
+ # dK
335
+ KW = fast_dequantize(KW.t(), KW_quant)
336
+ # dX += dK @ KW.t()
337
+ dX.addmm_(dK, KW.t())
338
+ del KW
339
+ # dX += dK @ KB.to(dtype).t() @ (KS * KA.to(dtype).t())
340
+ dX.addmm_(dK @ KB.t(), KA.t(), alpha = KS)
341
+
342
+ # dV
343
+ VW = fast_dequantize(VW.t(), VW_quant)
344
+ # dX += dV @ VW.t()
345
+ dX.addmm_(dV, VW.t())
346
+ del VW
347
+ # dX += dV @ VB.to(dtype).t() @ (VS * VA.to(dtype).t())
348
+ dX.addmm_(dV @ VB.t(), VA.t(), alpha = VS)
349
+
350
+ # QW, QW_quant, QA, QB, QS,
351
+ # KW, KW_quant, KA, KB, KS,
352
+ # VW, VW_quant, VA, VB, VS,
353
+ return dX.view(batch, seq_len, hd), \
354
+ None, None, d_QA.t(), d_QB.t(), None, \
355
+ None, None, d_KA.t(), d_KB.t(), None, \
356
+ None, None, d_VA.t(), d_VB.t(), None, \
357
+ None,
358
+ pass
359
+ pass
360
+
361
+
362
+ def apply_lora_qkv(self, X, inplace = True):
363
+ QW, QW_quant, QA, QB, QS = get_lora_parameters(self.q_proj)
364
+ KW, KW_quant, KA, KB, KS = get_lora_parameters(self.k_proj)
365
+ VW, VW_quant, VA, VB, VS = get_lora_parameters(self.v_proj)
366
+ Q, K, V = LoRA_QKV.apply(X,
367
+ QW, QW_quant, QA, QB, QS,
368
+ KW, KW_quant, KA, KB, KS,
369
+ VW, VW_quant, VA, VB, VS,
370
+ inplace,
371
+ )
372
+ return Q, K, V
373
+ pass
374
+
375
+
376
+ class LoRA_W(torch.autograd.Function):
377
+ """
378
+ ### LoRA weights
379
+ Wq = Wq + Aq @ Bq
380
+ Wk = Wk + Ak @ Bk
381
+ Wv = Wv + Av @ Bv
382
+ Q = X @ Wq = X @ Wq + X @ Aq @ Bq
383
+ K = X @ Wk = X @ Wk + X @ Ak @ Bk
384
+ V = X @ Wv = X @ Wv + X @ Av @ Bv
385
+
386
+ ### Backpropagation chain rule
387
+ dC/dWq = X.T @ D(Wq)
388
+ dC/dWk = X.T @ D(Wk)
389
+ dC/dWv = X.T @ D(Wv)
390
+
391
+ ### Q projection LoRA weights
392
+ dC/dAq = X.T @ D(Wq) @ B.T
393
+ dC/dBq = A.T @ X.T @ D(Wq)
394
+
395
+ ### K projection LoRA weights
396
+ dC/dAk = X.T @ D(Wk) @ B.T
397
+ dC/dBk = A.T @ X.T @ D(Wk)
398
+
399
+ ### V projection LoRA weights
400
+ dC/dAv = X.T @ D(Wv) @ B.T
401
+ dC/dBv = A.T @ X.T @ D(Wv)
402
+ """
403
+ @staticmethod
404
+ @torch_amp_custom_fwd
405
+ def forward(ctx, X : torch.Tensor,
406
+ W, W_quant, A, B, S):
407
+ dtype = X.dtype
408
+ XW = matmul_lora(X, W, W_quant, A, B, S)
409
+ ctx.custom_saved_tensors = (W, W_quant, S,)
410
+ ctx.save_for_backward(A, B, X)
411
+ return XW
412
+ pass
413
+
414
+ @staticmethod
415
+ @torch_amp_custom_bwd
416
+ def backward(ctx, dY : torch.Tensor):
417
+ W, W_quant, S = ctx.custom_saved_tensors
418
+ A, B, X = ctx.saved_tensors
419
+
420
+ batch, seq_len, hd = X.shape
421
+ dY = dY.reshape(-1, dY.shape[-1]) # Must be reshape
422
+ X = X .reshape(-1, X .shape[-1]) # Must be reshape
423
+ dtype = X.dtype
424
+
425
+ A, B = A.to(dtype), B.to(dtype)
426
+
427
+ A, B = A.t(), B.t()
428
+
429
+ d_A = torch.empty_like(A)
430
+ d_B = torch.empty_like(B)
431
+
432
+ ### Weight projection LoRA weights
433
+ # Weight projection
434
+ # d_A = X.t() @ (dY @ B.t())
435
+ # d_B = (A.t() @ X.t()) @ dY
436
+ # d_A *= S
437
+ # d_B *= S
438
+ d_A.addmm_(X.t(), dY @ B.t(), alpha = S, beta = 0)
439
+ d_B.addmm_(A.t() @ X.t(), dY, alpha = S, beta = 0)
440
+
441
+ # Get derivative for dX
442
+ W = fast_dequantize(W.t(), W_quant)
443
+ dX = dY @ W.t()
444
+ del W
445
+ # dX += dY @ B.to(dtype).t() @ (S * A.to(dtype).t())
446
+ dX.addmm_(dY @ B.t(), A.t(), alpha = S)
447
+
448
+ # W, W_quant, A, B, S
449
+ return dX.view(batch, seq_len, hd), \
450
+ None, None, d_A.t(), d_B.t(), None
451
+ pass
452
+ pass
453
+
454
+
455
+ def apply_lora_o(self, X):
456
+ OW, OW_quant, OA, OB, OS = get_lora_parameters(self.o_proj)
457
+ O = LoRA_W.apply(X, OW, OW_quant, OA, OB, OS)
458
+ return O
459
+ pass
460
+
461
+
462
+ IDENTITY_DROPOUT = torch.nn.Identity
463
+ @torch._disable_dynamo
464
+ def fast_lora_forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
465
+ raise NotImplementedError(
466
+ "Unsloth: Currently not supported yet - reshaping done incorrectly"
467
+ )
468
+ self._check_forward_args(x, *args, **kwargs)
469
+ adapter_names = kwargs.pop("adapter_names", None)
470
+
471
+ if self.disable_adapters:
472
+ if self.merged:
473
+ self.unmerge()
474
+ result = self.base_layer(x, *args, **kwargs)
475
+ elif adapter_names is not None:
476
+ result = self._mixed_batch_forward(x, *args, adapter_names=adapter_names, **kwargs)
477
+ elif self.merged:
478
+ result = self.base_layer(x, *args, **kwargs)
479
+ else:
480
+ # Fastpath
481
+ if len(self.active_adapters) == 1:
482
+ active_adapter = self.active_adapters[0]
483
+ if active_adapter not in self.lora_A.keys(): return self.base_layer(x, *args, **kwargs)
484
+
485
+ dropout = self.lora_dropout[active_adapter]
486
+ if isinstance(dropout, IDENTITY_DROPOUT) and not self.use_dora[active_adapter]:
487
+ lora_A = self.lora_A[active_adapter].weight
488
+ lora_B = self.lora_B[active_adapter].weight
489
+ scaling = self.scaling[active_adapter]
490
+ W = self.base_layer.weight
491
+ return LoRA_W.apply(x, W, QUANT_STATE(W), lora_A, lora_B, scaling)
492
+ pass
493
+ pass
494
+
495
+ result = self.base_layer(x, *args, **kwargs)
496
+ # As per Tim Dettmers, for 4bit, we need to defensively clone here.
497
+ # The reason is that in some cases, an error can occur that backprop
498
+ # does not work on a manipulated view. This issue may be solved with
499
+ # newer PyTorch versions but this would need extensive testing to be
500
+ # sure.
501
+ result = result.clone()
502
+
503
+ for active_adapter in self.active_adapters:
504
+ if active_adapter not in self.lora_A.keys():
505
+ continue
506
+ lora_A = self.lora_A[active_adapter]
507
+ lora_B = self.lora_B[active_adapter]
508
+ dropout = self.lora_dropout[active_adapter]
509
+ scaling = self.scaling[active_adapter]
510
+
511
+ requires_conversion = not torch.is_autocast_enabled()
512
+ if requires_conversion:
513
+ expected_dtype = result.dtype
514
+ x = x.to(lora_A.weight.dtype)
515
+
516
+ if not self.use_dora[active_adapter]:
517
+ result = result + lora_B(lora_A(dropout(x))) * scaling
518
+ else:
519
+ if isinstance(dropout, torch.nn.Identity) or not self.training:
520
+ base_result = result
521
+ else:
522
+ x = dropout(x)
523
+ base_result = None
524
+
525
+ result = result + self.lora_magnitude_vector[active_adapter](
526
+ x,
527
+ lora_A=lora_A,
528
+ lora_B=lora_B,
529
+ scaling=scaling,
530
+ base_layer=self.get_base_layer(),
531
+ base_result=base_result,
532
+ )
533
+ if requires_conversion:
534
+ result = result.to(expected_dtype)
535
+
536
+ return result
537
+ pass
torch-ext/unsloth_kernels/flex_attention.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+ from functools import lru_cache
17
+ from transformers.models.llama.modeling_llama import logger
18
+ import os
19
+
20
+ torch_compile_options = {
21
+ "epilogue_fusion" : True,
22
+ "max_autotune" : True,
23
+ "shape_padding" : True,
24
+ "trace.enabled" : os.environ.get("UNSLOTH_COMPILE_DEBUG", "0") == "1",
25
+ "triton.cudagraphs" : False,
26
+ }
27
+
28
+ # Flex Attention supported from torch 2.5 onwards only
29
+ try:
30
+ from torch.nn.attention.flex_attention import (
31
+ flex_attention as _flex_attention,
32
+ create_block_mask as _create_block_mask,
33
+ )
34
+ _flex_attention = torch.compile(_flex_attention, dynamic = True, options = torch_compile_options)
35
+ HAS_FLEX_ATTENTION = False
36
+ except:
37
+ HAS_FLEX_ATTENTION = False
38
+ pass
39
+
40
+
41
+ if not HAS_FLEX_ATTENTION:
42
+
43
+ # Logit softcapping
44
+ @torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options)
45
+ def slow_attention_softcapping(Q, K, V, causal_mask, self, bsz, q_len):
46
+ n_heads = self.config.num_attention_heads
47
+ head_dim = self.head_dim
48
+ n_kv_heads = self.config.num_key_value_heads
49
+ n_groups = self.num_key_value_groups
50
+
51
+ # Grouped query attention
52
+ K = K[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, q_len, head_dim)
53
+ V = V[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, q_len, head_dim)
54
+ K = K.reshape(bsz, n_heads, q_len, head_dim)
55
+ V = V.reshape(bsz, n_heads, q_len, head_dim)
56
+
57
+ # See https://github.com/google/gemma_pytorch/commit/03e657582d17cb5a8617ebf333c1c16f3694670e
58
+ # Gemma 9b should use 256 and not 224 (hs / nah). 27b uses the below
59
+ # We default to using the config file itself
60
+ # s = self.config.hidden_size // self.config.num_attention_heads
61
+ s = self.config.query_pre_attn_scalar
62
+ t = self.config.attn_logit_softcapping
63
+
64
+ Q = Q * torch.tensor(s**-0.5, dtype = Q.dtype) # Follow Keras exactly
65
+ A = torch.matmul(Q, K.transpose(2, 3))
66
+ A = t * torch.tanh(A / t) # Logit softcapping
67
+ A += causal_mask[:q_len, :q_len]
68
+ # Much slower in torch compile!
69
+ # A.masked_fill_(causal_mask[:q_len, :q_len], -float("inf"))
70
+ A = torch.nn.functional.softmax(A, dim = -1, dtype = torch.float32).to(Q.dtype)
71
+ A = torch.matmul(A, V)
72
+ A = A.transpose(1, 2).contiguous()
73
+ A = A.reshape(bsz, q_len, n_heads*head_dim)
74
+ return A
75
+ pass
76
+
77
+ create_flex_attention_causal_mask = None
78
+ create_flex_attention_sliding_window_mask = None
79
+ else:
80
+ # See https://github.com/pytorch-labs/attention-gym/blob/main/examples/flex_attn.ipynb
81
+ # for more examples
82
+ # BSD 3-Clause License Copyright (c) 2023, Driss Guessous, Horace He et al
83
+ import functools, math
84
+
85
+ def generate_tanh_softcap(t):
86
+ def tanh_softcap(x, b, h, q_idx, kv_idx):
87
+ return t * torch.tanh(x / t)
88
+ return tanh_softcap
89
+ pass
90
+ def causal_masker(b, h, q_idx, kv_idx):
91
+ return q_idx >= kv_idx
92
+ pass
93
+
94
+ @functools.lru_cache
95
+ def sliding_window_masker(size = 4096):
96
+ def sliding_window(b, h, q_idx, kv_idx):
97
+ causal_mask = q_idx >= kv_idx
98
+ window_mask = q_idx - kv_idx <= size
99
+ return causal_mask & window_mask
100
+ return sliding_window
101
+ pass
102
+
103
+ @functools.lru_cache
104
+ def create_block_mask(mask, n = 128):
105
+ return _create_block_mask(
106
+ mask, 1, 1, n, n,
107
+ BLOCK_SIZE = 128,
108
+ _compile = True,
109
+ )
110
+ pass
111
+
112
+ def create_flex_attention_causal_mask(max_seq_length = 8192):
113
+ causal_mask = create_block_mask(causal_masker, max_seq_length)
114
+ return causal_mask
115
+ pass
116
+
117
+ def create_flex_attention_sliding_window_mask(max_seq_length = 8192, sliding_window = 4096):
118
+ sliding_masker = sliding_window_masker(sliding_window)
119
+ causal_mask = create_block_mask(sliding_masker, max_seq_length)
120
+ return causal_mask
121
+ pass
122
+
123
+ @functools.lru_cache
124
+ def flex_attention(s, t):
125
+ scale = 1.0 / math.sqrt(s)
126
+ score_mod = generate_tanh_softcap(t)
127
+ return functools.partial(
128
+ _flex_attention, score_mod = score_mod, scale = scale, enable_gqa = True,
129
+ )
130
+ pass
131
+
132
+ def slow_attention_softcapping(Q, K, V, causal_mask, self, bsz, q_len):
133
+ n_heads = self.config.num_attention_heads
134
+ head_dim = self.head_dim
135
+ s = self.config.query_pre_attn_scalar
136
+ t = self.config.attn_logit_softcapping
137
+ fx = flex_attention(s, t)
138
+ A = fx(query = Q, key = K, value = V, block_mask = causal_mask)
139
+ A = A.transpose(1, 2).contiguous()
140
+ A = A.reshape(bsz, q_len, n_heads*head_dim)
141
+ return A
142
+ pass
143
+ pass
144
+
145
+
146
+ torch_matmul = torch.matmul
147
+ torch_tanh = torch.tanh
148
+ torch_nn_functional_softmax = torch.nn.functional.softmax
149
+ def slow_inference_attention_softcapping(Q, K, V, causal_mask, self, bsz, q_len):
150
+ n_heads = self.config.num_attention_heads
151
+ head_dim = self.head_dim
152
+ n_kv_heads = self.config.num_key_value_heads
153
+ n_groups = self.num_key_value_groups
154
+
155
+ # Grouped query attention
156
+ K = K[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, q_len, head_dim)
157
+ V = V[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, q_len, head_dim)
158
+ K = K.reshape(bsz, n_heads, q_len, head_dim)
159
+ V = V.reshape(bsz, n_heads, q_len, head_dim)
160
+
161
+ # See https://github.com/google/gemma_pytorch/commit/03e657582d17cb5a8617ebf333c1c16f3694670e
162
+ # Gemma 9b should use 256 and not 224 (hs / nah). 27b uses the below
163
+ # We default to using the config file itself
164
+ # s = self.config.hidden_size // self.config.num_attention_heads
165
+ s = self.config.query_pre_attn_scalar
166
+ t = self.config.attn_logit_softcapping
167
+
168
+ Q = Q * torch.tensor(s**-0.5, dtype = Q.dtype) # Follow Keras exactly
169
+ A = torch_matmul(Q, K.transpose(2, 3))
170
+
171
+ # Logit softcapping
172
+ A /= t; torch_tanh(A, out = A); A *= t;
173
+ A += causal_mask[:q_len, :q_len]
174
+ # Much slower in torch compile!
175
+ # A.masked_fill_(causal_mask[:q_len, :q_len], -float("inf"))
176
+ A = torch_nn_functional_softmax(A, dim = -1, dtype = torch.float32).to(Q.dtype)
177
+ A = torch_matmul(A, V)
178
+ A = A.transpose(1, 2).contiguous()
179
+ A = A.reshape(bsz, q_len, n_heads*head_dim)
180
+ return A
181
+ pass
torch-ext/unsloth_kernels/geglu.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import triton
16
+ import triton.language as tl
17
+ import torch
18
+ from .utils import (
19
+ calculate_settings,
20
+ triton_tanh,
21
+ torch_cuda_device,
22
+ )
23
+
24
+
25
+ @triton.jit
26
+ def _exact_forward_kernel(e, g, h, n_elements, BLOCK_SIZE : tl.constexpr,):
27
+ block_idx = tl.program_id(0)
28
+ offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
29
+ mask = offsets < n_elements
30
+
31
+ # f = 1/2 * e * (1 + erf(1/sqrt(2) * e))
32
+ # h = f * up
33
+ e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32)
34
+ g_row = tl.load(g + offsets, mask = mask, other = 0)#.to(tl.float32)
35
+
36
+ f_row = 0.5 * e_row * (tl.math.erf(tl.math.rsqrt(2.0) * e_row) + 1.0)
37
+ f_row = f_row.to(g_row.dtype) # Exact copy from HF
38
+ h_row = f_row * g_row
39
+
40
+ # Store h
41
+ tl.store(h + offsets, h_row, mask = mask)
42
+ pass
43
+
44
+
45
+ def geglu_exact_forward_kernel(gate, up):
46
+ batch, seq_len, hd = gate.shape
47
+ n_elements = gate.numel()
48
+ device = gate.device
49
+ out = torch.empty((batch, seq_len, hd), dtype = gate.dtype, device = device)
50
+ grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
51
+ with torch_cuda_device(device):
52
+ _exact_forward_kernel[grid](gate, up, out, n_elements, BLOCK_SIZE = 1024,)
53
+ return out
54
+ pass
55
+
56
+
57
+ @triton.jit
58
+ def _exact_backward_kernel(DW, e, g, n_elements, BLOCK_SIZE : tl.constexpr,):
59
+ """
60
+ f = 1/2 * e * (1 + erf(1/sqrt(2) * e))
61
+ h = f * up
62
+
63
+ df/de (with help of Wolfram :)
64
+ df/de = 1/2 * (1 + erf(1/sqrt(2) * e)) + 1/sqrt(2*pi) * e * exp(-1/2 * e^2)
65
+
66
+ Reuse via
67
+ f = 1/2 * (1 + erf(1/sqrt(2) * e)) * e
68
+ """
69
+ block_idx = tl.program_id(0)
70
+ offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
71
+ mask = offsets < n_elements
72
+
73
+ DW_row = tl.load(DW + offsets, mask = mask, other = 0)#.to(tl.float32)
74
+ e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32)
75
+ g_row = tl.load(g + offsets, mask = mask, other = 0)#.to(tl.float32)
76
+
77
+ # Break e_row away for re-use
78
+ # f = 1/2 * e * (1 + erf(1/sqrt(2) * e))
79
+ f_partial_row = 0.5 * (tl.math.erf(tl.math.rsqrt(2.0) * e_row) + 1.0)
80
+ f_row = f_partial_row * e_row
81
+
82
+ f_row = f_row.to(DW_row.dtype)
83
+ # h = f * g
84
+ h_row = f_row * g_row
85
+ # df = DW * f
86
+ df_row = DW_row * f_row
87
+ # dg = DW * g
88
+ dg_row = DW_row * g_row
89
+
90
+ # df/de = 1/2 * (1 + erf(1/sqrt(2) * e)) + 1/sqrt(2*pi) * e * exp(-1/2 * e^2)
91
+ t = 0.3989422804014327 # 1/sqrt(2*pi)
92
+ df_de = f_partial_row + t * e_row * tl.exp(-0.5 * e_row * e_row)
93
+
94
+ de_row = dg_row.to(tl.float32) * df_de
95
+ de_row = de_row.to(DW_row.dtype)
96
+
97
+ # Store derivatives in buffers
98
+ tl.store(DW + offsets, h_row, mask = mask) # h = f * g
99
+ tl.store(e + offsets, df_row, mask = mask) # df = DW * f
100
+ tl.store(g + offsets, de_row, mask = mask) # de
101
+ pass
102
+
103
+
104
+ def geglu_exact_backward_kernel(DW, e, g):
105
+ batch_seq_len, hd = e.shape
106
+ n_elements = e.numel()
107
+ grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
108
+ with torch_cuda_device(e.device):
109
+ _exact_backward_kernel[grid](DW, e, g, n_elements, BLOCK_SIZE = 1024,)
110
+ return DW, e, g
111
+ pass
112
+
113
+
114
+ @triton.jit
115
+ def _approx_forward_kernel(e, g, h, n_elements, BLOCK_SIZE : tl.constexpr,):
116
+ block_idx = tl.program_id(0)
117
+ offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
118
+ mask = offsets < n_elements
119
+
120
+ # f = 1/2 * e * (1 + tanh( sqrt(2/pi) * (x + 0.044715 * x^3 ) ))
121
+ # f = 1/2 * e * (1 + tanh( sqrt(2/pi) * x * (1 + 0.044715 * x^2 ) ))
122
+ # h = f * up
123
+ s = 0.7978845608028654 # math.sqrt(2 / math.pi)
124
+
125
+ e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32)
126
+ g_row = tl.load(g + offsets, mask = mask, other = 0)#.to(tl.float32)
127
+
128
+ f_row = 0.5 * e_row * (
129
+ triton_tanh(s * e_row * (1.0 + 0.044715 * e_row * e_row)) \
130
+ + 1.0
131
+ )
132
+ f_row = f_row.to(g_row.dtype) # Exact copy from HF
133
+ h_row = f_row * g_row
134
+
135
+ # Store h
136
+ tl.store(h + offsets, h_row, mask = mask)
137
+ pass
138
+
139
+
140
+ def geglu_approx_forward_kernel(gate, up):
141
+ batch, seq_len, hd = gate.shape
142
+ n_elements = gate.numel()
143
+ device = gate.device
144
+ out = torch.empty((batch, seq_len, hd), dtype = gate.dtype, device = device)
145
+ grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
146
+ with torch_cuda_device(device):
147
+ _approx_forward_kernel[grid](gate, up, out, n_elements, BLOCK_SIZE = 1024,)
148
+ return out
149
+ pass
150
+
151
+
152
+ @triton.jit
153
+ def _approx_backward_kernel(DW, e, g, n_elements, BLOCK_SIZE : tl.constexpr,):
154
+ """
155
+ f = 1/2 * e * (1 + tanh( sqrt(2/pi) * x * (1 + 0.044715 * x^2 ) ))
156
+ h = f * up
157
+
158
+ df/de (with help from https://arxiv.org/pdf/2305.12073.pdf :))
159
+ df/de = 1/2 * [1 + tanh( sqrt(2/pi) * x * (1 + 0.044715 * x^2 ) )] +
160
+ 1/2 * sech^2 [ sqrt(2/pi) * x * (1 + 0.044715 * x^2 ) ] * \
161
+ ( sqrt(2/pi) * x * (1 + 0.044715 * x^2 * 3 ) )
162
+
163
+ Notice sech^2(x) = 1 - tanh^2(x)
164
+ So reuse tanh( sqrt(2/pi) * x * (1 + 0.044715 * x^2 ) )
165
+
166
+ See https://www.desmos.com/calculator/nqprfoni6x
167
+ """
168
+ block_idx = tl.program_id(0)
169
+ offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
170
+ mask = offsets < n_elements
171
+
172
+ DW_row = tl.load(DW + offsets, mask = mask, other = 0)#.to(tl.float32)
173
+ e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32)
174
+ g_row = tl.load(g + offsets, mask = mask, other = 0)#.to(tl.float32)
175
+
176
+ # See https://www.desmos.com/calculator/nqprfoni6x
177
+ s = 0.7978845608028654 # math.sqrt(2 / math.pi)
178
+ a = s * e_row # a = sqrt(2 / pi) * x
179
+ b = a * 0.044715 * e_row * e_row # b = a * 0.044715 * x^2
180
+ T = 1.0 + triton_tanh(a + b)
181
+ T2 = 0.5 * T
182
+ # Q = 0.5 * -T * (T - 2.0) * (a + 3.0 * b)
183
+ Q2 = -T2 * (T - 2.0) * (a + 3.0 * b)
184
+ df_de = T2 + Q2 # 1/2 * (T + Q)
185
+
186
+ # f = 1/2 * e * (1 + tanh( sqrt(2/pi) * (x + 0.044715 * x^3 ) ))
187
+ f_row = T2 * e_row
188
+ f_row = f_row.to(DW_row.dtype)
189
+ # h = f * g
190
+ h_row = f_row * g_row
191
+ # df = DW * f
192
+ df_row = DW_row * f_row
193
+ # dg = DW * g
194
+ dg_row = DW_row * g_row
195
+
196
+ de_row = dg_row.to(tl.float32) * df_de
197
+ de_row = de_row.to(DW_row.dtype)
198
+
199
+ # Store derivatives in buffers
200
+ tl.store(DW + offsets, h_row, mask = mask) # h = f * g
201
+ tl.store(e + offsets, df_row, mask = mask) # df = DW * f
202
+ tl.store(g + offsets, de_row, mask = mask) # de
203
+ pass
204
+
205
+
206
+ def geglu_approx_backward_kernel(DW, e, g):
207
+ batch_seq_len, hd = e.shape
208
+ n_elements = e.numel()
209
+ grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
210
+ with torch_cuda_device(e.device):
211
+ _approx_backward_kernel[grid](DW, e, g, n_elements, BLOCK_SIZE = 1024,)
212
+ return DW, e, g
213
+ pass
torch-ext/unsloth_kernels/layernorm.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
2
+ # Copyright 2024-present Andrej Karpathy & the llm.c team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import triton
17
+ import triton.language as tl
18
+ import torch
19
+ from .utils import calculate_settings, torch_cuda_device
20
+ from unsloth_zoo.patching_utils import (
21
+ patch_layernorm,
22
+ )
23
+
24
+
25
+ @triton.jit
26
+ def layernorm_forward(
27
+ Y, Y_row_stride,
28
+ X, X_row_stride,
29
+ W,
30
+ b,
31
+ r,
32
+ mu,
33
+ n_cols : tl.constexpr,
34
+ eps : tl.constexpr,
35
+ BLOCK_SIZE : tl.constexpr
36
+ ):
37
+ row_idx = tl.program_id(0)
38
+ col_offsets = tl.arange(0, BLOCK_SIZE)
39
+ mask = col_offsets < n_cols
40
+
41
+ Y += row_idx * Y_row_stride
42
+ X += row_idx * X_row_stride
43
+ r += row_idx
44
+ mu += row_idx
45
+
46
+ # According to https://pytorch.org/torchtune/stable/_modules/torchtune/modules/layer_norm.html#Fp32LayerNorm, all modules
47
+ # are in float32!
48
+ X_row = tl.load(X + col_offsets, mask = mask, other = 0).to(tl.float32)
49
+ W_row = tl.load(W + col_offsets, mask = mask, other = 0).to(tl.float32)
50
+ b_row = tl.load(b + col_offsets, mask = mask, other = 0).to(tl.float32)
51
+
52
+ mean_X = tl.sum(X_row, axis = 0) / n_cols
53
+ # (X[0] - mean) == -mean so we need to mask it out
54
+ XX = tl.where(mask, X_row - mean_X, 0)
55
+ row_var = tl.sum(XX * XX, axis = 0) / n_cols
56
+ inv_var = tl.math.rsqrt(row_var + eps)
57
+ tl.store (r, inv_var)
58
+ tl.store (mu, mean_X)
59
+ output = (XX * inv_var) * W_row + b_row
60
+ tl.store(Y + col_offsets, output, mask = mask)
61
+ pass
62
+
63
+
64
+ @triton.jit
65
+ def layernorm_backward(
66
+ dY, dY_row_stride,
67
+ X, X_row_stride,
68
+ W,
69
+ b,
70
+ r,
71
+ mu,
72
+ n_cols : tl.constexpr,
73
+ eps : tl.constexpr,
74
+ BLOCK_SIZE : tl.constexpr
75
+ ):
76
+ # Approximately follows https://github.com/karpathy/llm.c/blob/master/doc/layernorm/layernorm.md
77
+ row_idx = tl.program_id(0)
78
+ col_offsets = tl.arange(0, BLOCK_SIZE)
79
+ mask = col_offsets < n_cols
80
+
81
+ dY += row_idx * dY_row_stride
82
+ X += row_idx * X_row_stride
83
+ r += row_idx
84
+ mu += row_idx
85
+
86
+ # According to https://pytorch.org/torchtune/stable/_modules/torchtune/modules/layer_norm.html#Fp32LayerNorm, all modules
87
+ # are in float32!
88
+ dY_row = tl.load(dY + col_offsets, mask = mask, other = 0).to(tl.float32)
89
+ X_row = tl.load(X + col_offsets, mask = mask, other = 0).to(tl.float32)
90
+ W_row = tl.load(W + col_offsets, mask = mask, other = 0).to(tl.float32)
91
+ b_row = tl.load(b + col_offsets, mask = mask, other = 0).to(tl.float32)
92
+
93
+ inv_var = tl.load(r) .to(tl.float32)
94
+ mean = tl.load(mu).to(tl.float32)
95
+ normed = (X_row - mean) * inv_var
96
+ dY_W = dY_row * W_row
97
+ dX_row = dY_W - tl.sum(dY_W, axis = 0) / n_cols - normed * tl.sum(dY_W * normed, axis = 0) / n_cols
98
+ dX_row = dX_row * inv_var
99
+ tl.store(dY + col_offsets, dX_row, mask = mask)
100
+ pass
101
+
102
+
103
+ class Fast_Layernorm(torch.autograd.Function):
104
+ @staticmethod
105
+ def forward(ctx, X, W, b, eps):
106
+ shape = X.shape
107
+ dim = shape[-1]
108
+ X = X.view(-1, dim)
109
+ n_rows, n_cols = X.shape
110
+ BLOCK_SIZE, num_warps = calculate_settings(n_cols)
111
+ device = X.device
112
+ Y = torch.empty((n_rows, n_cols), dtype = X.dtype, device = device)
113
+ r = torch.empty(n_rows, dtype = torch.float32, device = device)
114
+ mu = torch.empty(n_rows, dtype = torch.float32, device = device)
115
+
116
+ with torch_cuda_device(device):
117
+ layernorm_forward[(n_rows,)](
118
+ Y, Y.stride(0),
119
+ X, X.stride(0),
120
+ W,
121
+ b,
122
+ r,
123
+ mu,
124
+ n_cols, eps,
125
+ BLOCK_SIZE = BLOCK_SIZE,
126
+ num_warps = num_warps,
127
+ )
128
+ ctx.eps = eps
129
+ ctx.BLOCK_SIZE = BLOCK_SIZE
130
+ ctx.num_warps = num_warps
131
+ ctx.save_for_backward(X, W, b, r, mu)
132
+ return Y.view(*shape)
133
+ pass
134
+
135
+ @staticmethod
136
+ def backward(ctx, dY):
137
+ shape = dY.shape
138
+ dim = shape[-1]
139
+ dY = dY.view(-1, dim)
140
+ X, W, b, r, mu = ctx.saved_tensors
141
+ n_rows, n_cols = dY.shape
142
+
143
+ with torch_cuda_device(dY.device):
144
+ layernorm_backward[(n_rows,)](
145
+ dY, dY.stride(0),
146
+ X, X .stride(0),
147
+ W,
148
+ b,
149
+ r,
150
+ mu,
151
+ n_cols, ctx.eps,
152
+ BLOCK_SIZE = ctx.BLOCK_SIZE,
153
+ num_warps = ctx.num_warps,
154
+ )
155
+ dX = dY.view(*shape)
156
+ return dX, None, None, None, None
157
+ pass
158
+ pass
159
+
160
+
161
+ def fast_layernorm(layernorm, X):
162
+ assert(layernorm.elementwise_affine is True)
163
+ W = layernorm.weight
164
+ bias = layernorm.bias
165
+ eps = layernorm.variance_epsilon if \
166
+ hasattr(layernorm, "variance_epsilon") \
167
+ else layernorm.eps
168
+ out = Fast_Layernorm.apply(X, W, bias, eps)
169
+ return out
170
+ pass
torch-ext/unsloth_kernels/rms_layernorm.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import triton
16
+ import triton.language as tl
17
+ import torch
18
+ from .utils import calculate_settings, torch_cuda_device
19
+
20
+ @triton.jit
21
+ def _rms_layernorm_forward(
22
+ Y, Y_row_stride,
23
+ X, X_row_stride,
24
+ W, W_row_stride,
25
+ r, r_row_stride : tl.constexpr,
26
+ n_cols : tl.constexpr,
27
+ eps : tl.constexpr,
28
+ BLOCK_SIZE : tl.constexpr,
29
+ ):
30
+ """
31
+ Fast RMS Layernorm kernel
32
+ Inspiration from a Triton tutorial:
33
+ https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
34
+ """
35
+ row_idx = tl.program_id(0)
36
+ col_offsets = tl.arange(0, BLOCK_SIZE)
37
+ mask = col_offsets < n_cols
38
+
39
+ Y += row_idx * Y_row_stride
40
+ X += row_idx * X_row_stride
41
+ r += row_idx * r_row_stride
42
+
43
+ X_row = tl.load(X + col_offsets, mask = mask, other = 0).to(tl.float32)
44
+ W_row = tl.load(W + col_offsets, mask = mask, other = 0)#.to(tl.float32)
45
+
46
+ row_var = tl.sum(X_row * X_row, axis = 0) / n_cols
47
+ inv_var = tl.math.rsqrt(row_var + eps)
48
+ tl.store(r, inv_var)
49
+ normed = X_row * inv_var
50
+ normed = normed.to(W_row.dtype) # Exact copy from HF
51
+ output = normed * W_row
52
+ tl.store(Y + col_offsets, output, mask = mask)
53
+ pass
54
+
55
+
56
+ def _rms_layernorm_backward(
57
+ dY, dY_row_stride,
58
+ dX, dX_row_stride,
59
+ X, X_row_stride,
60
+ W, W_row_stride,
61
+ r, r_row_stride : tl.constexpr,
62
+ # dW, dW_row_stride,
63
+ n_cols : tl.constexpr,
64
+ eps : tl.constexpr,
65
+ GEMMA : tl.constexpr,
66
+ BLOCK_SIZE : tl.constexpr,
67
+ ):
68
+ """
69
+ Fast RMS Layernorm kernel for the backward pass
70
+ Inspiration from a Triton tutorial:
71
+ https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
72
+ """
73
+ row_idx = tl.program_id(0)
74
+ col_offsets = tl.arange(0, BLOCK_SIZE)
75
+ mask = col_offsets < n_cols
76
+
77
+ dY += row_idx * dY_row_stride
78
+ X += row_idx * X_row_stride
79
+ r += row_idx * r_row_stride
80
+
81
+ if GEMMA: dX += row_idx * dY_row_stride
82
+ else: dX = dY
83
+
84
+ dY_row = tl.load(dY + col_offsets, mask = mask, other = 0).to(tl.float32)
85
+ X_row = tl.load(X + col_offsets, mask = mask, other = 0).to(tl.float32)
86
+ W_row = tl.load(W + col_offsets, mask = mask, other = 0).to(tl.float32)
87
+
88
+ # Get saved row variance
89
+ inv_var = tl.load(r).to(tl.float32)
90
+ normed = X_row * inv_var
91
+
92
+ if GEMMA: dY_W = dY_row * (W_row + 1.0)
93
+ else: dY_W = dY_row * W_row
94
+
95
+ rowsum_dY_normed = tl.sum(dY_W * normed, axis = 0)
96
+ output = inv_var/n_cols * (n_cols*dY_W - normed*rowsum_dY_normed)
97
+ tl.store(dX + col_offsets, output, mask = mask)
98
+ pass
99
+ _rms_layernorm_backward = triton.jit(_rms_layernorm_backward)
100
+ _rms_layernorm_backward = triton.heuristics(
101
+ {
102
+ "GEMMA": lambda args: bool(args["GEMMA"]),
103
+ }
104
+ )(_rms_layernorm_backward)
105
+
106
+
107
+ @triton.jit
108
+ def _gemma_rms_layernorm_forward(
109
+ Y, Y_row_stride,
110
+ X, X_row_stride,
111
+ W, W_row_stride,
112
+ r, r_row_stride : tl.constexpr,
113
+ n_cols : tl.constexpr,
114
+ eps : tl.constexpr,
115
+ BLOCK_SIZE : tl.constexpr,
116
+ ):
117
+ # Copies https://github.com/google-deepmind/gemma/blob/main/gemma/layers.py#L31
118
+ # and https://github.com/keras-team/keras-nlp/blob/v0.8.2/keras_nlp/models/gemma/rms_normalization.py#L33
119
+ # exactly. Essentially all in float32!
120
+ row_idx = tl.program_id(0)
121
+ col_offsets = tl.arange(0, BLOCK_SIZE)
122
+ mask = col_offsets < n_cols
123
+
124
+ Y += row_idx * Y_row_stride
125
+ X += row_idx * X_row_stride
126
+ r += row_idx * r_row_stride
127
+
128
+ X_row = tl.load(X + col_offsets, mask = mask, other = 0).to(tl.float32)
129
+ W_row = tl.load(W + col_offsets, mask = mask, other = 0).to(tl.float32)
130
+
131
+ row_var = tl.sum(X_row * X_row, axis = 0) / n_cols
132
+ inv_var = tl.math.rsqrt(row_var + eps)
133
+ tl.store(r, inv_var)
134
+ normed = X_row * inv_var
135
+ output = normed * (W_row + 1.0)
136
+
137
+ tl.store(Y + col_offsets, output, mask = mask)
138
+ pass
139
+
140
+
141
+ class Fast_RMS_Layernorm(torch.autograd.Function):
142
+ @staticmethod
143
+ def forward(ctx, X : torch.Tensor, W : torch.Tensor, eps : float, gemma : bool = False):
144
+ shape = X.shape
145
+ dim : int = shape[-1]
146
+ X = X.view(-1, dim)
147
+ n_rows : int
148
+ n_cols : int
149
+ n_rows, n_cols = X.shape
150
+ BLOCK_SIZE : int
151
+ num_warps : int
152
+ BLOCK_SIZE, num_warps = calculate_settings(n_cols)
153
+ device = X.device
154
+
155
+ Y = torch.empty((n_rows, n_cols), dtype = X.dtype, device = device)
156
+ r = torch.empty(n_rows, dtype = torch.float32, device = device)
157
+
158
+ fx = _gemma_rms_layernorm_forward if gemma else _rms_layernorm_forward
159
+ with torch_cuda_device(device):
160
+ fx[(n_rows,)](
161
+ Y, Y.stride(0),
162
+ X, X.stride(0),
163
+ W, W.stride(0),
164
+ r, r.stride(0),
165
+ n_cols, eps,
166
+ BLOCK_SIZE = BLOCK_SIZE,
167
+ num_warps = num_warps,
168
+ )
169
+ ctx.eps = eps
170
+ ctx.BLOCK_SIZE = BLOCK_SIZE
171
+ ctx.num_warps = num_warps
172
+ ctx.GEMMA = gemma
173
+ ctx.save_for_backward(X, W, r)
174
+ return Y.view(*shape)
175
+ pass
176
+
177
+ @staticmethod
178
+ def backward(ctx, dY : torch.Tensor):
179
+ shape = dY.shape
180
+ dim : int = shape[-1]
181
+ dY = dY.view(-1, dim)
182
+ X, W, r = ctx.saved_tensors
183
+ n_rows : int
184
+ n_cols : int
185
+ n_rows, n_cols = dY.shape
186
+ # dW = X
187
+ dX = torch.empty_like(dY) if ctx.GEMMA else dY
188
+
189
+ with torch_cuda_device(dY.device):
190
+ _rms_layernorm_backward[(n_rows,)](
191
+ dY, dY.stride(0),
192
+ dX, dX.stride(0),
193
+ X, X .stride(0),
194
+ W, W .stride(0),
195
+ r, r .stride(0),
196
+ # dW, dW.stride(0),
197
+ n_cols, ctx.eps,
198
+ GEMMA = ctx.GEMMA,
199
+ BLOCK_SIZE = ctx.BLOCK_SIZE,
200
+ num_warps = ctx.num_warps,
201
+ )
202
+ dX = dX.view(*shape)
203
+ return dX, None, None, None
204
+ pass
205
+ pass
206
+
207
+
208
+ # [TODO] Unsure why RMS Layernorm is not torch.compiling properly
209
+ @torch.compiler.disable
210
+ def fast_rms_layernorm(layernorm, X : torch.Tensor, gemma : bool = False):
211
+ W : torch.Tensor = layernorm.weight
212
+ eps : float = layernorm.variance_epsilon if \
213
+ hasattr(layernorm, "variance_epsilon") \
214
+ else layernorm.eps
215
+ out = Fast_RMS_Layernorm.apply(X, W, eps, gemma)
216
+ return out
217
+ pass
218
+
219
+
220
+ from transformers.models.llama.modeling_llama import LlamaRMSNorm
221
+ class Unsloth_LlamaRMSNorm(LlamaRMSNorm):
222
+ def forward(self, X):
223
+ return fast_rms_layernorm(self, X, gemma = False)
224
+ pass
225
+ pass
226
+
227
+ try:
228
+ from transformers.models.mllama.modeling_mllama import MllamaTextRMSNorm
229
+ class Unsloth_MllamaTextRMSNorm(MllamaTextRMSNorm):
230
+ def forward(self, X):
231
+ return fast_rms_layernorm(self, X, gemma = False)
232
+ pass
233
+ pass
234
+ except:
235
+ pass
236
+ pass
237
+
238
+ def patch_rms_layernorm():
239
+ import transformers.models.llama.modeling_llama
240
+ transformers.models.llama.modeling_llama.LlamaRMSNorm = Unsloth_LlamaRMSNorm
241
+ try:
242
+ import transformers.models.mllama.modeling_mllama
243
+ transformers.models.mllama.modeling_mllama.MllamaTextRMSNorm = Unsloth_MllamaTextRMSNorm
244
+ except:
245
+ pass
246
+ return
247
+ pass
248
+
249
+
250
+ def unpatch_rms_layernorm():
251
+ import transformers.models.llama.modeling_llama
252
+ transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm
253
+ try:
254
+ import transformers.models.mllama.modeling_mllama
255
+ transformers.models.mllama.modeling_mllama.MllamaTextRMSNorm = MllamaTextRMSNorm
256
+ except:
257
+ pass
258
+ return
259
+ pass
260
+
261
+
torch-ext/unsloth_kernels/rope_embedding.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import triton
16
+ import triton.language as tl
17
+ import torch
18
+ from .utils import calculate_settings, torch_cuda_device
19
+ ROPE_GROUP_SIZE : int = 4
20
+
21
+ def _rope_embedding(
22
+ Q, Q_row_stride,
23
+ cos, cos_row_stride,
24
+ sin, sin_row_stride,
25
+ seqlen,
26
+ head_dim : tl.constexpr,
27
+ n_heads : tl.constexpr,
28
+ BACKWARD_PASS : tl.constexpr,
29
+ BLOCK_SIZE : tl.constexpr,
30
+ ):
31
+ """
32
+ Calculates the RoPE Embedding quickly
33
+ RoPE is Q * cos + rotate_half(Q) * sin
34
+ See our blog post for more info
35
+ """
36
+ ROPE_GROUP_SIZE = 4
37
+ row_position = tl.program_id(0)
38
+ group_head_position = tl.program_id(1)
39
+ col_offsets = tl.arange(0, BLOCK_SIZE)
40
+ half_head_dim = head_dim // 2
41
+ mask = col_offsets < half_head_dim
42
+
43
+ sin1 = tl.load(sin + (row_position % seqlen)*sin_row_stride + \
44
+ half_head_dim*0 + col_offsets, mask = mask, other = 0)
45
+ cos1 = tl.load(cos + (row_position % seqlen)*cos_row_stride + \
46
+ half_head_dim*0 + col_offsets, mask = mask, other = 0)
47
+
48
+ if BACKWARD_PASS:
49
+ # See our blog post for more info.
50
+ sin1 = -sin1
51
+ pass
52
+
53
+ # [TODO] Autotune ROPE_GROUP_SIZE to be 1, 2, 4, 8
54
+ head_start = group_head_position * ROPE_GROUP_SIZE
55
+ head_end = min((head_start + ROPE_GROUP_SIZE), n_heads)
56
+
57
+ # 10% Faster kernel from [HuyNguyen-hust](https://github.com/unslothai/unsloth/pull/238)
58
+ for k in range(head_start, head_end):
59
+ offs_q1 = row_position * Q_row_stride + k * head_dim + col_offsets
60
+ offs_q2 = row_position * Q_row_stride + k * head_dim + col_offsets + half_head_dim
61
+
62
+ # For Gemma - sometimes RoPE must be done in float32 and not bfloat16
63
+ Q1 = tl.load(Q + offs_q1, mask = mask, other = 0).to(sin1.dtype)
64
+ Q2 = tl.load(Q + offs_q2, mask = mask, other = 0).to(sin1.dtype)
65
+
66
+ tl.store(Q + offs_q1, Q1*cos1 - Q2*sin1, mask = mask)
67
+ tl.store(Q + offs_q2, Q2*cos1 + Q1*sin1, mask = mask)
68
+ pass
69
+ pass
70
+ _rope_embedding = triton.jit(_rope_embedding)
71
+ _rope_embedding = triton.heuristics(
72
+ {
73
+ "BACKWARD_PASS": lambda args: bool(args["BACKWARD_PASS"]),
74
+ }
75
+ )(_rope_embedding)
76
+
77
+
78
+ class Fast_RoPE_Embedding(torch.autograd.Function):
79
+ @staticmethod
80
+ def forward(ctx, Q, cos, sin):
81
+ cos, sin = cos.squeeze(), sin.squeeze()
82
+ batch : int
83
+ seq_len : int
84
+ n_heads : int
85
+ head_dim : int
86
+ batch, seq_len, n_heads, head_dim = Q.shape
87
+ Q = Q.view(batch*seq_len, n_heads*head_dim)
88
+ n_rows : int
89
+ n_cols : int
90
+ n_rows, n_cols = Q.shape
91
+ assert(seq_len <= cos.shape[0])
92
+
93
+ # [TODO] Changing blocksize to head_dim//2 seems to have
94
+ # some concurrency / un-deterministic issues.
95
+ BLOCK_SIZE, num_warps = calculate_settings(head_dim//2) # (head_dim//2)
96
+
97
+ # group_size = 4 # 4 or 8, too large group_size can hurt performance.
98
+ div : int
99
+ mod : int
100
+ div, mod = divmod(n_heads, ROPE_GROUP_SIZE)
101
+ n_groups : int = div + (mod != 0)
102
+
103
+ with torch_cuda_device(Q.device):
104
+ _rope_embedding[(n_rows, n_groups, )](
105
+ Q, Q.stride(0),
106
+ cos, cos.stride(0),
107
+ sin, sin.stride(0),
108
+ seq_len,
109
+ head_dim, n_heads,
110
+ BACKWARD_PASS = False,
111
+ BLOCK_SIZE = BLOCK_SIZE,
112
+ num_warps = num_warps,
113
+ )
114
+ ctx.BLOCK_SIZE = BLOCK_SIZE
115
+ ctx.num_warps = num_warps
116
+ ctx.n_groups = n_groups
117
+ ctx.cos = cos
118
+ ctx.sin = sin
119
+ return Q.view(batch, seq_len, n_heads, head_dim)
120
+ pass
121
+
122
+ @staticmethod
123
+ def backward(ctx, dY):
124
+ batch : int
125
+ seq_len : int
126
+ n_heads : int
127
+ head_dim : int
128
+ batch, seq_len, n_heads, head_dim = dY.shape
129
+ dY = dY.reshape(batch*seq_len, n_heads*head_dim)
130
+ # Must be reshape not view
131
+ n_rows : int
132
+ n_cols : int
133
+ n_rows, n_cols = dY.shape
134
+
135
+ cos = ctx.cos
136
+ sin = ctx.sin
137
+
138
+ with torch_cuda_device(dY.device):
139
+ _rope_embedding[(n_rows, ctx.n_groups, )](
140
+ dY, dY .stride(0),
141
+ cos, cos.stride(0),
142
+ sin, sin.stride(0),
143
+ seq_len, head_dim, n_heads,
144
+ BACKWARD_PASS = True,
145
+ BLOCK_SIZE = ctx.BLOCK_SIZE,
146
+ num_warps = ctx.num_warps,
147
+ )
148
+ dY = dY.view(batch, seq_len, n_heads, head_dim)
149
+ return dY, None, None,
150
+ pass
151
+ pass
152
+
153
+ # [TODO] Unsure why RoPE Embedding is not torch.compiling properly
154
+ @torch.compiler.disable
155
+ def fast_rope_embedding(Q, K, cos, sin):
156
+ Q = Fast_RoPE_Embedding.apply(Q.transpose(1, 2), cos, sin).transpose(1, 2)
157
+ K = Fast_RoPE_Embedding.apply(K.transpose(1, 2), cos, sin).transpose(1, 2)
158
+ return Q, K
159
+ pass
160
+
161
+
162
+ class Slow_RoPE_Embedding(torch.autograd.Function):
163
+ @staticmethod
164
+ def forward(ctx, Q, cos, sin, position_ids):
165
+ if position_ids is not None:
166
+ # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
167
+ cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
168
+ sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
169
+ cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
170
+ sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
171
+
172
+ # Q * cos + rotate_half(Q) * sin
173
+ half = Q.shape[-1]//2
174
+ RH_Q = torch.cat((-Q[..., half:], Q[..., :half]), dim = -1)
175
+ Q *= cos
176
+ Q.addcmul_(RH_Q, sin)
177
+ # RH_Q *= sin
178
+ # Q += RH_Q
179
+ ctx.save_for_backward(cos, sin)
180
+ return Q
181
+ pass
182
+
183
+ @staticmethod
184
+ def backward(ctx, dY):
185
+ cos, sin = ctx.saved_tensors
186
+ # Q * cos + rotate_half.T(Q) * sin
187
+ half = dY.shape[-1]//2
188
+ RH_dY = torch.cat((dY[..., half:], -dY[..., :half]), dim = -1)
189
+ dY *= cos
190
+ dY.addcmul_(RH_dY, sin)
191
+ # RH_dY *= sin
192
+ # dY += RH_dY
193
+ return dY, None, None, None
194
+ pass
195
+ pass
196
+
197
+
198
+ def inplace_rope_embedding(Q, K, cos, sin, position_ids):
199
+ Q = Slow_RoPE_Embedding.apply(Q, cos, sin, position_ids)
200
+ K = Slow_RoPE_Embedding.apply(K, cos, sin, position_ids)
201
+ return Q, K
202
+ pass
torch-ext/unsloth_kernels/swiglu.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import triton
16
+ import triton.language as tl
17
+ import torch
18
+ from .utils import calculate_settings, torch_cuda_device
19
+
20
+
21
+ @triton.jit
22
+ def _fg_kernel(e, g, h, n_elements, BLOCK_SIZE : tl.constexpr,):
23
+ block_idx = tl.program_id(0)
24
+ offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
25
+ mask = offsets < n_elements
26
+
27
+ e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32)
28
+ g_row = tl.load(g + offsets, mask = mask, other = 0)#.to(tl.float32)
29
+
30
+ # f = e * sigmoid(e)
31
+ f_row = e_row * tl.sigmoid(e_row) # e_row / (1 + tl.exp(-e_row))
32
+ f_row = f_row.to(g_row.dtype) # Exact copy from HF
33
+ # h = f * g
34
+ h_row = f_row * g_row
35
+
36
+ # Store h
37
+ tl.store(h + offsets, h_row, mask = mask)
38
+ pass
39
+
40
+
41
+ def swiglu_fg_kernel(e, g):
42
+ batch, seq_len, hd = e.shape
43
+ n_elements = e.numel()
44
+ h = torch.empty((batch, seq_len, hd), dtype = e.dtype, device = e.device)
45
+ grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
46
+ with torch_cuda_device(e.device):
47
+ _fg_kernel[grid](e, g, h, n_elements, BLOCK_SIZE = 1024,)
48
+ return h
49
+ pass
50
+
51
+
52
+ @triton.jit
53
+ def _DWf_DW_dfg_kernel(DW, e, g, n_elements, BLOCK_SIZE : tl.constexpr,):
54
+ """
55
+ e = e.float()
56
+ se = 1.0 / (1.0 + torch.exp(-e))
57
+ f = (se * e).to(dtype)
58
+ h = f * g
59
+ df = DW * f
60
+ dg = DW * g
61
+ de = (dg.float() * se * (1.0 + e * (1.0 - se))).to(dtype)
62
+ """
63
+ block_idx = tl.program_id(0)
64
+ offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
65
+ mask = offsets < n_elements
66
+
67
+ DW_row = tl.load(DW + offsets, mask = mask, other = 0)#.to(tl.float32)
68
+ e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32)
69
+ g_row = tl.load(g + offsets, mask = mask, other = 0)#.to(tl.float32)
70
+
71
+ # e = e.float()
72
+ # se = 1.0 / (1.0 + torch.exp(-e))
73
+ se_row = tl.sigmoid(e_row) # 1.0 / (1.0 + tl.exp(-e_row))
74
+ # f = (se * e).to(dtype)
75
+ f_row = se_row * e_row
76
+ f_row = f_row.to(DW_row.dtype)
77
+ # h = f * g
78
+ h_row = f_row * g_row
79
+ # df = DW * f
80
+ df_row = DW_row * f_row
81
+ # dg = DW * g
82
+ dg_row = DW_row * g_row
83
+ # de = (dg.float() * se * (1.0 + e * (1.0 - se))).to(dtype)
84
+ de_row = dg_row.to(tl.float32) * se_row * (1.0 + e_row * (1.0 - se_row))
85
+ de_row = de_row.to(DW_row.dtype)
86
+
87
+ # Store derivatives in buffers
88
+ tl.store(DW + offsets, h_row, mask = mask) # h = f * g
89
+ tl.store(e + offsets, df_row, mask = mask) # df = DW * f
90
+ tl.store(g + offsets, de_row, mask = mask) # de
91
+ pass
92
+
93
+
94
+ def swiglu_DWf_DW_dfg_kernel(DW, e, g):
95
+ batch_seq_len, hd = e.shape
96
+ n_elements = e.numel()
97
+ grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
98
+ with torch_cuda_device(e.device):
99
+ _DWf_DW_dfg_kernel[grid](DW, e, g, n_elements, BLOCK_SIZE = 1024,)
100
+ return DW, e, g
101
+ pass
torch-ext/unsloth_kernels/utils.py ADDED
@@ -0,0 +1,497 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import triton
16
+ MAX_FUSED_SIZE : int = 65536
17
+ next_power_of_2 = triton.next_power_of_2
18
+ import functools
19
+
20
+ # torch.cuda.amp.custom_fwd is deprecated >= 2.4
21
+ import torch
22
+ torch_Tensor = torch.Tensor
23
+ from packaging.version import Version
24
+ if Version(torch.__version__) < Version("2.4.0"):
25
+ torch_amp_custom_fwd = torch.cuda.amp.custom_fwd
26
+ torch_amp_custom_bwd = torch.cuda.amp.custom_bwd
27
+ else:
28
+ torch_amp_custom_fwd = torch.amp.custom_fwd(device_type = "cuda")
29
+ torch_amp_custom_bwd = torch.amp.custom_bwd(device_type = "cuda")
30
+ pass
31
+
32
+
33
+ # tl.math.tanh now is libdevice.tanh
34
+ from packaging.version import Version
35
+ import triton
36
+ import triton.language as tl
37
+ if Version(triton.__version__) >= Version("3.0.0"):
38
+ from triton.language.extra import libdevice
39
+ triton_tanh = libdevice.tanh
40
+ triton_cast = tl.cast
41
+ else:
42
+ triton_tanh = tl.math.tanh
43
+ # No casting in old Triton versions
44
+ @triton.jit
45
+ def triton_cast(x, dtype):
46
+ return x.to(dtype)
47
+ pass
48
+ pass
49
+
50
+
51
+ def calculate_settings(n : int) -> (int, int,):
52
+ BLOCK_SIZE : int = next_power_of_2(n)
53
+ if BLOCK_SIZE > MAX_FUSED_SIZE:
54
+ raise RuntimeError(f"Cannot launch Triton kernel since n = {n} exceeds "\
55
+ f"the maximum CUDA blocksize = {MAX_FUSED_SIZE}.")
56
+ num_warps : int = 4
57
+ if BLOCK_SIZE >= 32768: num_warps = 32
58
+ elif BLOCK_SIZE >= 8192: num_warps = 16
59
+ elif BLOCK_SIZE >= 2048: num_warps = 8
60
+ return BLOCK_SIZE, num_warps
61
+ pass
62
+
63
+
64
+ import bitsandbytes as bnb
65
+ import ctypes
66
+
67
+ # https://github.com/bitsandbytes-foundation/bitsandbytes/pull/1330/files
68
+ HAS_CUDA_STREAM = Version(bnb.__version__) > Version("0.43.3")
69
+ get_ptr = bnb.functional.get_ptr
70
+
71
+ if torch.cuda.device_count() > 1:
72
+ torch_cuda_device = torch.cuda.device
73
+ else:
74
+ from contextlib import nullcontext
75
+ def torch_cuda_device(device): return nullcontext()
76
+ pass
77
+ _cuda_getCurrentRawStream = torch._C._cuda_getCurrentRawStream
78
+ c_void_p = ctypes.c_void_p
79
+ def _get_tensor_stream(tensor: torch_Tensor) -> c_void_p:
80
+ return c_void_p(_cuda_getCurrentRawStream(tensor.device.index))
81
+ pass
82
+
83
+ # Get array of CUDA streams and other buffers
84
+ global CUDA_STREAMS
85
+ global WEIGHT_BUFFERS
86
+ global ABSMAX_BUFFERS
87
+
88
+ _CUDA_STREAMS = {
89
+ (index := torch.cuda.device(i).idx) : ctypes.c_void_p(torch._C._cuda_getCurrentRawStream(index))
90
+ for i in range(torch.cuda.device_count())
91
+ }
92
+ CUDA_STREAMS = [None] * (max(_CUDA_STREAMS.keys()) + 1)
93
+ WEIGHT_BUFFERS = [None] * (max(_CUDA_STREAMS.keys()) + 1)
94
+ ABSMAX_BUFFERS = [None] * (max(_CUDA_STREAMS.keys()) + 1)
95
+ for k, v in _CUDA_STREAMS.items(): CUDA_STREAMS[k] = v
96
+ CUDA_STREAMS = tuple(CUDA_STREAMS)
97
+ del _CUDA_STREAMS
98
+
99
+ # Bitsandbytes operations
100
+ ctypes_c_int = ctypes.c_int
101
+ ctypes_c_int32 = ctypes.c_int32
102
+ cdequantize_blockwise_fp32 = bnb.functional.lib.cdequantize_blockwise_fp32
103
+ cdequantize_blockwise_fp16_nf4 = bnb.functional.lib.cdequantize_blockwise_fp16_nf4
104
+ cdequantize_blockwise_bf16_nf4 = bnb.functional.lib.cdequantize_blockwise_bf16_nf4
105
+ cgemm_4bit_inference_naive_fp16 = bnb.functional.lib.cgemm_4bit_inference_naive_fp16
106
+ cgemm_4bit_inference_naive_bf16 = bnb.functional.lib.cgemm_4bit_inference_naive_bf16
107
+ torch_mm = torch.mm
108
+ torch_mv = torch.mv
109
+ torch_matmul = torch.matmul
110
+ torch_addmm = torch.addmm
111
+ torch_empty = torch.empty
112
+
113
+ def QUANT_STATE(W): return getattr(W, "quant_state", None)
114
+
115
+ def get_lora_parameters(proj):
116
+ # For DPO or disabled adapters
117
+ base_layer = getattr(proj, "base_layer", proj) # (proj.base_layer if hasattr(proj, "base_layer") else proj)
118
+ W = base_layer.weight
119
+
120
+ # if not hasattr(proj, "disable_adapters") or proj.disable_adapters or proj.merged:
121
+ if getattr(proj, "disable_adapters", True) or proj.merged:
122
+ return W, getattr(W, "quant_state", None), None, None, None
123
+ pass
124
+
125
+ adapter = getattr(proj, "active_adapters", None)
126
+ if adapter is None: adapter = getattr(proj, "active_adapter", ("default"))
127
+ adapter = adapter[0]
128
+
129
+ return (
130
+ W,
131
+ getattr(W, "quant_state", None),
132
+ proj.lora_A [adapter].weight,
133
+ proj.lora_B [adapter].weight,
134
+ proj.scaling[adapter],
135
+ )
136
+ pass
137
+
138
+
139
+ def get_lora_parameters_bias(proj):
140
+ # For DPO or disabled adapters
141
+ base_layer = getattr(proj, "base_layer", proj) # (proj.base_layer if hasattr(proj, "base_layer") else proj)
142
+ W = base_layer.weight
143
+
144
+ # if not hasattr(proj, "disable_adapters") or proj.disable_adapters or proj.merged:
145
+ if getattr(proj, "disable_adapters", True) or proj.merged:
146
+ return W, getattr(W, "quant_state", None), None, None, None, base_layer.bias
147
+ pass
148
+
149
+ adapter = getattr(proj, "active_adapters", None)
150
+ if adapter is None: adapter = getattr(proj, "active_adapter", ("default"))
151
+ adapter = adapter[0]
152
+
153
+ return (
154
+ W,
155
+ getattr(W, "quant_state", None),
156
+ proj.lora_A [adapter].weight,
157
+ proj.lora_B [adapter].weight,
158
+ proj.scaling[adapter],
159
+ base_layer.bias,
160
+ )
161
+ pass
162
+
163
+ if HAS_CUDA_STREAM:
164
+ @torch.inference_mode
165
+ def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False):
166
+ if quant_state is None: return W
167
+ if type(quant_state) is not list:
168
+ # New quant_state as a class
169
+ # https://github.com/TimDettmers/bitsandbytes/pull/763/files
170
+ absmax = quant_state.absmax
171
+ shape = quant_state.shape
172
+ dtype = quant_state.dtype
173
+ blocksize = quant_state.blocksize
174
+ offset = quant_state.offset
175
+ state2 = quant_state.state2
176
+ absmax2 = state2.absmax
177
+ code2 = state2.code
178
+ blocksize2 = state2.blocksize
179
+ else:
180
+ # Old quant_state as a list of lists
181
+ absmax, shape, dtype, blocksize, compressed_stats, _, _ = quant_state
182
+ offset, state2 = compressed_stats
183
+ absmax2, code2, blocksize2, _, _, _, _ = state2
184
+ pass
185
+ global CUDA_STREAMS
186
+ device = W.device
187
+ device_index = device.index
188
+ CUDA_STREAM = CUDA_STREAMS[device_index]
189
+
190
+ n_elements_absmax = absmax.numel()
191
+
192
+ # Create weight matrix
193
+ if use_global_buffer:
194
+
195
+ # Use same buffers for faster inference
196
+ size = shape[0]*shape[1]
197
+ global WEIGHT_BUFFERS
198
+ global ABSMAX_BUFFERS
199
+ WEIGHT_BUFFER = WEIGHT_BUFFERS[device_index]
200
+ ABSMAX_BUFFER = ABSMAX_BUFFERS[device_index]
201
+ if WEIGHT_BUFFER is None:
202
+ WEIGHT_BUFFERS[device_index] = WEIGHT_BUFFER = torch_empty(size, dtype = dtype, device = device, requires_grad = False)
203
+ ABSMAX_BUFFERS[device_index] = ABSMAX_BUFFER = torch_empty(n_elements_absmax, dtype = torch.float32, device = device, requires_grad = False)
204
+
205
+ if size > WEIGHT_BUFFER.numel(): WEIGHT_BUFFER.resize_(size)
206
+ if n_elements_absmax > ABSMAX_BUFFER.numel(): ABSMAX_BUFFER.resize_(n_elements_absmax)
207
+
208
+ out = WEIGHT_BUFFER[:size].view(shape)
209
+ out_absmax = ABSMAX_BUFFER[:n_elements_absmax]
210
+ else:
211
+ if out is None:
212
+ out = torch_empty(shape, dtype = dtype, device = device, requires_grad = False)
213
+ else:
214
+ assert(out.shape == shape)
215
+ assert(out.dtype == dtype)
216
+ out_absmax = torch_empty(n_elements_absmax, dtype = torch.float32, device = device, requires_grad = False)
217
+ pass
218
+
219
+ # NF4 dequantization of statistics
220
+ ptr_out_absmax = get_ptr(out_absmax)
221
+ with torch_cuda_device(device):
222
+ cdequantize_blockwise_fp32(
223
+ get_ptr(code2), get_ptr(absmax), get_ptr(absmax2), ptr_out_absmax,
224
+ ctypes_c_int(blocksize2), ctypes_c_int(n_elements_absmax), CUDA_STREAM
225
+ )
226
+ out_absmax += offset
227
+
228
+ # Dequantize W
229
+ fx = cdequantize_blockwise_fp16_nf4 if dtype == torch.float16 else \
230
+ cdequantize_blockwise_bf16_nf4
231
+ fx(get_ptr(None), get_ptr(W), ptr_out_absmax, get_ptr(out),
232
+ ctypes_c_int(blocksize), ctypes_c_int(out.numel()), CUDA_STREAM,)
233
+ pass
234
+ # Careful returning transposed data
235
+ is_transposed = (True if W.shape[0] == 1 else False)
236
+ return out.t() if is_transposed else out
237
+ pass
238
+ else:
239
+ @torch.inference_mode
240
+ def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False):
241
+ if quant_state is None: return W
242
+ if type(quant_state) is not list:
243
+ # New quant_state as a class
244
+ # https://github.com/TimDettmers/bitsandbytes/pull/763/files
245
+ absmax = quant_state.absmax
246
+ shape = quant_state.shape
247
+ dtype = quant_state.dtype
248
+ blocksize = quant_state.blocksize
249
+ offset = quant_state.offset
250
+ state2 = quant_state.state2
251
+ absmax2 = state2.absmax
252
+ code2 = state2.code
253
+ blocksize2 = state2.blocksize
254
+ else:
255
+ # Old quant_state as a list of lists
256
+ absmax, shape, dtype, blocksize, compressed_stats, _, _ = quant_state
257
+ offset, state2 = compressed_stats
258
+ absmax2, code2, blocksize2, _, _, _, _ = state2
259
+ pass
260
+
261
+ n_elements_absmax = absmax.numel()
262
+ device = W.device
263
+
264
+ # Create weight matrix
265
+ if out is None:
266
+ out = torch_empty(shape, dtype = dtype, device = device, requires_grad = False)
267
+ else:
268
+ assert(out.shape == shape)
269
+ assert(out.dtype == dtype)
270
+ out_absmax = torch_empty(n_elements_absmax, dtype = torch.float32, device = device, requires_grad = False)
271
+
272
+ # Do dequantization
273
+ ptr_out_absmax = get_ptr(out_absmax)
274
+ cdequantize_blockwise_fp32(
275
+ get_ptr(code2), get_ptr(absmax), get_ptr(absmax2), ptr_out_absmax,
276
+ ctypes_c_int(blocksize2), ctypes_c_int(n_elements_absmax),
277
+ )
278
+ out_absmax += offset
279
+
280
+ fx = cdequantize_blockwise_fp16_nf4 if dtype == torch.float16 else \
281
+ cdequantize_blockwise_bf16_nf4
282
+ fx(get_ptr(None), get_ptr(W), ptr_out_absmax, get_ptr(out),
283
+ ctypes_c_int(blocksize), ctypes_c_int(out.numel()),)
284
+
285
+ # Careful returning transposed data
286
+ is_transposed = (True if W.shape[0] == 1 else False)
287
+ return out.t() if is_transposed else out
288
+ pass
289
+ pass
290
+
291
+
292
+ if HAS_CUDA_STREAM:
293
+ def fast_gemv(X, W, quant_state, out = None):
294
+ if quant_state is None: return torch_matmul(X, W, out = out)
295
+ # For fast X @ W where seq_len == 1
296
+ # From https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/functional.py#L1469
297
+ _, q_len, hd = X.shape
298
+ # assert(q_len == 1)
299
+
300
+ if type(quant_state) is not list:
301
+ # https://github.com/TimDettmers/bitsandbytes/pull/763/files
302
+ absmax = quant_state.absmax
303
+ shape = quant_state.shape
304
+ dtype = quant_state.dtype
305
+ blocksize = quant_state.blocksize
306
+ stats = quant_state.code
307
+ offset = quant_state.offset
308
+ state2 = quant_state.state2
309
+ absmax2 = state2.absmax
310
+ code2 = state2.code
311
+ blocksize2 = state2.blocksize
312
+ else:
313
+ absmax, shape, dtype, blocksize, compressed_stats, quant_type, stats = quant_state
314
+ offset, state2 = compressed_stats
315
+ absmax2, code2, blocksize2, _, _, _, _ = state2
316
+ pass
317
+ global CUDA_STREAMS
318
+ device = W.device
319
+ device_index = device.index
320
+ CUDA_STREAM = CUDA_STREAMS[device_index]
321
+
322
+ # assert(dtype == X.dtype)
323
+ bout = shape[0]
324
+
325
+ if out is None:
326
+ out = torch_empty((1, 1, bout,), dtype = dtype, device = device)
327
+ # else:
328
+ # assert(out.shape == (1, 1, bout,))
329
+ # pass
330
+
331
+ n = 1
332
+ m = shape[0]
333
+ k = shape[1]
334
+ lda = shape[0]
335
+ ldc = shape[0]
336
+ ldb = (hd+1)//2
337
+ m = ctypes_c_int32(m)
338
+ n = ctypes_c_int32(n)
339
+ k = ctypes_c_int32(k)
340
+ lda = ctypes_c_int32(lda)
341
+ ldb = ctypes_c_int32(ldb)
342
+ ldc = ctypes_c_int32(ldc)
343
+
344
+ df = torch_empty(absmax.shape, dtype = torch.float32, device = device)
345
+ with torch_cuda_device(device):
346
+ cdequantize_blockwise_fp32(
347
+ get_ptr(code2), get_ptr(absmax), get_ptr(absmax2), get_ptr(df),
348
+ ctypes_c_int(blocksize2), ctypes_c_int(df.numel()), CUDA_STREAM,
349
+ )
350
+ df += offset
351
+ absmax = df
352
+
353
+ fx = cgemm_4bit_inference_naive_fp16 if dtype == torch.float16 else \
354
+ cgemm_4bit_inference_naive_bf16
355
+
356
+ blocksize = ctypes_c_int32(blocksize)
357
+ fx(m, n, k, get_ptr(X), get_ptr(W), get_ptr(absmax), get_ptr(stats), get_ptr(out),
358
+ lda, ldb, ldc, blocksize, CUDA_STREAM,)
359
+ pass
360
+
361
+ return out
362
+ pass
363
+ else:
364
+ def fast_gemv(X, W, quant_state, out = None):
365
+ if quant_state is None: return torch.matmul(X, W, out = out)
366
+ # For fast X @ W where seq_len == 1
367
+ # From https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/functional.py#L1469
368
+ _, q_len, hd = X.shape
369
+ # assert(q_len == 1)
370
+
371
+ if type(quant_state) is not list:
372
+ # https://github.com/TimDettmers/bitsandbytes/pull/763/files
373
+ absmax = quant_state.absmax
374
+ shape = quant_state.shape
375
+ dtype = quant_state.dtype
376
+ blocksize = quant_state.blocksize
377
+ stats = quant_state.code
378
+ offset = quant_state.offset
379
+ state2 = quant_state.state2
380
+ absmax2 = state2.absmax
381
+ code2 = state2.code
382
+ blocksize2 = state2.blocksize
383
+ else:
384
+ absmax, shape, dtype, blocksize, compressed_stats, quant_type, stats = quant_state
385
+ offset, state2 = compressed_stats
386
+ absmax2, code2, blocksize2, _, _, _, _ = state2
387
+ pass
388
+ # assert(dtype == X.dtype)
389
+ bout = shape[0]
390
+ device = W.device
391
+
392
+ if out is None:
393
+ out = torch_empty((1, 1, bout,), dtype = dtype, device = device)
394
+ # else:
395
+ # assert(out.shape == (1, 1, bout,))
396
+ # pass
397
+
398
+ n = 1
399
+ m = shape[0]
400
+ k = shape[1]
401
+ lda = shape[0]
402
+ ldc = shape[0]
403
+ ldb = (hd+1)//2
404
+ m = ctypes_c_int32(m)
405
+ n = ctypes_c_int32(n)
406
+ k = ctypes_c_int32(k)
407
+ lda = ctypes_c_int32(lda)
408
+ ldb = ctypes_c_int32(ldb)
409
+ ldc = ctypes_c_int32(ldc)
410
+
411
+ df = torch_empty(absmax.shape, dtype = torch.float32, device = device)
412
+ cdequantize_blockwise_fp32(
413
+ get_ptr(code2), get_ptr(absmax), get_ptr(absmax2), get_ptr(df),
414
+ ctypes_c_int(blocksize2), ctypes_c_int(df.numel()),
415
+ )
416
+ df += offset
417
+ absmax = df
418
+
419
+ fx = cgemm_4bit_inference_naive_fp16 if dtype == torch.float16 else \
420
+ cgemm_4bit_inference_naive_bf16
421
+
422
+ blocksize = ctypes_c_int32(blocksize)
423
+ fx(m, n, k, get_ptr(X), get_ptr(W), get_ptr(absmax), get_ptr(stats), get_ptr(out),
424
+ lda, ldb, ldc, blocksize,)
425
+
426
+ return out
427
+ pass
428
+ pass
429
+
430
+
431
+ def fast_linear_forward(proj, X, temp_lora = None, out = None):
432
+
433
+ W, W_quant, lora_A, lora_B, lora_S, bias = get_lora_parameters_bias(proj)
434
+ bsz, q_len, in_dim = X.shape
435
+ if q_len != 1: return matmul_lora(X, W, W_quant, lora_A, lora_B, lora_S)
436
+
437
+ if W_quant is None:
438
+ out = torch_matmul(X, W.t(), out = out)
439
+ elif bsz == 1 and q_len == 1:
440
+ out = fast_gemv(X, W, W_quant, out = out)
441
+ else:
442
+ W = fast_dequantize(W.t(), W_quant, use_global_buffer = True)
443
+ out = torch_matmul(X, W, out = out)
444
+ pass
445
+
446
+ # Add in LoRA weights
447
+ if lora_A is not None:
448
+ out_dim = out.shape[2]
449
+ dtype = X.dtype
450
+
451
+ if not hasattr(lora_A, "_fast_lora"):
452
+ lora_A._fast_lora = lora_A.to(dtype)
453
+ lora_B._fast_lora = lora_B.to(dtype)
454
+ pass
455
+
456
+ if bsz == 1:
457
+ out = out.view(out_dim)
458
+ temp_lora = torch_mv(lora_A._fast_lora, X.ravel(), out = temp_lora)
459
+ out.addmv_(lora_B._fast_lora, temp_lora, alpha = lora_S)
460
+ else:
461
+ out = out.view(bsz, out_dim)
462
+ temp_lora = torch_mm(X.view(bsz, in_dim), lora_A._fast_lora.t(), out = temp_lora)
463
+ out.addmm_(temp_lora, lora_B._fast_lora.t(), alpha = lora_S)
464
+ pass
465
+ out = out.view(bsz, 1, out_dim)
466
+ pass
467
+
468
+ if bias is not None: out += bias
469
+
470
+ return out
471
+ pass
472
+
473
+
474
+ def matmul_lora(X, W, W_quant, A, B, s, out = None):
475
+ dtype = X.dtype
476
+ W = fast_dequantize(W.t(), W_quant, use_global_buffer = True)
477
+
478
+ if X.dim() == 3:
479
+ batch, seq_len, d = X.shape
480
+ X = X.view(-1, X.shape[-1])
481
+ reshape = True
482
+ else:
483
+ reshape = False
484
+ pass
485
+ out = torch_matmul(X, W, out = out)
486
+ if W_quant is not None: del W
487
+
488
+ if A is not None:
489
+ # LoRA is enabled
490
+ A, B = A.t(), B.t()
491
+ XA = torch_matmul(X, A.to(dtype))
492
+ out.addmm_(XA, B.to(dtype), alpha = s)
493
+ # out += (X @ A.to(dtype)) @ (s * B.to(dtype))
494
+ pass
495
+
496
+ return out.view(batch, seq_len, -1) if reshape else out
497
+ pass