Upload folder using huggingface_hub
Browse files- README.md +6 -0
- build.toml +5 -0
- flake.nix +17 -0
- torch-ext/unsloth_kernels/__init__.py +23 -0
- torch-ext/unsloth_kernels/cross_entropy_loss.py +420 -0
- torch-ext/unsloth_kernels/fast_lora.py +537 -0
- torch-ext/unsloth_kernels/flex_attention.py +181 -0
- torch-ext/unsloth_kernels/geglu.py +213 -0
- torch-ext/unsloth_kernels/layernorm.py +170 -0
- torch-ext/unsloth_kernels/rms_layernorm.py +261 -0
- torch-ext/unsloth_kernels/rope_embedding.py +202 -0
- torch-ext/unsloth_kernels/swiglu.py +101 -0
- torch-ext/unsloth_kernels/utils.py +497 -0
README.md
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Unsloth Kernels
|
2 |
+
|
3 |
+
Unsloth Kernels is a collection of kernels for the Unsloth project.
|
4 |
+
|
5 |
+
## Installation
|
6 |
+
|
build.toml
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[general]
|
2 |
+
name = "unsloth_kernels"
|
3 |
+
|
4 |
+
[torch]
|
5 |
+
universal = true
|
flake.nix
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
description = "Flake for ReLU kernel";
|
3 |
+
|
4 |
+
inputs = {
|
5 |
+
kernel-builder.url = "path:../..";
|
6 |
+
};
|
7 |
+
|
8 |
+
outputs =
|
9 |
+
{
|
10 |
+
self,
|
11 |
+
kernel-builder,
|
12 |
+
}:
|
13 |
+
kernel-builder.lib.genFlakeOutputs {
|
14 |
+
path = ./.;
|
15 |
+
rev = self.shortRev or self.dirtyShortRev or self.lastModifiedDate;
|
16 |
+
};
|
17 |
+
}
|
torch-ext/unsloth_kernels/__init__.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .cross_entropy_loss import fast_cross_entropy_loss
|
2 |
+
from .fast_lora import fast_lora_forward
|
3 |
+
from .flex_attention import slow_inference_attention_softcapping
|
4 |
+
from .layernorm import fast_layernorm
|
5 |
+
from .rope_embedding import inplace_rope_embedding, fast_rope_embedding
|
6 |
+
from .rms_layernorm import fast_rms_layernorm
|
7 |
+
from .swiglu import swiglu_fg_kernel
|
8 |
+
from .geglu import geglu_approx_forward_kernel, geglu_approx_backward_kernel, geglu_exact_forward_kernel, geglu_exact_backward_kernel
|
9 |
+
from .swiglu import swiglu_fg_kernel
|
10 |
+
|
11 |
+
__all__ = ["fast_cross_entropy_loss",
|
12 |
+
"fast_lora_forward",
|
13 |
+
"slow_inference_attention_softcapping",
|
14 |
+
"fast_layernorm",
|
15 |
+
"inplace_rope_embedding",
|
16 |
+
"fast_rms_layernorm",
|
17 |
+
"swiglu_fg_kernel",
|
18 |
+
"geglu_approx_forward_kernel",
|
19 |
+
"geglu_approx_backward_kernel",
|
20 |
+
"geglu_exact_forward_kernel",
|
21 |
+
"geglu_exact_backward_kernel",
|
22 |
+
"fast_rope_embedding"
|
23 |
+
]
|
torch-ext/unsloth_kernels/cross_entropy_loss.py
ADDED
@@ -0,0 +1,420 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import triton
|
16 |
+
import triton.language as tl
|
17 |
+
import torch
|
18 |
+
from .utils import (
|
19 |
+
calculate_settings,
|
20 |
+
MAX_FUSED_SIZE,
|
21 |
+
triton_tanh,
|
22 |
+
triton_cast,
|
23 |
+
torch_cuda_device,
|
24 |
+
)
|
25 |
+
from transformers.models.llama.modeling_llama import logger
|
26 |
+
from packaging.version import Version
|
27 |
+
|
28 |
+
from unsloth_zoo.loss_utils import (
|
29 |
+
patch_loss_functions as _patch_loss_functions,
|
30 |
+
post_patch_loss_function,
|
31 |
+
)
|
32 |
+
|
33 |
+
|
34 |
+
def _cross_entropy_forward(
|
35 |
+
logits_ptr ,
|
36 |
+
logits_row_stride ,
|
37 |
+
loss_ptr ,
|
38 |
+
logsumexp_ptr ,
|
39 |
+
labels_ptr ,
|
40 |
+
VOCAB_SIZE : tl.constexpr,
|
41 |
+
BLOCK_SIZE : tl.constexpr,
|
42 |
+
DO_SOFTCAPPING : tl.constexpr,
|
43 |
+
SOFTCAP : tl.constexpr,
|
44 |
+
DO_LOGIT_SCALING : tl.constexpr,
|
45 |
+
LOGIT_SCALE : tl.constexpr,
|
46 |
+
):
|
47 |
+
"""
|
48 |
+
Cross Entropy Loss = 1/n sum [ -yi log(Pi) ]
|
49 |
+
Pi = exp(xi) / sum(exp(xi))
|
50 |
+
CE_i = -y log(p) = -y log[ exp(x) / sum(exp(x)) ]
|
51 |
+
= -y [ x - log[sum(exp(x))] ]
|
52 |
+
= y * (log[sum(exp(x))] - x)
|
53 |
+
If y == 0: CE_i = 0
|
54 |
+
If y == 1: CE_i = logsumexp - x
|
55 |
+
|
56 |
+
logsumexp is also stable
|
57 |
+
Take y = log[sum(exp(x))]
|
58 |
+
exp(y) = sum(exp(x))
|
59 |
+
exp(y) = sum(exp(x - c)*exp(c)) Since e^(x-c)*e^c = e^x
|
60 |
+
exp(y) = exp(c)*sum(exp(x - c))
|
61 |
+
y = log(exp(c)*sum(exp(x - c)))
|
62 |
+
y = c + log[sum(exp(x - c))]
|
63 |
+
This means we can set c = max(x) to make sure
|
64 |
+
exp(x - c) always is exp(x - max(x)).
|
65 |
+
This ensures exp(x - max(x))'s maximum is 1 as exp(0) = 1.
|
66 |
+
"""
|
67 |
+
row_idx = tl.program_id(0)
|
68 |
+
logits_ptr += row_idx * triton_cast(logits_row_stride, tl.int64)
|
69 |
+
loss_ptr += row_idx
|
70 |
+
logsumexp_ptr += row_idx
|
71 |
+
labels_ptr += row_idx
|
72 |
+
|
73 |
+
col_offsets = tl.arange(0, BLOCK_SIZE)
|
74 |
+
mask = col_offsets < VOCAB_SIZE
|
75 |
+
|
76 |
+
label_idx = tl.load(labels_ptr).to(tl.int32)
|
77 |
+
logits = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf")).to(tl.float32)
|
78 |
+
|
79 |
+
# Go logit scaling for Cohere: t * x
|
80 |
+
if DO_LOGIT_SCALING: logits = LOGIT_SCALE * logits
|
81 |
+
# Do logit softcapping for Gemma 2: t * tanh(1/t * x)
|
82 |
+
if DO_SOFTCAPPING: logits = SOFTCAP * triton_tanh(logits / SOFTCAP)
|
83 |
+
|
84 |
+
c = tl.max(logits, 0)
|
85 |
+
logsumexp = c + tl.log(tl.sum(tl.exp(logits - c), 0))
|
86 |
+
|
87 |
+
if label_idx != -100:
|
88 |
+
x = tl.load(logits_ptr + label_idx).to(tl.float32)
|
89 |
+
# Go logit scaling for Cohere: t * x
|
90 |
+
if DO_LOGIT_SCALING: x = LOGIT_SCALE * x
|
91 |
+
# Do logit softcapping for Gemma 2: t * tanh(1/t * x)
|
92 |
+
if DO_SOFTCAPPING: x = SOFTCAP * triton_tanh(x / SOFTCAP)
|
93 |
+
loss = logsumexp - x
|
94 |
+
else:
|
95 |
+
loss = 0.0
|
96 |
+
tl.store(logsumexp_ptr, logsumexp)
|
97 |
+
tl.store(loss_ptr, loss)
|
98 |
+
pass
|
99 |
+
_cross_entropy_forward = triton.jit(_cross_entropy_forward)
|
100 |
+
_cross_entropy_forward = triton.heuristics(
|
101 |
+
{
|
102 |
+
"DO_SOFTCAPPING": lambda args: bool(args["DO_SOFTCAPPING" ]),
|
103 |
+
"DO_LOGIT_SCALING": lambda args: bool(args["DO_LOGIT_SCALING"]),
|
104 |
+
}
|
105 |
+
)(_cross_entropy_forward)
|
106 |
+
|
107 |
+
|
108 |
+
def _chunked_cross_entropy_forward(
|
109 |
+
logits_ptr ,
|
110 |
+
logits_row_stride ,
|
111 |
+
loss_ptr ,
|
112 |
+
logsumexp_ptr ,
|
113 |
+
labels_ptr ,
|
114 |
+
VOCAB_SIZE : tl.constexpr,
|
115 |
+
N_CHUNKS : tl.constexpr,
|
116 |
+
BLOCK_SIZE : tl.constexpr,
|
117 |
+
DO_SOFTCAPPING : tl.constexpr,
|
118 |
+
SOFTCAP : tl.constexpr,
|
119 |
+
DO_LOGIT_SCALING : tl.constexpr,
|
120 |
+
LOGIT_SCALE : tl.constexpr,
|
121 |
+
):
|
122 |
+
"""
|
123 |
+
256K vocab divided in 4 chunks
|
124 |
+
|
125 |
+
|-65536-| |-65536-| |-65536-| |-65536-|
|
126 |
+
|-------| |-------| |-------| |-------|
|
127 |
+
|-------| |-------| |-------| |-------|
|
128 |
+
|
129 |
+
If y == 0: CE_i = 0
|
130 |
+
If y == 1: CE_i = logsumexp - x
|
131 |
+
|
132 |
+
Notice we can do logsumexp for each chunk and then
|
133 |
+
logsumexp[chunk_sum(logsumexp)] == logsumexp
|
134 |
+
|
135 |
+
chunk_sum = log[chunk_sum(logsumexp)]
|
136 |
+
= log[exp(logsumexp(a)) + ... + exp(logsumexp(z))]
|
137 |
+
= log[exp(log[sum(exp(a))]) + ... + exp(log[sum(exp(z))])]
|
138 |
+
= log[sum(exp(a)) + ... + sum(exp(z))]
|
139 |
+
= logsumexp(x)
|
140 |
+
|
141 |
+
This means we can perform a logsumexp for each chunk, then do a
|
142 |
+
final logsumexp reduction!
|
143 |
+
|
144 |
+
Ie do: logsumexp(chunked_logsumexp) - x
|
145 |
+
"""
|
146 |
+
row_idx = tl.program_id(0)
|
147 |
+
chunk_idx = tl.program_id(1)
|
148 |
+
logits_ptr += row_idx * triton_cast(logits_row_stride, tl.int64)
|
149 |
+
loss_ptr += row_idx
|
150 |
+
logsumexp_ptr += row_idx * N_CHUNKS + chunk_idx
|
151 |
+
labels_ptr += row_idx
|
152 |
+
|
153 |
+
col_offsets = chunk_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
154 |
+
mask = col_offsets < VOCAB_SIZE
|
155 |
+
|
156 |
+
label_idx = tl.load(labels_ptr).to(tl.int32)
|
157 |
+
logits = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf")).to(tl.float32)
|
158 |
+
|
159 |
+
# Go logit scaling for Cohere: t * x
|
160 |
+
if DO_LOGIT_SCALING: logits = LOGIT_SCALE * logits
|
161 |
+
# Do logit softcapping for Gemma 2: t * tanh(1/t * x)
|
162 |
+
if DO_SOFTCAPPING: logits = SOFTCAP * triton_tanh(logits / SOFTCAP)
|
163 |
+
|
164 |
+
c = tl.max(logits, 0)
|
165 |
+
logsumexp = c + tl.log(tl.sum(tl.exp(logits - c), 0))
|
166 |
+
|
167 |
+
if chunk_idx == 0:
|
168 |
+
# logsumexp(chunked_logsumexp) - x
|
169 |
+
# Do the -x separately
|
170 |
+
if label_idx != -100:
|
171 |
+
x = tl.load(logits_ptr + label_idx).to(tl.float32)
|
172 |
+
# Go logit scaling for Cohere: t * x
|
173 |
+
if DO_LOGIT_SCALING: x = LOGIT_SCALE * x
|
174 |
+
# Do logit softcapping for Gemma 2: t * tanh(1/t * x)
|
175 |
+
if DO_SOFTCAPPING: x = SOFTCAP * triton_tanh(x / SOFTCAP)
|
176 |
+
loss = -1.0 * x
|
177 |
+
else:
|
178 |
+
loss = 0.0
|
179 |
+
tl.store(loss_ptr, loss)
|
180 |
+
pass
|
181 |
+
tl.store(logsumexp_ptr, logsumexp)
|
182 |
+
pass
|
183 |
+
_chunked_cross_entropy_forward = triton.jit(_chunked_cross_entropy_forward)
|
184 |
+
_chunked_cross_entropy_forward = triton.heuristics(
|
185 |
+
{
|
186 |
+
"DO_SOFTCAPPING": lambda args: bool(args["DO_SOFTCAPPING" ]),
|
187 |
+
"DO_LOGIT_SCALING": lambda args: bool(args["DO_LOGIT_SCALING"]),
|
188 |
+
}
|
189 |
+
)(_chunked_cross_entropy_forward)
|
190 |
+
|
191 |
+
|
192 |
+
def _cross_entropy_backward(
|
193 |
+
logits_ptr ,
|
194 |
+
logits_row_stride ,
|
195 |
+
dloss_ptr ,
|
196 |
+
dloss_row_stride ,
|
197 |
+
logsumexp_ptr ,
|
198 |
+
labels_ptr ,
|
199 |
+
VOCAB_SIZE : tl.constexpr,
|
200 |
+
BLOCK_SIZE : tl.constexpr,
|
201 |
+
DO_SOFTCAPPING : tl.constexpr,
|
202 |
+
SOFTCAP : tl.constexpr,
|
203 |
+
DO_LOGIT_SCALING : tl.constexpr,
|
204 |
+
LOGIT_SCALE : tl.constexpr,
|
205 |
+
):
|
206 |
+
"""
|
207 |
+
CE_i = -y log(P) = y * (log[sum(exp(x))] - x)
|
208 |
+
dC/dx = d/dx (y * log[sum(exp(x))] - x * y)
|
209 |
+
|
210 |
+
From https://en.wikipedia.org/wiki/LogSumExp
|
211 |
+
d/dx logsumexp = exp(x) / sum(exp(x)) = softmax(x)
|
212 |
+
|
213 |
+
dC/dx = y * exp(x) / sum(exp(x)) - d/dx (x * y)
|
214 |
+
dC/dx = y * exp[ log[exp(x) / sum(exp(x))] ] using x = exp(log(x)) trick
|
215 |
+
dC/dx = y * exp[x - logsumexp] - d/dx (x * y)
|
216 |
+
|
217 |
+
If y == 0: dC/dx = 0
|
218 |
+
If y == 1 and x == label: dC/dlabel = exp[x - logsumexp] - 1
|
219 |
+
If y == 1 and x != label: dC/dx = exp[x - logsumexp]
|
220 |
+
"""
|
221 |
+
row_idx = tl.program_id(0)
|
222 |
+
block_idx = tl.program_id(1)
|
223 |
+
|
224 |
+
logits_ptr += row_idx * triton_cast(logits_row_stride, tl.int64)
|
225 |
+
dloss_ptr += row_idx * dloss_row_stride
|
226 |
+
col_offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
227 |
+
mask = col_offsets < VOCAB_SIZE
|
228 |
+
label_idx = tl.load(labels_ptr + row_idx).to(tl.int32)
|
229 |
+
|
230 |
+
if label_idx != -100:
|
231 |
+
dloss = tl.load(dloss_ptr)
|
232 |
+
else:
|
233 |
+
dloss = 0.0
|
234 |
+
|
235 |
+
x = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf")).to(tl.float32)
|
236 |
+
|
237 |
+
# Do logit scaling for Cohere
|
238 |
+
if DO_LOGIT_SCALING:
|
239 |
+
# d/dx [s * x] = s
|
240 |
+
x = x * LOGIT_SCALE
|
241 |
+
pass
|
242 |
+
|
243 |
+
# Do logit softcapping for Gemma 2: t * tanh(1/t * x)
|
244 |
+
partial = x
|
245 |
+
if DO_SOFTCAPPING:
|
246 |
+
# d/dx [t * tanh(1/t * x)] = 1 - tanh^2(1/t * x)
|
247 |
+
partial = triton_tanh(x / SOFTCAP)
|
248 |
+
x = SOFTCAP * partial
|
249 |
+
pass
|
250 |
+
|
251 |
+
logsumexp = tl.load(logsumexp_ptr + row_idx)
|
252 |
+
y = tl.exp(x - logsumexp)
|
253 |
+
y = tl.where(
|
254 |
+
col_offsets == label_idx,
|
255 |
+
y - 1.0, # exp(x - logsumexp) - 1
|
256 |
+
y, # exp(x - logsumexp)
|
257 |
+
)
|
258 |
+
|
259 |
+
if DO_LOGIT_SCALING:
|
260 |
+
# d/dx [s * x] = s
|
261 |
+
y = y * LOGIT_SCALE
|
262 |
+
pass
|
263 |
+
|
264 |
+
if DO_SOFTCAPPING:
|
265 |
+
# d/dx [t * tanh(1/t * x)] = 1 - tanh^2(1/t * x)
|
266 |
+
y = y * (1.0 - partial*partial)
|
267 |
+
pass
|
268 |
+
|
269 |
+
# If y == 0: dC/dx = 0 ==> we already masked it to be = 0, so dloss = 0.
|
270 |
+
tl.store(logits_ptr + col_offsets, dloss * y, mask = mask)
|
271 |
+
pass
|
272 |
+
_cross_entropy_backward = triton.jit(_cross_entropy_backward)
|
273 |
+
_cross_entropy_backward = triton.heuristics(
|
274 |
+
{
|
275 |
+
"DO_SOFTCAPPING": lambda args: bool(args["DO_SOFTCAPPING" ]),
|
276 |
+
"DO_LOGIT_SCALING": lambda args: bool(args["DO_LOGIT_SCALING"]),
|
277 |
+
}
|
278 |
+
)(_cross_entropy_backward)
|
279 |
+
|
280 |
+
|
281 |
+
MAX_FUSED_SIZE = 65536 # 2**16
|
282 |
+
class Fast_CrossEntropyLoss(torch.autograd.Function):
|
283 |
+
@staticmethod
|
284 |
+
def forward(ctx, logits, labels, logit_softcapping : float = 0, logit_scaling : float = 0):
|
285 |
+
n_rows : int
|
286 |
+
vocab_size : int
|
287 |
+
n_rows, vocab_size = logits.shape
|
288 |
+
device = logits.device
|
289 |
+
|
290 |
+
div, mod = divmod(vocab_size, MAX_FUSED_SIZE)
|
291 |
+
n_chunks : int = div + (mod != 0)
|
292 |
+
losses = torch.empty(n_rows, dtype = torch.float32, device = device)
|
293 |
+
|
294 |
+
DO_SOFTCAPPING : bool = bool(logit_softcapping != 0)
|
295 |
+
DO_LOGIT_SCALING : bool = bool(logit_scaling != 0)
|
296 |
+
|
297 |
+
BLOCK_SIZE : int
|
298 |
+
num_warps : int
|
299 |
+
if n_chunks == 1:
|
300 |
+
# For small vocabs <= 65336 like Llama, Mistral
|
301 |
+
BLOCK_SIZE, num_warps = calculate_settings(vocab_size)
|
302 |
+
logsumexp = torch.empty(n_rows, dtype = torch.float32, device = device)
|
303 |
+
|
304 |
+
with torch_cuda_device(device):
|
305 |
+
_cross_entropy_forward[(n_rows,)](
|
306 |
+
logits, logits.stride(0),
|
307 |
+
losses,
|
308 |
+
logsumexp,
|
309 |
+
labels,
|
310 |
+
VOCAB_SIZE = vocab_size,
|
311 |
+
BLOCK_SIZE = BLOCK_SIZE,
|
312 |
+
DO_SOFTCAPPING = DO_SOFTCAPPING,
|
313 |
+
SOFTCAP = logit_softcapping,
|
314 |
+
DO_LOGIT_SCALING = DO_LOGIT_SCALING,
|
315 |
+
LOGIT_SCALE = logit_scaling,
|
316 |
+
num_warps = num_warps,
|
317 |
+
)
|
318 |
+
else:
|
319 |
+
# For large vocabs > 65336 like Gemma 256K
|
320 |
+
logsumexp = torch.empty((n_rows, n_chunks,), dtype = torch.float32, device = device)
|
321 |
+
|
322 |
+
with torch_cuda_device(device):
|
323 |
+
_chunked_cross_entropy_forward[(n_rows, n_chunks,)](
|
324 |
+
logits, logits.stride(0),
|
325 |
+
losses,
|
326 |
+
logsumexp,
|
327 |
+
labels,
|
328 |
+
VOCAB_SIZE = vocab_size,
|
329 |
+
N_CHUNKS = n_chunks,
|
330 |
+
BLOCK_SIZE = MAX_FUSED_SIZE,
|
331 |
+
DO_SOFTCAPPING = DO_SOFTCAPPING,
|
332 |
+
SOFTCAP = logit_softcapping,
|
333 |
+
DO_LOGIT_SCALING = DO_LOGIT_SCALING,
|
334 |
+
LOGIT_SCALE = logit_scaling,
|
335 |
+
num_warps = 32,
|
336 |
+
)
|
337 |
+
# logsumexp(chunked_logsumexp) - x
|
338 |
+
# Do the -x separately
|
339 |
+
logsumexp = torch.logsumexp(logsumexp, dim = 1) # Row sum
|
340 |
+
losses += logsumexp
|
341 |
+
losses.masked_fill_(labels == -100, 0) # Don't forget to mask padding out!
|
342 |
+
pass
|
343 |
+
|
344 |
+
ctx.save_for_backward(logits, logsumexp, labels)
|
345 |
+
ctx.DO_SOFTCAPPING = DO_SOFTCAPPING
|
346 |
+
ctx.logit_softcapping = logit_softcapping
|
347 |
+
ctx.DO_LOGIT_SCALING = DO_LOGIT_SCALING
|
348 |
+
ctx.logit_scaling = logit_scaling
|
349 |
+
return losses
|
350 |
+
pass
|
351 |
+
|
352 |
+
|
353 |
+
@staticmethod
|
354 |
+
def backward(ctx, dlosses):
|
355 |
+
logits, logsumexp, labels = ctx.saved_tensors
|
356 |
+
n_rows : int
|
357 |
+
vocab_size : int
|
358 |
+
n_rows, vocab_size = logits.shape
|
359 |
+
|
360 |
+
BLOCK_SIZE : int = 4096
|
361 |
+
div : int
|
362 |
+
mod : int
|
363 |
+
div, mod = divmod(vocab_size, BLOCK_SIZE)
|
364 |
+
n_blocks : int = div + (mod != 0)
|
365 |
+
|
366 |
+
with torch_cuda_device(dlosses.device):
|
367 |
+
_cross_entropy_backward[(n_rows, n_blocks,)](
|
368 |
+
logits, logits.stride(0),
|
369 |
+
dlosses, dlosses.stride(0),
|
370 |
+
logsumexp,
|
371 |
+
labels,
|
372 |
+
VOCAB_SIZE = vocab_size,
|
373 |
+
BLOCK_SIZE = BLOCK_SIZE,
|
374 |
+
DO_SOFTCAPPING = ctx.DO_SOFTCAPPING,
|
375 |
+
SOFTCAP = ctx.logit_softcapping,
|
376 |
+
DO_LOGIT_SCALING = ctx.DO_LOGIT_SCALING,
|
377 |
+
LOGIT_SCALE = ctx.logit_scaling,
|
378 |
+
num_warps = 8,
|
379 |
+
)
|
380 |
+
return logits, None, None, None,
|
381 |
+
pass
|
382 |
+
pass
|
383 |
+
|
384 |
+
|
385 |
+
def fast_cross_entropy_loss(
|
386 |
+
logits,
|
387 |
+
labels,
|
388 |
+
logit_softcapping = 0,
|
389 |
+
logit_scaling = 0,
|
390 |
+
n_items = None,
|
391 |
+
):
|
392 |
+
"""
|
393 |
+
Arguments:
|
394 |
+
logits: (batch, seq_len, vocab_size)
|
395 |
+
labels: (batch, seq_len,)
|
396 |
+
Returns:
|
397 |
+
losses: float
|
398 |
+
"""
|
399 |
+
batch, seq_len, d = logits.shape
|
400 |
+
assert(labels.shape == (batch, seq_len))
|
401 |
+
|
402 |
+
loss = Fast_CrossEntropyLoss.apply(
|
403 |
+
logits.view(batch*seq_len, d),
|
404 |
+
labels.view(-1),
|
405 |
+
logit_softcapping,
|
406 |
+
logit_scaling,
|
407 |
+
)
|
408 |
+
if n_items is None:
|
409 |
+
n_items = torch.count_nonzero(labels != -100)
|
410 |
+
return loss.sum() / n_items
|
411 |
+
pass
|
412 |
+
if (Version(torch.__version__) < Version("2.4.0")) and \
|
413 |
+
not hasattr(fast_cross_entropy_loss, "__wrapped__"):
|
414 |
+
fast_cross_entropy_loss = torch._disable_dynamo(fast_cross_entropy_loss)
|
415 |
+
pass
|
416 |
+
|
417 |
+
# Patch CE Losses in transformers
|
418 |
+
def patch_loss_functions(torch_compile = True):
|
419 |
+
_patch_loss_functions(fast_cross_entropy_loss, torch_compile = torch_compile)
|
420 |
+
pass
|
torch-ext/unsloth_kernels/fast_lora.py
ADDED
@@ -0,0 +1,537 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import torch
|
16 |
+
from .utils import (
|
17 |
+
fast_dequantize,
|
18 |
+
QUANT_STATE,
|
19 |
+
get_lora_parameters,
|
20 |
+
get_lora_parameters_bias,
|
21 |
+
matmul_lora,
|
22 |
+
torch_amp_custom_fwd,
|
23 |
+
torch_amp_custom_bwd,
|
24 |
+
)
|
25 |
+
|
26 |
+
|
27 |
+
class LoRA_MLP(torch.autograd.Function):
|
28 |
+
"""
|
29 |
+
### LoRA weights
|
30 |
+
G = G + Ag @ Bg
|
31 |
+
U = U + Au @ Bu
|
32 |
+
W = W + Aw @ Bw
|
33 |
+
|
34 |
+
### SwiGLU(X)
|
35 |
+
e = X @ G
|
36 |
+
f = e * sigmoid(e)
|
37 |
+
g = X @ U
|
38 |
+
h = f * g
|
39 |
+
i = h @ W
|
40 |
+
|
41 |
+
### Backpropagation chain rule
|
42 |
+
See our blog post for more details
|
43 |
+
|
44 |
+
df = sigmoid(e) * (1 - f) + f
|
45 |
+
dC/dW = h.T @ dY
|
46 |
+
dC/dU = X.T @ (D @ W.T * f)
|
47 |
+
dC/dG = X.T @ (D @ W.T * df * g)
|
48 |
+
|
49 |
+
### Down projection LoRA weights
|
50 |
+
dC/dAw = dC/dW @ B.T
|
51 |
+
dC/dBw = A.T @ dC/dW
|
52 |
+
dC/dAw = h.T @ dY @ B.T
|
53 |
+
dC/dBw = A.T @ h.T @ dY
|
54 |
+
|
55 |
+
### Up projection LoRA weights
|
56 |
+
dC/dAu = X.T @ (D @ W.T * f) @ B.T
|
57 |
+
dC/dBu = A.T @ X.T @ (D @ W.T * f)
|
58 |
+
|
59 |
+
### Gate projection LoRA weights
|
60 |
+
dC/dAg = X.T @ (D @ W.T * df * g) @ B.T
|
61 |
+
dC/dBg = A.T @ X.T @ (D @ W.T * df * g)
|
62 |
+
|
63 |
+
Don't forget to see our blog post for more details!
|
64 |
+
"""
|
65 |
+
@staticmethod
|
66 |
+
@torch_amp_custom_fwd
|
67 |
+
def forward(ctx, X : torch.Tensor,
|
68 |
+
gateW, gateW_quant, gateA, gateB, gateS,
|
69 |
+
upW, upW_quant, upA, upB, upS,
|
70 |
+
downW, downW_quant, downA, downB, downS,
|
71 |
+
_forward_function, _backward_function,
|
72 |
+
inplace = True,):
|
73 |
+
dtype = X.dtype
|
74 |
+
|
75 |
+
e = matmul_lora(X, gateW, gateW_quant, gateA, gateB, gateS)
|
76 |
+
g = matmul_lora(X, upW, upW_quant, upA, upB, upS)
|
77 |
+
h = _forward_function(e, g)
|
78 |
+
i = matmul_lora(h, downW, downW_quant, downA, downB, downS)
|
79 |
+
|
80 |
+
ctx.custom_saved_tensors = (
|
81 |
+
gateW, gateW_quant, gateS,
|
82 |
+
upW, upW_quant, upS,
|
83 |
+
downW, downW_quant, downS,
|
84 |
+
_backward_function,
|
85 |
+
)
|
86 |
+
ctx.save_for_backward(gateA, gateB, upA, upB, downA, downB,
|
87 |
+
X, e, g)
|
88 |
+
ctx.inplace = inplace
|
89 |
+
return i
|
90 |
+
pass
|
91 |
+
|
92 |
+
|
93 |
+
@staticmethod
|
94 |
+
@torch_amp_custom_bwd
|
95 |
+
def backward(ctx, dY : torch.Tensor):
|
96 |
+
gateW, gateW_quant, gateS, upW, upW_quant, upS, downW, downW_quant, downS, \
|
97 |
+
_backward_function = ctx.custom_saved_tensors
|
98 |
+
gateA, gateB, upA, upB, downA, downB, \
|
99 |
+
X, e, g = ctx.saved_tensors
|
100 |
+
|
101 |
+
batch, seq_len, hd = X.shape
|
102 |
+
dY = dY.view(-1, dY.shape[-1])
|
103 |
+
X = X .view(-1, X .shape[-1])
|
104 |
+
e = e .view(-1, e .shape[-1])
|
105 |
+
g = g .view(-1, g .shape[-1])
|
106 |
+
dtype = X.dtype
|
107 |
+
|
108 |
+
gateA, gateB, upA, upB, downA, downB = \
|
109 |
+
gateA.to(dtype), gateB.to(dtype), upA.to(dtype), upB.to(dtype), downA.to(dtype), downB.to(dtype)
|
110 |
+
|
111 |
+
gateA, gateB, upA, upB, downA, downB = \
|
112 |
+
gateA.t(), gateB.t(), upA.t(), upB.t(), downA.t(), downB.t()
|
113 |
+
|
114 |
+
DW = matmul_lora(dY, downW.t(), downW_quant, downB, downA, downS)
|
115 |
+
DW, e, g = _backward_function(DW, e, g)
|
116 |
+
h, df, de = DW, e, g
|
117 |
+
|
118 |
+
d_downA = torch.empty_like(downA)
|
119 |
+
d_downB = torch.empty_like(downB)
|
120 |
+
d_gateA = torch.empty_like(gateA)
|
121 |
+
d_gateB = torch.empty_like(gateB)
|
122 |
+
d_upA = torch.empty_like(upA)
|
123 |
+
d_upB = torch.empty_like(upB)
|
124 |
+
|
125 |
+
# Down projection LoRA weights
|
126 |
+
# d_downA = h.t() @ (dY @ downB.t())
|
127 |
+
# d_downB = (downA.t() @ h.t()) @ dY
|
128 |
+
# d_downA *= downS
|
129 |
+
# d_downB *= downS
|
130 |
+
d_downA.addmm_(h.t(), dY @ downB.t(), alpha = downS, beta = 0)
|
131 |
+
d_downB.addmm_(downA.t() @ h.t(), dY, alpha = downS, beta = 0)
|
132 |
+
|
133 |
+
# Up projection LoRA weights
|
134 |
+
# d_upA = X.t() @ (df @ upB.t())
|
135 |
+
# d_upB = (upA.t() @ X.t()) @ df
|
136 |
+
# d_upA *= upS
|
137 |
+
# d_upB *= upS
|
138 |
+
d_upA.addmm_(X.t(), df @ upB.t(), alpha = upS, beta = 0)
|
139 |
+
d_upB.addmm_(upA.t() @ X.t(), df, alpha = upS, beta = 0)
|
140 |
+
|
141 |
+
# Gate projection LoRA weights
|
142 |
+
# d_gateA = X.t() @ (de @ gateB.t())
|
143 |
+
# d_gateB = (gateA.t() @ X.t()) @ de
|
144 |
+
# d_gateA *= gateS
|
145 |
+
# d_gateB *= gateS
|
146 |
+
d_gateA.addmm_(X.t(), de @ gateB.t(), alpha = gateS, beta = 0)
|
147 |
+
d_gateB.addmm_(gateA.t() @ X.t(), de, alpha = gateS, beta = 0)
|
148 |
+
|
149 |
+
# dX = matmul_lora(df, upW.t(), upW_quant, upB, upA, upS)
|
150 |
+
# dX += matmul_lora(de, gateW.t(), gateW_quant, gateB, gateA, gateS)
|
151 |
+
upW = fast_dequantize(upW.t(), upW_quant)
|
152 |
+
dX = torch.matmul(df, upW.t(), out = X if ctx.inplace else None)
|
153 |
+
del upW
|
154 |
+
# dX += df @ upB.to(dtype).t() @ (upS * upA.to(dtype).t())
|
155 |
+
dX.addmm_(df @ upB.t(), upA.t(), alpha = upS)
|
156 |
+
|
157 |
+
gateW = fast_dequantize(gateW.t(), gateW_quant)
|
158 |
+
# dX += de @ gateW.t()
|
159 |
+
dX.addmm_(de, gateW.t())
|
160 |
+
del gateW
|
161 |
+
# dX += de @ gateB.to(dtype).t() @ (gateS * gateA.to(dtype).t())
|
162 |
+
dX.addmm_(de @ gateB.t(), gateA.t(), alpha = gateS)
|
163 |
+
|
164 |
+
# gateW, gateW_quant, gateA, gateB, gateS,
|
165 |
+
# upW, upW_quant, upA, upB, upS,
|
166 |
+
# downW, downW_quant, downA, downB, downS,
|
167 |
+
return dX.view(batch, seq_len, hd), \
|
168 |
+
None, None, d_gateA.t(), d_gateB.t(), None, \
|
169 |
+
None, None, d_upA.t(), d_upB.t(), None, \
|
170 |
+
None, None, d_downA.t(), d_downB.t(), None, \
|
171 |
+
None, None, None, # _backward and _forward and inplace
|
172 |
+
pass
|
173 |
+
pass
|
174 |
+
|
175 |
+
|
176 |
+
from .swiglu import swiglu_fg_kernel, swiglu_DWf_DW_dfg_kernel
|
177 |
+
def apply_lora_mlp_swiglu(self, X, inplace = True):
|
178 |
+
gateW, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj)
|
179 |
+
upW, upW_quant, upA, upB, upS = get_lora_parameters(self. up_proj)
|
180 |
+
downW, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj)
|
181 |
+
out = LoRA_MLP.apply(X,
|
182 |
+
gateW, gateW_quant, gateA, gateB, gateS,
|
183 |
+
upW, upW_quant, upA, upB, upS,
|
184 |
+
downW, downW_quant, downA, downB, downS,
|
185 |
+
swiglu_fg_kernel, swiglu_DWf_DW_dfg_kernel,
|
186 |
+
inplace,)
|
187 |
+
return out
|
188 |
+
pass
|
189 |
+
|
190 |
+
|
191 |
+
from .geglu import geglu_exact_forward_kernel, geglu_exact_backward_kernel
|
192 |
+
def apply_lora_mlp_geglu_exact(self, X, inplace = True):
|
193 |
+
gateW, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj)
|
194 |
+
upW, upW_quant, upA, upB, upS = get_lora_parameters(self. up_proj)
|
195 |
+
downW, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj)
|
196 |
+
out = LoRA_MLP.apply(X,
|
197 |
+
gateW, gateW_quant, gateA, gateB, gateS,
|
198 |
+
upW, upW_quant, upA, upB, upS,
|
199 |
+
downW, downW_quant, downA, downB, downS,
|
200 |
+
geglu_exact_forward_kernel, geglu_exact_backward_kernel,
|
201 |
+
inplace,)
|
202 |
+
return out
|
203 |
+
pass
|
204 |
+
|
205 |
+
|
206 |
+
from .geglu import geglu_approx_forward_kernel, geglu_approx_backward_kernel
|
207 |
+
def apply_lora_mlp_geglu_approx(self, X):
|
208 |
+
gateW, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj)
|
209 |
+
upW, upW_quant, upA, upB, upS = get_lora_parameters(self. up_proj)
|
210 |
+
downW, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj)
|
211 |
+
out = LoRA_MLP.apply(X,
|
212 |
+
gateW, gateW_quant, gateA, gateB, gateS,
|
213 |
+
upW, upW_quant, upA, upB, upS,
|
214 |
+
downW, downW_quant, downA, downB, downS,
|
215 |
+
geglu_approx_forward_kernel, geglu_approx_backward_kernel,)
|
216 |
+
return out
|
217 |
+
pass
|
218 |
+
|
219 |
+
|
220 |
+
class LoRA_QKV(torch.autograd.Function):
|
221 |
+
"""
|
222 |
+
### LoRA weights
|
223 |
+
Wq = Wq + Aq @ Bq
|
224 |
+
Wk = Wk + Ak @ Bk
|
225 |
+
Wv = Wv + Av @ Bv
|
226 |
+
Q = X @ Wq = X @ Wq + X @ Aq @ Bq
|
227 |
+
K = X @ Wk = X @ Wk + X @ Ak @ Bk
|
228 |
+
V = X @ Wv = X @ Wv + X @ Av @ Bv
|
229 |
+
|
230 |
+
### Backpropagation chain rule
|
231 |
+
See our blogpost for more details.
|
232 |
+
|
233 |
+
dC/dWq = X.T @ D(Wq)
|
234 |
+
dC/dWk = X.T @ D(Wk)
|
235 |
+
dC/dWv = X.T @ D(Wv)
|
236 |
+
We then sum them all find dC/dX
|
237 |
+
|
238 |
+
### Q projection LoRA weights
|
239 |
+
dC/dAq = X.T @ D(Wq) @ B.T
|
240 |
+
dC/dBq = A.T @ X.T @ D(Wq)
|
241 |
+
|
242 |
+
### K projection LoRA weights
|
243 |
+
dC/dAk = X.T @ D(Wk) @ B.T
|
244 |
+
dC/dBk = A.T @ X.T @ D(Wk)
|
245 |
+
|
246 |
+
### V projection LoRA weights
|
247 |
+
dC/dAv = X.T @ D(Wv) @ B.T
|
248 |
+
dC/dBv = A.T @ X.T @ D(Wv)
|
249 |
+
"""
|
250 |
+
@staticmethod
|
251 |
+
@torch_amp_custom_fwd
|
252 |
+
def forward(ctx, X : torch.Tensor,
|
253 |
+
QW, QW_quant, QA, QB, QS,
|
254 |
+
KW, KW_quant, KA, KB, KS,
|
255 |
+
VW, VW_quant, VA, VB, VS,
|
256 |
+
inplace = True):
|
257 |
+
dtype = X.dtype
|
258 |
+
|
259 |
+
Q = matmul_lora(X, QW, QW_quant, QA, QB, QS)
|
260 |
+
K = matmul_lora(X, KW, KW_quant, KA, KB, KS)
|
261 |
+
V = matmul_lora(X, VW, VW_quant, VA, VB, VS)
|
262 |
+
|
263 |
+
ctx.custom_saved_tensors = (
|
264 |
+
QW, QW_quant, QS,
|
265 |
+
KW, KW_quant, KS,
|
266 |
+
VW, VW_quant, VS,
|
267 |
+
)
|
268 |
+
ctx.save_for_backward(X, QA, QB, KA, KB, VA, VB,)
|
269 |
+
ctx.inplace = inplace
|
270 |
+
return Q, K, V
|
271 |
+
pass
|
272 |
+
|
273 |
+
@staticmethod
|
274 |
+
@torch_amp_custom_bwd
|
275 |
+
def backward(ctx, dQ, dK, dV):
|
276 |
+
QW, QW_quant, QS, KW, KW_quant, KS, VW, VW_quant, VS = \
|
277 |
+
ctx.custom_saved_tensors
|
278 |
+
X, QA, QB, KA, KB, VA, VB, = ctx.saved_tensors
|
279 |
+
|
280 |
+
batch, seq_len, hd = X.shape
|
281 |
+
dQ = dQ.view(-1, dQ.shape[-1])
|
282 |
+
dK = dK.reshape(-1, dK.shape[-1]) # view doesn't work on K.T
|
283 |
+
dV = dV.view(-1, dV.shape[-1])
|
284 |
+
X = X .view(-1, X .shape[-1])
|
285 |
+
dtype = X.dtype
|
286 |
+
|
287 |
+
QA, QB, KA, KB, VA, VB = \
|
288 |
+
QA.to(dtype), QB.to(dtype), KA.to(dtype), KB.to(dtype), VA.to(dtype), VB.to(dtype)
|
289 |
+
|
290 |
+
QA, QB, KA, KB, VA, VB = \
|
291 |
+
QA.t(), QB.t(), KA.t(), KB.t(), VA.t(), VB.t()
|
292 |
+
|
293 |
+
### Weight projection LoRA weights
|
294 |
+
# See our blogpost for more details.
|
295 |
+
d_QA = torch.empty_like(QA)
|
296 |
+
d_QB = torch.empty_like(QB)
|
297 |
+
d_KA = torch.empty_like(KA)
|
298 |
+
d_KB = torch.empty_like(KB)
|
299 |
+
d_VA = torch.empty_like(VA)
|
300 |
+
d_VB = torch.empty_like(VB)
|
301 |
+
|
302 |
+
# Q Projection
|
303 |
+
# d_QA = X.t() @ (dQ @ QB.t())
|
304 |
+
# d_QB = (QA.t() @ X.t()) @ dQ
|
305 |
+
# d_QA *= QS
|
306 |
+
# d_QB *= QS
|
307 |
+
d_QA.addmm_(X.t(), dQ @ QB.t(), alpha = QS, beta = 0)
|
308 |
+
d_QB.addmm_(QA.t() @ X.t(), dQ, alpha = QS, beta = 0)
|
309 |
+
|
310 |
+
# K Projection
|
311 |
+
# d_KA = X.t() @ (dK @ KB.t())
|
312 |
+
# d_KB = (KA.t() @ X.t()) @ dK
|
313 |
+
# d_KA *= KS
|
314 |
+
# d_KB *= KS
|
315 |
+
d_KA.addmm_(X.t(), dK @ KB.t(), alpha = KS, beta = 0)
|
316 |
+
d_KB.addmm_(KA.t() @ X.t(), dK, alpha = KS, beta = 0)
|
317 |
+
|
318 |
+
# V Projection
|
319 |
+
# d_VA = X.t() @ (dV @ VB.t())
|
320 |
+
# d_VB = (VA.t() @ X.t()) @ dV
|
321 |
+
# d_VA *= VS
|
322 |
+
# d_VB *= VS
|
323 |
+
d_VA.addmm_(X.t(), dV @ VB.t(), alpha = VS, beta = 0)
|
324 |
+
d_VB.addmm_(VA.t() @ X.t(), dV, alpha = VS, beta = 0)
|
325 |
+
|
326 |
+
# Combine derivatives to find dX
|
327 |
+
# dQ
|
328 |
+
QW = fast_dequantize(QW.t(), QW_quant)
|
329 |
+
dX = torch.matmul(dQ, QW.t(), out = X if ctx.inplace else None)
|
330 |
+
del QW
|
331 |
+
# dX += (dQ @ QB.to(dtype).t() @ (QS * QA.to(dtype).t()))
|
332 |
+
dX.addmm_(dQ @ QB.t(), QA.t(), alpha = QS)
|
333 |
+
|
334 |
+
# dK
|
335 |
+
KW = fast_dequantize(KW.t(), KW_quant)
|
336 |
+
# dX += dK @ KW.t()
|
337 |
+
dX.addmm_(dK, KW.t())
|
338 |
+
del KW
|
339 |
+
# dX += dK @ KB.to(dtype).t() @ (KS * KA.to(dtype).t())
|
340 |
+
dX.addmm_(dK @ KB.t(), KA.t(), alpha = KS)
|
341 |
+
|
342 |
+
# dV
|
343 |
+
VW = fast_dequantize(VW.t(), VW_quant)
|
344 |
+
# dX += dV @ VW.t()
|
345 |
+
dX.addmm_(dV, VW.t())
|
346 |
+
del VW
|
347 |
+
# dX += dV @ VB.to(dtype).t() @ (VS * VA.to(dtype).t())
|
348 |
+
dX.addmm_(dV @ VB.t(), VA.t(), alpha = VS)
|
349 |
+
|
350 |
+
# QW, QW_quant, QA, QB, QS,
|
351 |
+
# KW, KW_quant, KA, KB, KS,
|
352 |
+
# VW, VW_quant, VA, VB, VS,
|
353 |
+
return dX.view(batch, seq_len, hd), \
|
354 |
+
None, None, d_QA.t(), d_QB.t(), None, \
|
355 |
+
None, None, d_KA.t(), d_KB.t(), None, \
|
356 |
+
None, None, d_VA.t(), d_VB.t(), None, \
|
357 |
+
None,
|
358 |
+
pass
|
359 |
+
pass
|
360 |
+
|
361 |
+
|
362 |
+
def apply_lora_qkv(self, X, inplace = True):
|
363 |
+
QW, QW_quant, QA, QB, QS = get_lora_parameters(self.q_proj)
|
364 |
+
KW, KW_quant, KA, KB, KS = get_lora_parameters(self.k_proj)
|
365 |
+
VW, VW_quant, VA, VB, VS = get_lora_parameters(self.v_proj)
|
366 |
+
Q, K, V = LoRA_QKV.apply(X,
|
367 |
+
QW, QW_quant, QA, QB, QS,
|
368 |
+
KW, KW_quant, KA, KB, KS,
|
369 |
+
VW, VW_quant, VA, VB, VS,
|
370 |
+
inplace,
|
371 |
+
)
|
372 |
+
return Q, K, V
|
373 |
+
pass
|
374 |
+
|
375 |
+
|
376 |
+
class LoRA_W(torch.autograd.Function):
|
377 |
+
"""
|
378 |
+
### LoRA weights
|
379 |
+
Wq = Wq + Aq @ Bq
|
380 |
+
Wk = Wk + Ak @ Bk
|
381 |
+
Wv = Wv + Av @ Bv
|
382 |
+
Q = X @ Wq = X @ Wq + X @ Aq @ Bq
|
383 |
+
K = X @ Wk = X @ Wk + X @ Ak @ Bk
|
384 |
+
V = X @ Wv = X @ Wv + X @ Av @ Bv
|
385 |
+
|
386 |
+
### Backpropagation chain rule
|
387 |
+
dC/dWq = X.T @ D(Wq)
|
388 |
+
dC/dWk = X.T @ D(Wk)
|
389 |
+
dC/dWv = X.T @ D(Wv)
|
390 |
+
|
391 |
+
### Q projection LoRA weights
|
392 |
+
dC/dAq = X.T @ D(Wq) @ B.T
|
393 |
+
dC/dBq = A.T @ X.T @ D(Wq)
|
394 |
+
|
395 |
+
### K projection LoRA weights
|
396 |
+
dC/dAk = X.T @ D(Wk) @ B.T
|
397 |
+
dC/dBk = A.T @ X.T @ D(Wk)
|
398 |
+
|
399 |
+
### V projection LoRA weights
|
400 |
+
dC/dAv = X.T @ D(Wv) @ B.T
|
401 |
+
dC/dBv = A.T @ X.T @ D(Wv)
|
402 |
+
"""
|
403 |
+
@staticmethod
|
404 |
+
@torch_amp_custom_fwd
|
405 |
+
def forward(ctx, X : torch.Tensor,
|
406 |
+
W, W_quant, A, B, S):
|
407 |
+
dtype = X.dtype
|
408 |
+
XW = matmul_lora(X, W, W_quant, A, B, S)
|
409 |
+
ctx.custom_saved_tensors = (W, W_quant, S,)
|
410 |
+
ctx.save_for_backward(A, B, X)
|
411 |
+
return XW
|
412 |
+
pass
|
413 |
+
|
414 |
+
@staticmethod
|
415 |
+
@torch_amp_custom_bwd
|
416 |
+
def backward(ctx, dY : torch.Tensor):
|
417 |
+
W, W_quant, S = ctx.custom_saved_tensors
|
418 |
+
A, B, X = ctx.saved_tensors
|
419 |
+
|
420 |
+
batch, seq_len, hd = X.shape
|
421 |
+
dY = dY.reshape(-1, dY.shape[-1]) # Must be reshape
|
422 |
+
X = X .reshape(-1, X .shape[-1]) # Must be reshape
|
423 |
+
dtype = X.dtype
|
424 |
+
|
425 |
+
A, B = A.to(dtype), B.to(dtype)
|
426 |
+
|
427 |
+
A, B = A.t(), B.t()
|
428 |
+
|
429 |
+
d_A = torch.empty_like(A)
|
430 |
+
d_B = torch.empty_like(B)
|
431 |
+
|
432 |
+
### Weight projection LoRA weights
|
433 |
+
# Weight projection
|
434 |
+
# d_A = X.t() @ (dY @ B.t())
|
435 |
+
# d_B = (A.t() @ X.t()) @ dY
|
436 |
+
# d_A *= S
|
437 |
+
# d_B *= S
|
438 |
+
d_A.addmm_(X.t(), dY @ B.t(), alpha = S, beta = 0)
|
439 |
+
d_B.addmm_(A.t() @ X.t(), dY, alpha = S, beta = 0)
|
440 |
+
|
441 |
+
# Get derivative for dX
|
442 |
+
W = fast_dequantize(W.t(), W_quant)
|
443 |
+
dX = dY @ W.t()
|
444 |
+
del W
|
445 |
+
# dX += dY @ B.to(dtype).t() @ (S * A.to(dtype).t())
|
446 |
+
dX.addmm_(dY @ B.t(), A.t(), alpha = S)
|
447 |
+
|
448 |
+
# W, W_quant, A, B, S
|
449 |
+
return dX.view(batch, seq_len, hd), \
|
450 |
+
None, None, d_A.t(), d_B.t(), None
|
451 |
+
pass
|
452 |
+
pass
|
453 |
+
|
454 |
+
|
455 |
+
def apply_lora_o(self, X):
|
456 |
+
OW, OW_quant, OA, OB, OS = get_lora_parameters(self.o_proj)
|
457 |
+
O = LoRA_W.apply(X, OW, OW_quant, OA, OB, OS)
|
458 |
+
return O
|
459 |
+
pass
|
460 |
+
|
461 |
+
|
462 |
+
IDENTITY_DROPOUT = torch.nn.Identity
|
463 |
+
@torch._disable_dynamo
|
464 |
+
def fast_lora_forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
|
465 |
+
raise NotImplementedError(
|
466 |
+
"Unsloth: Currently not supported yet - reshaping done incorrectly"
|
467 |
+
)
|
468 |
+
self._check_forward_args(x, *args, **kwargs)
|
469 |
+
adapter_names = kwargs.pop("adapter_names", None)
|
470 |
+
|
471 |
+
if self.disable_adapters:
|
472 |
+
if self.merged:
|
473 |
+
self.unmerge()
|
474 |
+
result = self.base_layer(x, *args, **kwargs)
|
475 |
+
elif adapter_names is not None:
|
476 |
+
result = self._mixed_batch_forward(x, *args, adapter_names=adapter_names, **kwargs)
|
477 |
+
elif self.merged:
|
478 |
+
result = self.base_layer(x, *args, **kwargs)
|
479 |
+
else:
|
480 |
+
# Fastpath
|
481 |
+
if len(self.active_adapters) == 1:
|
482 |
+
active_adapter = self.active_adapters[0]
|
483 |
+
if active_adapter not in self.lora_A.keys(): return self.base_layer(x, *args, **kwargs)
|
484 |
+
|
485 |
+
dropout = self.lora_dropout[active_adapter]
|
486 |
+
if isinstance(dropout, IDENTITY_DROPOUT) and not self.use_dora[active_adapter]:
|
487 |
+
lora_A = self.lora_A[active_adapter].weight
|
488 |
+
lora_B = self.lora_B[active_adapter].weight
|
489 |
+
scaling = self.scaling[active_adapter]
|
490 |
+
W = self.base_layer.weight
|
491 |
+
return LoRA_W.apply(x, W, QUANT_STATE(W), lora_A, lora_B, scaling)
|
492 |
+
pass
|
493 |
+
pass
|
494 |
+
|
495 |
+
result = self.base_layer(x, *args, **kwargs)
|
496 |
+
# As per Tim Dettmers, for 4bit, we need to defensively clone here.
|
497 |
+
# The reason is that in some cases, an error can occur that backprop
|
498 |
+
# does not work on a manipulated view. This issue may be solved with
|
499 |
+
# newer PyTorch versions but this would need extensive testing to be
|
500 |
+
# sure.
|
501 |
+
result = result.clone()
|
502 |
+
|
503 |
+
for active_adapter in self.active_adapters:
|
504 |
+
if active_adapter not in self.lora_A.keys():
|
505 |
+
continue
|
506 |
+
lora_A = self.lora_A[active_adapter]
|
507 |
+
lora_B = self.lora_B[active_adapter]
|
508 |
+
dropout = self.lora_dropout[active_adapter]
|
509 |
+
scaling = self.scaling[active_adapter]
|
510 |
+
|
511 |
+
requires_conversion = not torch.is_autocast_enabled()
|
512 |
+
if requires_conversion:
|
513 |
+
expected_dtype = result.dtype
|
514 |
+
x = x.to(lora_A.weight.dtype)
|
515 |
+
|
516 |
+
if not self.use_dora[active_adapter]:
|
517 |
+
result = result + lora_B(lora_A(dropout(x))) * scaling
|
518 |
+
else:
|
519 |
+
if isinstance(dropout, torch.nn.Identity) or not self.training:
|
520 |
+
base_result = result
|
521 |
+
else:
|
522 |
+
x = dropout(x)
|
523 |
+
base_result = None
|
524 |
+
|
525 |
+
result = result + self.lora_magnitude_vector[active_adapter](
|
526 |
+
x,
|
527 |
+
lora_A=lora_A,
|
528 |
+
lora_B=lora_B,
|
529 |
+
scaling=scaling,
|
530 |
+
base_layer=self.get_base_layer(),
|
531 |
+
base_result=base_result,
|
532 |
+
)
|
533 |
+
if requires_conversion:
|
534 |
+
result = result.to(expected_dtype)
|
535 |
+
|
536 |
+
return result
|
537 |
+
pass
|
torch-ext/unsloth_kernels/flex_attention.py
ADDED
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import torch
|
16 |
+
from functools import lru_cache
|
17 |
+
from transformers.models.llama.modeling_llama import logger
|
18 |
+
import os
|
19 |
+
|
20 |
+
torch_compile_options = {
|
21 |
+
"epilogue_fusion" : True,
|
22 |
+
"max_autotune" : True,
|
23 |
+
"shape_padding" : True,
|
24 |
+
"trace.enabled" : os.environ.get("UNSLOTH_COMPILE_DEBUG", "0") == "1",
|
25 |
+
"triton.cudagraphs" : False,
|
26 |
+
}
|
27 |
+
|
28 |
+
# Flex Attention supported from torch 2.5 onwards only
|
29 |
+
try:
|
30 |
+
from torch.nn.attention.flex_attention import (
|
31 |
+
flex_attention as _flex_attention,
|
32 |
+
create_block_mask as _create_block_mask,
|
33 |
+
)
|
34 |
+
_flex_attention = torch.compile(_flex_attention, dynamic = True, options = torch_compile_options)
|
35 |
+
HAS_FLEX_ATTENTION = False
|
36 |
+
except:
|
37 |
+
HAS_FLEX_ATTENTION = False
|
38 |
+
pass
|
39 |
+
|
40 |
+
|
41 |
+
if not HAS_FLEX_ATTENTION:
|
42 |
+
|
43 |
+
# Logit softcapping
|
44 |
+
@torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options)
|
45 |
+
def slow_attention_softcapping(Q, K, V, causal_mask, self, bsz, q_len):
|
46 |
+
n_heads = self.config.num_attention_heads
|
47 |
+
head_dim = self.head_dim
|
48 |
+
n_kv_heads = self.config.num_key_value_heads
|
49 |
+
n_groups = self.num_key_value_groups
|
50 |
+
|
51 |
+
# Grouped query attention
|
52 |
+
K = K[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, q_len, head_dim)
|
53 |
+
V = V[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, q_len, head_dim)
|
54 |
+
K = K.reshape(bsz, n_heads, q_len, head_dim)
|
55 |
+
V = V.reshape(bsz, n_heads, q_len, head_dim)
|
56 |
+
|
57 |
+
# See https://github.com/google/gemma_pytorch/commit/03e657582d17cb5a8617ebf333c1c16f3694670e
|
58 |
+
# Gemma 9b should use 256 and not 224 (hs / nah). 27b uses the below
|
59 |
+
# We default to using the config file itself
|
60 |
+
# s = self.config.hidden_size // self.config.num_attention_heads
|
61 |
+
s = self.config.query_pre_attn_scalar
|
62 |
+
t = self.config.attn_logit_softcapping
|
63 |
+
|
64 |
+
Q = Q * torch.tensor(s**-0.5, dtype = Q.dtype) # Follow Keras exactly
|
65 |
+
A = torch.matmul(Q, K.transpose(2, 3))
|
66 |
+
A = t * torch.tanh(A / t) # Logit softcapping
|
67 |
+
A += causal_mask[:q_len, :q_len]
|
68 |
+
# Much slower in torch compile!
|
69 |
+
# A.masked_fill_(causal_mask[:q_len, :q_len], -float("inf"))
|
70 |
+
A = torch.nn.functional.softmax(A, dim = -1, dtype = torch.float32).to(Q.dtype)
|
71 |
+
A = torch.matmul(A, V)
|
72 |
+
A = A.transpose(1, 2).contiguous()
|
73 |
+
A = A.reshape(bsz, q_len, n_heads*head_dim)
|
74 |
+
return A
|
75 |
+
pass
|
76 |
+
|
77 |
+
create_flex_attention_causal_mask = None
|
78 |
+
create_flex_attention_sliding_window_mask = None
|
79 |
+
else:
|
80 |
+
# See https://github.com/pytorch-labs/attention-gym/blob/main/examples/flex_attn.ipynb
|
81 |
+
# for more examples
|
82 |
+
# BSD 3-Clause License Copyright (c) 2023, Driss Guessous, Horace He et al
|
83 |
+
import functools, math
|
84 |
+
|
85 |
+
def generate_tanh_softcap(t):
|
86 |
+
def tanh_softcap(x, b, h, q_idx, kv_idx):
|
87 |
+
return t * torch.tanh(x / t)
|
88 |
+
return tanh_softcap
|
89 |
+
pass
|
90 |
+
def causal_masker(b, h, q_idx, kv_idx):
|
91 |
+
return q_idx >= kv_idx
|
92 |
+
pass
|
93 |
+
|
94 |
+
@functools.lru_cache
|
95 |
+
def sliding_window_masker(size = 4096):
|
96 |
+
def sliding_window(b, h, q_idx, kv_idx):
|
97 |
+
causal_mask = q_idx >= kv_idx
|
98 |
+
window_mask = q_idx - kv_idx <= size
|
99 |
+
return causal_mask & window_mask
|
100 |
+
return sliding_window
|
101 |
+
pass
|
102 |
+
|
103 |
+
@functools.lru_cache
|
104 |
+
def create_block_mask(mask, n = 128):
|
105 |
+
return _create_block_mask(
|
106 |
+
mask, 1, 1, n, n,
|
107 |
+
BLOCK_SIZE = 128,
|
108 |
+
_compile = True,
|
109 |
+
)
|
110 |
+
pass
|
111 |
+
|
112 |
+
def create_flex_attention_causal_mask(max_seq_length = 8192):
|
113 |
+
causal_mask = create_block_mask(causal_masker, max_seq_length)
|
114 |
+
return causal_mask
|
115 |
+
pass
|
116 |
+
|
117 |
+
def create_flex_attention_sliding_window_mask(max_seq_length = 8192, sliding_window = 4096):
|
118 |
+
sliding_masker = sliding_window_masker(sliding_window)
|
119 |
+
causal_mask = create_block_mask(sliding_masker, max_seq_length)
|
120 |
+
return causal_mask
|
121 |
+
pass
|
122 |
+
|
123 |
+
@functools.lru_cache
|
124 |
+
def flex_attention(s, t):
|
125 |
+
scale = 1.0 / math.sqrt(s)
|
126 |
+
score_mod = generate_tanh_softcap(t)
|
127 |
+
return functools.partial(
|
128 |
+
_flex_attention, score_mod = score_mod, scale = scale, enable_gqa = True,
|
129 |
+
)
|
130 |
+
pass
|
131 |
+
|
132 |
+
def slow_attention_softcapping(Q, K, V, causal_mask, self, bsz, q_len):
|
133 |
+
n_heads = self.config.num_attention_heads
|
134 |
+
head_dim = self.head_dim
|
135 |
+
s = self.config.query_pre_attn_scalar
|
136 |
+
t = self.config.attn_logit_softcapping
|
137 |
+
fx = flex_attention(s, t)
|
138 |
+
A = fx(query = Q, key = K, value = V, block_mask = causal_mask)
|
139 |
+
A = A.transpose(1, 2).contiguous()
|
140 |
+
A = A.reshape(bsz, q_len, n_heads*head_dim)
|
141 |
+
return A
|
142 |
+
pass
|
143 |
+
pass
|
144 |
+
|
145 |
+
|
146 |
+
torch_matmul = torch.matmul
|
147 |
+
torch_tanh = torch.tanh
|
148 |
+
torch_nn_functional_softmax = torch.nn.functional.softmax
|
149 |
+
def slow_inference_attention_softcapping(Q, K, V, causal_mask, self, bsz, q_len):
|
150 |
+
n_heads = self.config.num_attention_heads
|
151 |
+
head_dim = self.head_dim
|
152 |
+
n_kv_heads = self.config.num_key_value_heads
|
153 |
+
n_groups = self.num_key_value_groups
|
154 |
+
|
155 |
+
# Grouped query attention
|
156 |
+
K = K[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, q_len, head_dim)
|
157 |
+
V = V[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, q_len, head_dim)
|
158 |
+
K = K.reshape(bsz, n_heads, q_len, head_dim)
|
159 |
+
V = V.reshape(bsz, n_heads, q_len, head_dim)
|
160 |
+
|
161 |
+
# See https://github.com/google/gemma_pytorch/commit/03e657582d17cb5a8617ebf333c1c16f3694670e
|
162 |
+
# Gemma 9b should use 256 and not 224 (hs / nah). 27b uses the below
|
163 |
+
# We default to using the config file itself
|
164 |
+
# s = self.config.hidden_size // self.config.num_attention_heads
|
165 |
+
s = self.config.query_pre_attn_scalar
|
166 |
+
t = self.config.attn_logit_softcapping
|
167 |
+
|
168 |
+
Q = Q * torch.tensor(s**-0.5, dtype = Q.dtype) # Follow Keras exactly
|
169 |
+
A = torch_matmul(Q, K.transpose(2, 3))
|
170 |
+
|
171 |
+
# Logit softcapping
|
172 |
+
A /= t; torch_tanh(A, out = A); A *= t;
|
173 |
+
A += causal_mask[:q_len, :q_len]
|
174 |
+
# Much slower in torch compile!
|
175 |
+
# A.masked_fill_(causal_mask[:q_len, :q_len], -float("inf"))
|
176 |
+
A = torch_nn_functional_softmax(A, dim = -1, dtype = torch.float32).to(Q.dtype)
|
177 |
+
A = torch_matmul(A, V)
|
178 |
+
A = A.transpose(1, 2).contiguous()
|
179 |
+
A = A.reshape(bsz, q_len, n_heads*head_dim)
|
180 |
+
return A
|
181 |
+
pass
|
torch-ext/unsloth_kernels/geglu.py
ADDED
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import triton
|
16 |
+
import triton.language as tl
|
17 |
+
import torch
|
18 |
+
from .utils import (
|
19 |
+
calculate_settings,
|
20 |
+
triton_tanh,
|
21 |
+
torch_cuda_device,
|
22 |
+
)
|
23 |
+
|
24 |
+
|
25 |
+
@triton.jit
|
26 |
+
def _exact_forward_kernel(e, g, h, n_elements, BLOCK_SIZE : tl.constexpr,):
|
27 |
+
block_idx = tl.program_id(0)
|
28 |
+
offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
29 |
+
mask = offsets < n_elements
|
30 |
+
|
31 |
+
# f = 1/2 * e * (1 + erf(1/sqrt(2) * e))
|
32 |
+
# h = f * up
|
33 |
+
e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32)
|
34 |
+
g_row = tl.load(g + offsets, mask = mask, other = 0)#.to(tl.float32)
|
35 |
+
|
36 |
+
f_row = 0.5 * e_row * (tl.math.erf(tl.math.rsqrt(2.0) * e_row) + 1.0)
|
37 |
+
f_row = f_row.to(g_row.dtype) # Exact copy from HF
|
38 |
+
h_row = f_row * g_row
|
39 |
+
|
40 |
+
# Store h
|
41 |
+
tl.store(h + offsets, h_row, mask = mask)
|
42 |
+
pass
|
43 |
+
|
44 |
+
|
45 |
+
def geglu_exact_forward_kernel(gate, up):
|
46 |
+
batch, seq_len, hd = gate.shape
|
47 |
+
n_elements = gate.numel()
|
48 |
+
device = gate.device
|
49 |
+
out = torch.empty((batch, seq_len, hd), dtype = gate.dtype, device = device)
|
50 |
+
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
|
51 |
+
with torch_cuda_device(device):
|
52 |
+
_exact_forward_kernel[grid](gate, up, out, n_elements, BLOCK_SIZE = 1024,)
|
53 |
+
return out
|
54 |
+
pass
|
55 |
+
|
56 |
+
|
57 |
+
@triton.jit
|
58 |
+
def _exact_backward_kernel(DW, e, g, n_elements, BLOCK_SIZE : tl.constexpr,):
|
59 |
+
"""
|
60 |
+
f = 1/2 * e * (1 + erf(1/sqrt(2) * e))
|
61 |
+
h = f * up
|
62 |
+
|
63 |
+
df/de (with help of Wolfram :)
|
64 |
+
df/de = 1/2 * (1 + erf(1/sqrt(2) * e)) + 1/sqrt(2*pi) * e * exp(-1/2 * e^2)
|
65 |
+
|
66 |
+
Reuse via
|
67 |
+
f = 1/2 * (1 + erf(1/sqrt(2) * e)) * e
|
68 |
+
"""
|
69 |
+
block_idx = tl.program_id(0)
|
70 |
+
offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
71 |
+
mask = offsets < n_elements
|
72 |
+
|
73 |
+
DW_row = tl.load(DW + offsets, mask = mask, other = 0)#.to(tl.float32)
|
74 |
+
e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32)
|
75 |
+
g_row = tl.load(g + offsets, mask = mask, other = 0)#.to(tl.float32)
|
76 |
+
|
77 |
+
# Break e_row away for re-use
|
78 |
+
# f = 1/2 * e * (1 + erf(1/sqrt(2) * e))
|
79 |
+
f_partial_row = 0.5 * (tl.math.erf(tl.math.rsqrt(2.0) * e_row) + 1.0)
|
80 |
+
f_row = f_partial_row * e_row
|
81 |
+
|
82 |
+
f_row = f_row.to(DW_row.dtype)
|
83 |
+
# h = f * g
|
84 |
+
h_row = f_row * g_row
|
85 |
+
# df = DW * f
|
86 |
+
df_row = DW_row * f_row
|
87 |
+
# dg = DW * g
|
88 |
+
dg_row = DW_row * g_row
|
89 |
+
|
90 |
+
# df/de = 1/2 * (1 + erf(1/sqrt(2) * e)) + 1/sqrt(2*pi) * e * exp(-1/2 * e^2)
|
91 |
+
t = 0.3989422804014327 # 1/sqrt(2*pi)
|
92 |
+
df_de = f_partial_row + t * e_row * tl.exp(-0.5 * e_row * e_row)
|
93 |
+
|
94 |
+
de_row = dg_row.to(tl.float32) * df_de
|
95 |
+
de_row = de_row.to(DW_row.dtype)
|
96 |
+
|
97 |
+
# Store derivatives in buffers
|
98 |
+
tl.store(DW + offsets, h_row, mask = mask) # h = f * g
|
99 |
+
tl.store(e + offsets, df_row, mask = mask) # df = DW * f
|
100 |
+
tl.store(g + offsets, de_row, mask = mask) # de
|
101 |
+
pass
|
102 |
+
|
103 |
+
|
104 |
+
def geglu_exact_backward_kernel(DW, e, g):
|
105 |
+
batch_seq_len, hd = e.shape
|
106 |
+
n_elements = e.numel()
|
107 |
+
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
|
108 |
+
with torch_cuda_device(e.device):
|
109 |
+
_exact_backward_kernel[grid](DW, e, g, n_elements, BLOCK_SIZE = 1024,)
|
110 |
+
return DW, e, g
|
111 |
+
pass
|
112 |
+
|
113 |
+
|
114 |
+
@triton.jit
|
115 |
+
def _approx_forward_kernel(e, g, h, n_elements, BLOCK_SIZE : tl.constexpr,):
|
116 |
+
block_idx = tl.program_id(0)
|
117 |
+
offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
118 |
+
mask = offsets < n_elements
|
119 |
+
|
120 |
+
# f = 1/2 * e * (1 + tanh( sqrt(2/pi) * (x + 0.044715 * x^3 ) ))
|
121 |
+
# f = 1/2 * e * (1 + tanh( sqrt(2/pi) * x * (1 + 0.044715 * x^2 ) ))
|
122 |
+
# h = f * up
|
123 |
+
s = 0.7978845608028654 # math.sqrt(2 / math.pi)
|
124 |
+
|
125 |
+
e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32)
|
126 |
+
g_row = tl.load(g + offsets, mask = mask, other = 0)#.to(tl.float32)
|
127 |
+
|
128 |
+
f_row = 0.5 * e_row * (
|
129 |
+
triton_tanh(s * e_row * (1.0 + 0.044715 * e_row * e_row)) \
|
130 |
+
+ 1.0
|
131 |
+
)
|
132 |
+
f_row = f_row.to(g_row.dtype) # Exact copy from HF
|
133 |
+
h_row = f_row * g_row
|
134 |
+
|
135 |
+
# Store h
|
136 |
+
tl.store(h + offsets, h_row, mask = mask)
|
137 |
+
pass
|
138 |
+
|
139 |
+
|
140 |
+
def geglu_approx_forward_kernel(gate, up):
|
141 |
+
batch, seq_len, hd = gate.shape
|
142 |
+
n_elements = gate.numel()
|
143 |
+
device = gate.device
|
144 |
+
out = torch.empty((batch, seq_len, hd), dtype = gate.dtype, device = device)
|
145 |
+
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
|
146 |
+
with torch_cuda_device(device):
|
147 |
+
_approx_forward_kernel[grid](gate, up, out, n_elements, BLOCK_SIZE = 1024,)
|
148 |
+
return out
|
149 |
+
pass
|
150 |
+
|
151 |
+
|
152 |
+
@triton.jit
|
153 |
+
def _approx_backward_kernel(DW, e, g, n_elements, BLOCK_SIZE : tl.constexpr,):
|
154 |
+
"""
|
155 |
+
f = 1/2 * e * (1 + tanh( sqrt(2/pi) * x * (1 + 0.044715 * x^2 ) ))
|
156 |
+
h = f * up
|
157 |
+
|
158 |
+
df/de (with help from https://arxiv.org/pdf/2305.12073.pdf :))
|
159 |
+
df/de = 1/2 * [1 + tanh( sqrt(2/pi) * x * (1 + 0.044715 * x^2 ) )] +
|
160 |
+
1/2 * sech^2 [ sqrt(2/pi) * x * (1 + 0.044715 * x^2 ) ] * \
|
161 |
+
( sqrt(2/pi) * x * (1 + 0.044715 * x^2 * 3 ) )
|
162 |
+
|
163 |
+
Notice sech^2(x) = 1 - tanh^2(x)
|
164 |
+
So reuse tanh( sqrt(2/pi) * x * (1 + 0.044715 * x^2 ) )
|
165 |
+
|
166 |
+
See https://www.desmos.com/calculator/nqprfoni6x
|
167 |
+
"""
|
168 |
+
block_idx = tl.program_id(0)
|
169 |
+
offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
170 |
+
mask = offsets < n_elements
|
171 |
+
|
172 |
+
DW_row = tl.load(DW + offsets, mask = mask, other = 0)#.to(tl.float32)
|
173 |
+
e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32)
|
174 |
+
g_row = tl.load(g + offsets, mask = mask, other = 0)#.to(tl.float32)
|
175 |
+
|
176 |
+
# See https://www.desmos.com/calculator/nqprfoni6x
|
177 |
+
s = 0.7978845608028654 # math.sqrt(2 / math.pi)
|
178 |
+
a = s * e_row # a = sqrt(2 / pi) * x
|
179 |
+
b = a * 0.044715 * e_row * e_row # b = a * 0.044715 * x^2
|
180 |
+
T = 1.0 + triton_tanh(a + b)
|
181 |
+
T2 = 0.5 * T
|
182 |
+
# Q = 0.5 * -T * (T - 2.0) * (a + 3.0 * b)
|
183 |
+
Q2 = -T2 * (T - 2.0) * (a + 3.0 * b)
|
184 |
+
df_de = T2 + Q2 # 1/2 * (T + Q)
|
185 |
+
|
186 |
+
# f = 1/2 * e * (1 + tanh( sqrt(2/pi) * (x + 0.044715 * x^3 ) ))
|
187 |
+
f_row = T2 * e_row
|
188 |
+
f_row = f_row.to(DW_row.dtype)
|
189 |
+
# h = f * g
|
190 |
+
h_row = f_row * g_row
|
191 |
+
# df = DW * f
|
192 |
+
df_row = DW_row * f_row
|
193 |
+
# dg = DW * g
|
194 |
+
dg_row = DW_row * g_row
|
195 |
+
|
196 |
+
de_row = dg_row.to(tl.float32) * df_de
|
197 |
+
de_row = de_row.to(DW_row.dtype)
|
198 |
+
|
199 |
+
# Store derivatives in buffers
|
200 |
+
tl.store(DW + offsets, h_row, mask = mask) # h = f * g
|
201 |
+
tl.store(e + offsets, df_row, mask = mask) # df = DW * f
|
202 |
+
tl.store(g + offsets, de_row, mask = mask) # de
|
203 |
+
pass
|
204 |
+
|
205 |
+
|
206 |
+
def geglu_approx_backward_kernel(DW, e, g):
|
207 |
+
batch_seq_len, hd = e.shape
|
208 |
+
n_elements = e.numel()
|
209 |
+
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
|
210 |
+
with torch_cuda_device(e.device):
|
211 |
+
_approx_backward_kernel[grid](DW, e, g, n_elements, BLOCK_SIZE = 1024,)
|
212 |
+
return DW, e, g
|
213 |
+
pass
|
torch-ext/unsloth_kernels/layernorm.py
ADDED
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
|
2 |
+
# Copyright 2024-present Andrej Karpathy & the llm.c team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import triton
|
17 |
+
import triton.language as tl
|
18 |
+
import torch
|
19 |
+
from .utils import calculate_settings, torch_cuda_device
|
20 |
+
from unsloth_zoo.patching_utils import (
|
21 |
+
patch_layernorm,
|
22 |
+
)
|
23 |
+
|
24 |
+
|
25 |
+
@triton.jit
|
26 |
+
def layernorm_forward(
|
27 |
+
Y, Y_row_stride,
|
28 |
+
X, X_row_stride,
|
29 |
+
W,
|
30 |
+
b,
|
31 |
+
r,
|
32 |
+
mu,
|
33 |
+
n_cols : tl.constexpr,
|
34 |
+
eps : tl.constexpr,
|
35 |
+
BLOCK_SIZE : tl.constexpr
|
36 |
+
):
|
37 |
+
row_idx = tl.program_id(0)
|
38 |
+
col_offsets = tl.arange(0, BLOCK_SIZE)
|
39 |
+
mask = col_offsets < n_cols
|
40 |
+
|
41 |
+
Y += row_idx * Y_row_stride
|
42 |
+
X += row_idx * X_row_stride
|
43 |
+
r += row_idx
|
44 |
+
mu += row_idx
|
45 |
+
|
46 |
+
# According to https://pytorch.org/torchtune/stable/_modules/torchtune/modules/layer_norm.html#Fp32LayerNorm, all modules
|
47 |
+
# are in float32!
|
48 |
+
X_row = tl.load(X + col_offsets, mask = mask, other = 0).to(tl.float32)
|
49 |
+
W_row = tl.load(W + col_offsets, mask = mask, other = 0).to(tl.float32)
|
50 |
+
b_row = tl.load(b + col_offsets, mask = mask, other = 0).to(tl.float32)
|
51 |
+
|
52 |
+
mean_X = tl.sum(X_row, axis = 0) / n_cols
|
53 |
+
# (X[0] - mean) == -mean so we need to mask it out
|
54 |
+
XX = tl.where(mask, X_row - mean_X, 0)
|
55 |
+
row_var = tl.sum(XX * XX, axis = 0) / n_cols
|
56 |
+
inv_var = tl.math.rsqrt(row_var + eps)
|
57 |
+
tl.store (r, inv_var)
|
58 |
+
tl.store (mu, mean_X)
|
59 |
+
output = (XX * inv_var) * W_row + b_row
|
60 |
+
tl.store(Y + col_offsets, output, mask = mask)
|
61 |
+
pass
|
62 |
+
|
63 |
+
|
64 |
+
@triton.jit
|
65 |
+
def layernorm_backward(
|
66 |
+
dY, dY_row_stride,
|
67 |
+
X, X_row_stride,
|
68 |
+
W,
|
69 |
+
b,
|
70 |
+
r,
|
71 |
+
mu,
|
72 |
+
n_cols : tl.constexpr,
|
73 |
+
eps : tl.constexpr,
|
74 |
+
BLOCK_SIZE : tl.constexpr
|
75 |
+
):
|
76 |
+
# Approximately follows https://github.com/karpathy/llm.c/blob/master/doc/layernorm/layernorm.md
|
77 |
+
row_idx = tl.program_id(0)
|
78 |
+
col_offsets = tl.arange(0, BLOCK_SIZE)
|
79 |
+
mask = col_offsets < n_cols
|
80 |
+
|
81 |
+
dY += row_idx * dY_row_stride
|
82 |
+
X += row_idx * X_row_stride
|
83 |
+
r += row_idx
|
84 |
+
mu += row_idx
|
85 |
+
|
86 |
+
# According to https://pytorch.org/torchtune/stable/_modules/torchtune/modules/layer_norm.html#Fp32LayerNorm, all modules
|
87 |
+
# are in float32!
|
88 |
+
dY_row = tl.load(dY + col_offsets, mask = mask, other = 0).to(tl.float32)
|
89 |
+
X_row = tl.load(X + col_offsets, mask = mask, other = 0).to(tl.float32)
|
90 |
+
W_row = tl.load(W + col_offsets, mask = mask, other = 0).to(tl.float32)
|
91 |
+
b_row = tl.load(b + col_offsets, mask = mask, other = 0).to(tl.float32)
|
92 |
+
|
93 |
+
inv_var = tl.load(r) .to(tl.float32)
|
94 |
+
mean = tl.load(mu).to(tl.float32)
|
95 |
+
normed = (X_row - mean) * inv_var
|
96 |
+
dY_W = dY_row * W_row
|
97 |
+
dX_row = dY_W - tl.sum(dY_W, axis = 0) / n_cols - normed * tl.sum(dY_W * normed, axis = 0) / n_cols
|
98 |
+
dX_row = dX_row * inv_var
|
99 |
+
tl.store(dY + col_offsets, dX_row, mask = mask)
|
100 |
+
pass
|
101 |
+
|
102 |
+
|
103 |
+
class Fast_Layernorm(torch.autograd.Function):
|
104 |
+
@staticmethod
|
105 |
+
def forward(ctx, X, W, b, eps):
|
106 |
+
shape = X.shape
|
107 |
+
dim = shape[-1]
|
108 |
+
X = X.view(-1, dim)
|
109 |
+
n_rows, n_cols = X.shape
|
110 |
+
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
111 |
+
device = X.device
|
112 |
+
Y = torch.empty((n_rows, n_cols), dtype = X.dtype, device = device)
|
113 |
+
r = torch.empty(n_rows, dtype = torch.float32, device = device)
|
114 |
+
mu = torch.empty(n_rows, dtype = torch.float32, device = device)
|
115 |
+
|
116 |
+
with torch_cuda_device(device):
|
117 |
+
layernorm_forward[(n_rows,)](
|
118 |
+
Y, Y.stride(0),
|
119 |
+
X, X.stride(0),
|
120 |
+
W,
|
121 |
+
b,
|
122 |
+
r,
|
123 |
+
mu,
|
124 |
+
n_cols, eps,
|
125 |
+
BLOCK_SIZE = BLOCK_SIZE,
|
126 |
+
num_warps = num_warps,
|
127 |
+
)
|
128 |
+
ctx.eps = eps
|
129 |
+
ctx.BLOCK_SIZE = BLOCK_SIZE
|
130 |
+
ctx.num_warps = num_warps
|
131 |
+
ctx.save_for_backward(X, W, b, r, mu)
|
132 |
+
return Y.view(*shape)
|
133 |
+
pass
|
134 |
+
|
135 |
+
@staticmethod
|
136 |
+
def backward(ctx, dY):
|
137 |
+
shape = dY.shape
|
138 |
+
dim = shape[-1]
|
139 |
+
dY = dY.view(-1, dim)
|
140 |
+
X, W, b, r, mu = ctx.saved_tensors
|
141 |
+
n_rows, n_cols = dY.shape
|
142 |
+
|
143 |
+
with torch_cuda_device(dY.device):
|
144 |
+
layernorm_backward[(n_rows,)](
|
145 |
+
dY, dY.stride(0),
|
146 |
+
X, X .stride(0),
|
147 |
+
W,
|
148 |
+
b,
|
149 |
+
r,
|
150 |
+
mu,
|
151 |
+
n_cols, ctx.eps,
|
152 |
+
BLOCK_SIZE = ctx.BLOCK_SIZE,
|
153 |
+
num_warps = ctx.num_warps,
|
154 |
+
)
|
155 |
+
dX = dY.view(*shape)
|
156 |
+
return dX, None, None, None, None
|
157 |
+
pass
|
158 |
+
pass
|
159 |
+
|
160 |
+
|
161 |
+
def fast_layernorm(layernorm, X):
|
162 |
+
assert(layernorm.elementwise_affine is True)
|
163 |
+
W = layernorm.weight
|
164 |
+
bias = layernorm.bias
|
165 |
+
eps = layernorm.variance_epsilon if \
|
166 |
+
hasattr(layernorm, "variance_epsilon") \
|
167 |
+
else layernorm.eps
|
168 |
+
out = Fast_Layernorm.apply(X, W, bias, eps)
|
169 |
+
return out
|
170 |
+
pass
|
torch-ext/unsloth_kernels/rms_layernorm.py
ADDED
@@ -0,0 +1,261 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import triton
|
16 |
+
import triton.language as tl
|
17 |
+
import torch
|
18 |
+
from .utils import calculate_settings, torch_cuda_device
|
19 |
+
|
20 |
+
@triton.jit
|
21 |
+
def _rms_layernorm_forward(
|
22 |
+
Y, Y_row_stride,
|
23 |
+
X, X_row_stride,
|
24 |
+
W, W_row_stride,
|
25 |
+
r, r_row_stride : tl.constexpr,
|
26 |
+
n_cols : tl.constexpr,
|
27 |
+
eps : tl.constexpr,
|
28 |
+
BLOCK_SIZE : tl.constexpr,
|
29 |
+
):
|
30 |
+
"""
|
31 |
+
Fast RMS Layernorm kernel
|
32 |
+
Inspiration from a Triton tutorial:
|
33 |
+
https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
|
34 |
+
"""
|
35 |
+
row_idx = tl.program_id(0)
|
36 |
+
col_offsets = tl.arange(0, BLOCK_SIZE)
|
37 |
+
mask = col_offsets < n_cols
|
38 |
+
|
39 |
+
Y += row_idx * Y_row_stride
|
40 |
+
X += row_idx * X_row_stride
|
41 |
+
r += row_idx * r_row_stride
|
42 |
+
|
43 |
+
X_row = tl.load(X + col_offsets, mask = mask, other = 0).to(tl.float32)
|
44 |
+
W_row = tl.load(W + col_offsets, mask = mask, other = 0)#.to(tl.float32)
|
45 |
+
|
46 |
+
row_var = tl.sum(X_row * X_row, axis = 0) / n_cols
|
47 |
+
inv_var = tl.math.rsqrt(row_var + eps)
|
48 |
+
tl.store(r, inv_var)
|
49 |
+
normed = X_row * inv_var
|
50 |
+
normed = normed.to(W_row.dtype) # Exact copy from HF
|
51 |
+
output = normed * W_row
|
52 |
+
tl.store(Y + col_offsets, output, mask = mask)
|
53 |
+
pass
|
54 |
+
|
55 |
+
|
56 |
+
def _rms_layernorm_backward(
|
57 |
+
dY, dY_row_stride,
|
58 |
+
dX, dX_row_stride,
|
59 |
+
X, X_row_stride,
|
60 |
+
W, W_row_stride,
|
61 |
+
r, r_row_stride : tl.constexpr,
|
62 |
+
# dW, dW_row_stride,
|
63 |
+
n_cols : tl.constexpr,
|
64 |
+
eps : tl.constexpr,
|
65 |
+
GEMMA : tl.constexpr,
|
66 |
+
BLOCK_SIZE : tl.constexpr,
|
67 |
+
):
|
68 |
+
"""
|
69 |
+
Fast RMS Layernorm kernel for the backward pass
|
70 |
+
Inspiration from a Triton tutorial:
|
71 |
+
https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
|
72 |
+
"""
|
73 |
+
row_idx = tl.program_id(0)
|
74 |
+
col_offsets = tl.arange(0, BLOCK_SIZE)
|
75 |
+
mask = col_offsets < n_cols
|
76 |
+
|
77 |
+
dY += row_idx * dY_row_stride
|
78 |
+
X += row_idx * X_row_stride
|
79 |
+
r += row_idx * r_row_stride
|
80 |
+
|
81 |
+
if GEMMA: dX += row_idx * dY_row_stride
|
82 |
+
else: dX = dY
|
83 |
+
|
84 |
+
dY_row = tl.load(dY + col_offsets, mask = mask, other = 0).to(tl.float32)
|
85 |
+
X_row = tl.load(X + col_offsets, mask = mask, other = 0).to(tl.float32)
|
86 |
+
W_row = tl.load(W + col_offsets, mask = mask, other = 0).to(tl.float32)
|
87 |
+
|
88 |
+
# Get saved row variance
|
89 |
+
inv_var = tl.load(r).to(tl.float32)
|
90 |
+
normed = X_row * inv_var
|
91 |
+
|
92 |
+
if GEMMA: dY_W = dY_row * (W_row + 1.0)
|
93 |
+
else: dY_W = dY_row * W_row
|
94 |
+
|
95 |
+
rowsum_dY_normed = tl.sum(dY_W * normed, axis = 0)
|
96 |
+
output = inv_var/n_cols * (n_cols*dY_W - normed*rowsum_dY_normed)
|
97 |
+
tl.store(dX + col_offsets, output, mask = mask)
|
98 |
+
pass
|
99 |
+
_rms_layernorm_backward = triton.jit(_rms_layernorm_backward)
|
100 |
+
_rms_layernorm_backward = triton.heuristics(
|
101 |
+
{
|
102 |
+
"GEMMA": lambda args: bool(args["GEMMA"]),
|
103 |
+
}
|
104 |
+
)(_rms_layernorm_backward)
|
105 |
+
|
106 |
+
|
107 |
+
@triton.jit
|
108 |
+
def _gemma_rms_layernorm_forward(
|
109 |
+
Y, Y_row_stride,
|
110 |
+
X, X_row_stride,
|
111 |
+
W, W_row_stride,
|
112 |
+
r, r_row_stride : tl.constexpr,
|
113 |
+
n_cols : tl.constexpr,
|
114 |
+
eps : tl.constexpr,
|
115 |
+
BLOCK_SIZE : tl.constexpr,
|
116 |
+
):
|
117 |
+
# Copies https://github.com/google-deepmind/gemma/blob/main/gemma/layers.py#L31
|
118 |
+
# and https://github.com/keras-team/keras-nlp/blob/v0.8.2/keras_nlp/models/gemma/rms_normalization.py#L33
|
119 |
+
# exactly. Essentially all in float32!
|
120 |
+
row_idx = tl.program_id(0)
|
121 |
+
col_offsets = tl.arange(0, BLOCK_SIZE)
|
122 |
+
mask = col_offsets < n_cols
|
123 |
+
|
124 |
+
Y += row_idx * Y_row_stride
|
125 |
+
X += row_idx * X_row_stride
|
126 |
+
r += row_idx * r_row_stride
|
127 |
+
|
128 |
+
X_row = tl.load(X + col_offsets, mask = mask, other = 0).to(tl.float32)
|
129 |
+
W_row = tl.load(W + col_offsets, mask = mask, other = 0).to(tl.float32)
|
130 |
+
|
131 |
+
row_var = tl.sum(X_row * X_row, axis = 0) / n_cols
|
132 |
+
inv_var = tl.math.rsqrt(row_var + eps)
|
133 |
+
tl.store(r, inv_var)
|
134 |
+
normed = X_row * inv_var
|
135 |
+
output = normed * (W_row + 1.0)
|
136 |
+
|
137 |
+
tl.store(Y + col_offsets, output, mask = mask)
|
138 |
+
pass
|
139 |
+
|
140 |
+
|
141 |
+
class Fast_RMS_Layernorm(torch.autograd.Function):
|
142 |
+
@staticmethod
|
143 |
+
def forward(ctx, X : torch.Tensor, W : torch.Tensor, eps : float, gemma : bool = False):
|
144 |
+
shape = X.shape
|
145 |
+
dim : int = shape[-1]
|
146 |
+
X = X.view(-1, dim)
|
147 |
+
n_rows : int
|
148 |
+
n_cols : int
|
149 |
+
n_rows, n_cols = X.shape
|
150 |
+
BLOCK_SIZE : int
|
151 |
+
num_warps : int
|
152 |
+
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
153 |
+
device = X.device
|
154 |
+
|
155 |
+
Y = torch.empty((n_rows, n_cols), dtype = X.dtype, device = device)
|
156 |
+
r = torch.empty(n_rows, dtype = torch.float32, device = device)
|
157 |
+
|
158 |
+
fx = _gemma_rms_layernorm_forward if gemma else _rms_layernorm_forward
|
159 |
+
with torch_cuda_device(device):
|
160 |
+
fx[(n_rows,)](
|
161 |
+
Y, Y.stride(0),
|
162 |
+
X, X.stride(0),
|
163 |
+
W, W.stride(0),
|
164 |
+
r, r.stride(0),
|
165 |
+
n_cols, eps,
|
166 |
+
BLOCK_SIZE = BLOCK_SIZE,
|
167 |
+
num_warps = num_warps,
|
168 |
+
)
|
169 |
+
ctx.eps = eps
|
170 |
+
ctx.BLOCK_SIZE = BLOCK_SIZE
|
171 |
+
ctx.num_warps = num_warps
|
172 |
+
ctx.GEMMA = gemma
|
173 |
+
ctx.save_for_backward(X, W, r)
|
174 |
+
return Y.view(*shape)
|
175 |
+
pass
|
176 |
+
|
177 |
+
@staticmethod
|
178 |
+
def backward(ctx, dY : torch.Tensor):
|
179 |
+
shape = dY.shape
|
180 |
+
dim : int = shape[-1]
|
181 |
+
dY = dY.view(-1, dim)
|
182 |
+
X, W, r = ctx.saved_tensors
|
183 |
+
n_rows : int
|
184 |
+
n_cols : int
|
185 |
+
n_rows, n_cols = dY.shape
|
186 |
+
# dW = X
|
187 |
+
dX = torch.empty_like(dY) if ctx.GEMMA else dY
|
188 |
+
|
189 |
+
with torch_cuda_device(dY.device):
|
190 |
+
_rms_layernorm_backward[(n_rows,)](
|
191 |
+
dY, dY.stride(0),
|
192 |
+
dX, dX.stride(0),
|
193 |
+
X, X .stride(0),
|
194 |
+
W, W .stride(0),
|
195 |
+
r, r .stride(0),
|
196 |
+
# dW, dW.stride(0),
|
197 |
+
n_cols, ctx.eps,
|
198 |
+
GEMMA = ctx.GEMMA,
|
199 |
+
BLOCK_SIZE = ctx.BLOCK_SIZE,
|
200 |
+
num_warps = ctx.num_warps,
|
201 |
+
)
|
202 |
+
dX = dX.view(*shape)
|
203 |
+
return dX, None, None, None
|
204 |
+
pass
|
205 |
+
pass
|
206 |
+
|
207 |
+
|
208 |
+
# [TODO] Unsure why RMS Layernorm is not torch.compiling properly
|
209 |
+
@torch.compiler.disable
|
210 |
+
def fast_rms_layernorm(layernorm, X : torch.Tensor, gemma : bool = False):
|
211 |
+
W : torch.Tensor = layernorm.weight
|
212 |
+
eps : float = layernorm.variance_epsilon if \
|
213 |
+
hasattr(layernorm, "variance_epsilon") \
|
214 |
+
else layernorm.eps
|
215 |
+
out = Fast_RMS_Layernorm.apply(X, W, eps, gemma)
|
216 |
+
return out
|
217 |
+
pass
|
218 |
+
|
219 |
+
|
220 |
+
from transformers.models.llama.modeling_llama import LlamaRMSNorm
|
221 |
+
class Unsloth_LlamaRMSNorm(LlamaRMSNorm):
|
222 |
+
def forward(self, X):
|
223 |
+
return fast_rms_layernorm(self, X, gemma = False)
|
224 |
+
pass
|
225 |
+
pass
|
226 |
+
|
227 |
+
try:
|
228 |
+
from transformers.models.mllama.modeling_mllama import MllamaTextRMSNorm
|
229 |
+
class Unsloth_MllamaTextRMSNorm(MllamaTextRMSNorm):
|
230 |
+
def forward(self, X):
|
231 |
+
return fast_rms_layernorm(self, X, gemma = False)
|
232 |
+
pass
|
233 |
+
pass
|
234 |
+
except:
|
235 |
+
pass
|
236 |
+
pass
|
237 |
+
|
238 |
+
def patch_rms_layernorm():
|
239 |
+
import transformers.models.llama.modeling_llama
|
240 |
+
transformers.models.llama.modeling_llama.LlamaRMSNorm = Unsloth_LlamaRMSNorm
|
241 |
+
try:
|
242 |
+
import transformers.models.mllama.modeling_mllama
|
243 |
+
transformers.models.mllama.modeling_mllama.MllamaTextRMSNorm = Unsloth_MllamaTextRMSNorm
|
244 |
+
except:
|
245 |
+
pass
|
246 |
+
return
|
247 |
+
pass
|
248 |
+
|
249 |
+
|
250 |
+
def unpatch_rms_layernorm():
|
251 |
+
import transformers.models.llama.modeling_llama
|
252 |
+
transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm
|
253 |
+
try:
|
254 |
+
import transformers.models.mllama.modeling_mllama
|
255 |
+
transformers.models.mllama.modeling_mllama.MllamaTextRMSNorm = MllamaTextRMSNorm
|
256 |
+
except:
|
257 |
+
pass
|
258 |
+
return
|
259 |
+
pass
|
260 |
+
|
261 |
+
|
torch-ext/unsloth_kernels/rope_embedding.py
ADDED
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import triton
|
16 |
+
import triton.language as tl
|
17 |
+
import torch
|
18 |
+
from .utils import calculate_settings, torch_cuda_device
|
19 |
+
ROPE_GROUP_SIZE : int = 4
|
20 |
+
|
21 |
+
def _rope_embedding(
|
22 |
+
Q, Q_row_stride,
|
23 |
+
cos, cos_row_stride,
|
24 |
+
sin, sin_row_stride,
|
25 |
+
seqlen,
|
26 |
+
head_dim : tl.constexpr,
|
27 |
+
n_heads : tl.constexpr,
|
28 |
+
BACKWARD_PASS : tl.constexpr,
|
29 |
+
BLOCK_SIZE : tl.constexpr,
|
30 |
+
):
|
31 |
+
"""
|
32 |
+
Calculates the RoPE Embedding quickly
|
33 |
+
RoPE is Q * cos + rotate_half(Q) * sin
|
34 |
+
See our blog post for more info
|
35 |
+
"""
|
36 |
+
ROPE_GROUP_SIZE = 4
|
37 |
+
row_position = tl.program_id(0)
|
38 |
+
group_head_position = tl.program_id(1)
|
39 |
+
col_offsets = tl.arange(0, BLOCK_SIZE)
|
40 |
+
half_head_dim = head_dim // 2
|
41 |
+
mask = col_offsets < half_head_dim
|
42 |
+
|
43 |
+
sin1 = tl.load(sin + (row_position % seqlen)*sin_row_stride + \
|
44 |
+
half_head_dim*0 + col_offsets, mask = mask, other = 0)
|
45 |
+
cos1 = tl.load(cos + (row_position % seqlen)*cos_row_stride + \
|
46 |
+
half_head_dim*0 + col_offsets, mask = mask, other = 0)
|
47 |
+
|
48 |
+
if BACKWARD_PASS:
|
49 |
+
# See our blog post for more info.
|
50 |
+
sin1 = -sin1
|
51 |
+
pass
|
52 |
+
|
53 |
+
# [TODO] Autotune ROPE_GROUP_SIZE to be 1, 2, 4, 8
|
54 |
+
head_start = group_head_position * ROPE_GROUP_SIZE
|
55 |
+
head_end = min((head_start + ROPE_GROUP_SIZE), n_heads)
|
56 |
+
|
57 |
+
# 10% Faster kernel from [HuyNguyen-hust](https://github.com/unslothai/unsloth/pull/238)
|
58 |
+
for k in range(head_start, head_end):
|
59 |
+
offs_q1 = row_position * Q_row_stride + k * head_dim + col_offsets
|
60 |
+
offs_q2 = row_position * Q_row_stride + k * head_dim + col_offsets + half_head_dim
|
61 |
+
|
62 |
+
# For Gemma - sometimes RoPE must be done in float32 and not bfloat16
|
63 |
+
Q1 = tl.load(Q + offs_q1, mask = mask, other = 0).to(sin1.dtype)
|
64 |
+
Q2 = tl.load(Q + offs_q2, mask = mask, other = 0).to(sin1.dtype)
|
65 |
+
|
66 |
+
tl.store(Q + offs_q1, Q1*cos1 - Q2*sin1, mask = mask)
|
67 |
+
tl.store(Q + offs_q2, Q2*cos1 + Q1*sin1, mask = mask)
|
68 |
+
pass
|
69 |
+
pass
|
70 |
+
_rope_embedding = triton.jit(_rope_embedding)
|
71 |
+
_rope_embedding = triton.heuristics(
|
72 |
+
{
|
73 |
+
"BACKWARD_PASS": lambda args: bool(args["BACKWARD_PASS"]),
|
74 |
+
}
|
75 |
+
)(_rope_embedding)
|
76 |
+
|
77 |
+
|
78 |
+
class Fast_RoPE_Embedding(torch.autograd.Function):
|
79 |
+
@staticmethod
|
80 |
+
def forward(ctx, Q, cos, sin):
|
81 |
+
cos, sin = cos.squeeze(), sin.squeeze()
|
82 |
+
batch : int
|
83 |
+
seq_len : int
|
84 |
+
n_heads : int
|
85 |
+
head_dim : int
|
86 |
+
batch, seq_len, n_heads, head_dim = Q.shape
|
87 |
+
Q = Q.view(batch*seq_len, n_heads*head_dim)
|
88 |
+
n_rows : int
|
89 |
+
n_cols : int
|
90 |
+
n_rows, n_cols = Q.shape
|
91 |
+
assert(seq_len <= cos.shape[0])
|
92 |
+
|
93 |
+
# [TODO] Changing blocksize to head_dim//2 seems to have
|
94 |
+
# some concurrency / un-deterministic issues.
|
95 |
+
BLOCK_SIZE, num_warps = calculate_settings(head_dim//2) # (head_dim//2)
|
96 |
+
|
97 |
+
# group_size = 4 # 4 or 8, too large group_size can hurt performance.
|
98 |
+
div : int
|
99 |
+
mod : int
|
100 |
+
div, mod = divmod(n_heads, ROPE_GROUP_SIZE)
|
101 |
+
n_groups : int = div + (mod != 0)
|
102 |
+
|
103 |
+
with torch_cuda_device(Q.device):
|
104 |
+
_rope_embedding[(n_rows, n_groups, )](
|
105 |
+
Q, Q.stride(0),
|
106 |
+
cos, cos.stride(0),
|
107 |
+
sin, sin.stride(0),
|
108 |
+
seq_len,
|
109 |
+
head_dim, n_heads,
|
110 |
+
BACKWARD_PASS = False,
|
111 |
+
BLOCK_SIZE = BLOCK_SIZE,
|
112 |
+
num_warps = num_warps,
|
113 |
+
)
|
114 |
+
ctx.BLOCK_SIZE = BLOCK_SIZE
|
115 |
+
ctx.num_warps = num_warps
|
116 |
+
ctx.n_groups = n_groups
|
117 |
+
ctx.cos = cos
|
118 |
+
ctx.sin = sin
|
119 |
+
return Q.view(batch, seq_len, n_heads, head_dim)
|
120 |
+
pass
|
121 |
+
|
122 |
+
@staticmethod
|
123 |
+
def backward(ctx, dY):
|
124 |
+
batch : int
|
125 |
+
seq_len : int
|
126 |
+
n_heads : int
|
127 |
+
head_dim : int
|
128 |
+
batch, seq_len, n_heads, head_dim = dY.shape
|
129 |
+
dY = dY.reshape(batch*seq_len, n_heads*head_dim)
|
130 |
+
# Must be reshape not view
|
131 |
+
n_rows : int
|
132 |
+
n_cols : int
|
133 |
+
n_rows, n_cols = dY.shape
|
134 |
+
|
135 |
+
cos = ctx.cos
|
136 |
+
sin = ctx.sin
|
137 |
+
|
138 |
+
with torch_cuda_device(dY.device):
|
139 |
+
_rope_embedding[(n_rows, ctx.n_groups, )](
|
140 |
+
dY, dY .stride(0),
|
141 |
+
cos, cos.stride(0),
|
142 |
+
sin, sin.stride(0),
|
143 |
+
seq_len, head_dim, n_heads,
|
144 |
+
BACKWARD_PASS = True,
|
145 |
+
BLOCK_SIZE = ctx.BLOCK_SIZE,
|
146 |
+
num_warps = ctx.num_warps,
|
147 |
+
)
|
148 |
+
dY = dY.view(batch, seq_len, n_heads, head_dim)
|
149 |
+
return dY, None, None,
|
150 |
+
pass
|
151 |
+
pass
|
152 |
+
|
153 |
+
# [TODO] Unsure why RoPE Embedding is not torch.compiling properly
|
154 |
+
@torch.compiler.disable
|
155 |
+
def fast_rope_embedding(Q, K, cos, sin):
|
156 |
+
Q = Fast_RoPE_Embedding.apply(Q.transpose(1, 2), cos, sin).transpose(1, 2)
|
157 |
+
K = Fast_RoPE_Embedding.apply(K.transpose(1, 2), cos, sin).transpose(1, 2)
|
158 |
+
return Q, K
|
159 |
+
pass
|
160 |
+
|
161 |
+
|
162 |
+
class Slow_RoPE_Embedding(torch.autograd.Function):
|
163 |
+
@staticmethod
|
164 |
+
def forward(ctx, Q, cos, sin, position_ids):
|
165 |
+
if position_ids is not None:
|
166 |
+
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
|
167 |
+
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
|
168 |
+
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
|
169 |
+
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
|
170 |
+
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
|
171 |
+
|
172 |
+
# Q * cos + rotate_half(Q) * sin
|
173 |
+
half = Q.shape[-1]//2
|
174 |
+
RH_Q = torch.cat((-Q[..., half:], Q[..., :half]), dim = -1)
|
175 |
+
Q *= cos
|
176 |
+
Q.addcmul_(RH_Q, sin)
|
177 |
+
# RH_Q *= sin
|
178 |
+
# Q += RH_Q
|
179 |
+
ctx.save_for_backward(cos, sin)
|
180 |
+
return Q
|
181 |
+
pass
|
182 |
+
|
183 |
+
@staticmethod
|
184 |
+
def backward(ctx, dY):
|
185 |
+
cos, sin = ctx.saved_tensors
|
186 |
+
# Q * cos + rotate_half.T(Q) * sin
|
187 |
+
half = dY.shape[-1]//2
|
188 |
+
RH_dY = torch.cat((dY[..., half:], -dY[..., :half]), dim = -1)
|
189 |
+
dY *= cos
|
190 |
+
dY.addcmul_(RH_dY, sin)
|
191 |
+
# RH_dY *= sin
|
192 |
+
# dY += RH_dY
|
193 |
+
return dY, None, None, None
|
194 |
+
pass
|
195 |
+
pass
|
196 |
+
|
197 |
+
|
198 |
+
def inplace_rope_embedding(Q, K, cos, sin, position_ids):
|
199 |
+
Q = Slow_RoPE_Embedding.apply(Q, cos, sin, position_ids)
|
200 |
+
K = Slow_RoPE_Embedding.apply(K, cos, sin, position_ids)
|
201 |
+
return Q, K
|
202 |
+
pass
|
torch-ext/unsloth_kernels/swiglu.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import triton
|
16 |
+
import triton.language as tl
|
17 |
+
import torch
|
18 |
+
from .utils import calculate_settings, torch_cuda_device
|
19 |
+
|
20 |
+
|
21 |
+
@triton.jit
|
22 |
+
def _fg_kernel(e, g, h, n_elements, BLOCK_SIZE : tl.constexpr,):
|
23 |
+
block_idx = tl.program_id(0)
|
24 |
+
offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
25 |
+
mask = offsets < n_elements
|
26 |
+
|
27 |
+
e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32)
|
28 |
+
g_row = tl.load(g + offsets, mask = mask, other = 0)#.to(tl.float32)
|
29 |
+
|
30 |
+
# f = e * sigmoid(e)
|
31 |
+
f_row = e_row * tl.sigmoid(e_row) # e_row / (1 + tl.exp(-e_row))
|
32 |
+
f_row = f_row.to(g_row.dtype) # Exact copy from HF
|
33 |
+
# h = f * g
|
34 |
+
h_row = f_row * g_row
|
35 |
+
|
36 |
+
# Store h
|
37 |
+
tl.store(h + offsets, h_row, mask = mask)
|
38 |
+
pass
|
39 |
+
|
40 |
+
|
41 |
+
def swiglu_fg_kernel(e, g):
|
42 |
+
batch, seq_len, hd = e.shape
|
43 |
+
n_elements = e.numel()
|
44 |
+
h = torch.empty((batch, seq_len, hd), dtype = e.dtype, device = e.device)
|
45 |
+
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
|
46 |
+
with torch_cuda_device(e.device):
|
47 |
+
_fg_kernel[grid](e, g, h, n_elements, BLOCK_SIZE = 1024,)
|
48 |
+
return h
|
49 |
+
pass
|
50 |
+
|
51 |
+
|
52 |
+
@triton.jit
|
53 |
+
def _DWf_DW_dfg_kernel(DW, e, g, n_elements, BLOCK_SIZE : tl.constexpr,):
|
54 |
+
"""
|
55 |
+
e = e.float()
|
56 |
+
se = 1.0 / (1.0 + torch.exp(-e))
|
57 |
+
f = (se * e).to(dtype)
|
58 |
+
h = f * g
|
59 |
+
df = DW * f
|
60 |
+
dg = DW * g
|
61 |
+
de = (dg.float() * se * (1.0 + e * (1.0 - se))).to(dtype)
|
62 |
+
"""
|
63 |
+
block_idx = tl.program_id(0)
|
64 |
+
offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
65 |
+
mask = offsets < n_elements
|
66 |
+
|
67 |
+
DW_row = tl.load(DW + offsets, mask = mask, other = 0)#.to(tl.float32)
|
68 |
+
e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32)
|
69 |
+
g_row = tl.load(g + offsets, mask = mask, other = 0)#.to(tl.float32)
|
70 |
+
|
71 |
+
# e = e.float()
|
72 |
+
# se = 1.0 / (1.0 + torch.exp(-e))
|
73 |
+
se_row = tl.sigmoid(e_row) # 1.0 / (1.0 + tl.exp(-e_row))
|
74 |
+
# f = (se * e).to(dtype)
|
75 |
+
f_row = se_row * e_row
|
76 |
+
f_row = f_row.to(DW_row.dtype)
|
77 |
+
# h = f * g
|
78 |
+
h_row = f_row * g_row
|
79 |
+
# df = DW * f
|
80 |
+
df_row = DW_row * f_row
|
81 |
+
# dg = DW * g
|
82 |
+
dg_row = DW_row * g_row
|
83 |
+
# de = (dg.float() * se * (1.0 + e * (1.0 - se))).to(dtype)
|
84 |
+
de_row = dg_row.to(tl.float32) * se_row * (1.0 + e_row * (1.0 - se_row))
|
85 |
+
de_row = de_row.to(DW_row.dtype)
|
86 |
+
|
87 |
+
# Store derivatives in buffers
|
88 |
+
tl.store(DW + offsets, h_row, mask = mask) # h = f * g
|
89 |
+
tl.store(e + offsets, df_row, mask = mask) # df = DW * f
|
90 |
+
tl.store(g + offsets, de_row, mask = mask) # de
|
91 |
+
pass
|
92 |
+
|
93 |
+
|
94 |
+
def swiglu_DWf_DW_dfg_kernel(DW, e, g):
|
95 |
+
batch_seq_len, hd = e.shape
|
96 |
+
n_elements = e.numel()
|
97 |
+
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
|
98 |
+
with torch_cuda_device(e.device):
|
99 |
+
_DWf_DW_dfg_kernel[grid](DW, e, g, n_elements, BLOCK_SIZE = 1024,)
|
100 |
+
return DW, e, g
|
101 |
+
pass
|
torch-ext/unsloth_kernels/utils.py
ADDED
@@ -0,0 +1,497 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import triton
|
16 |
+
MAX_FUSED_SIZE : int = 65536
|
17 |
+
next_power_of_2 = triton.next_power_of_2
|
18 |
+
import functools
|
19 |
+
|
20 |
+
# torch.cuda.amp.custom_fwd is deprecated >= 2.4
|
21 |
+
import torch
|
22 |
+
torch_Tensor = torch.Tensor
|
23 |
+
from packaging.version import Version
|
24 |
+
if Version(torch.__version__) < Version("2.4.0"):
|
25 |
+
torch_amp_custom_fwd = torch.cuda.amp.custom_fwd
|
26 |
+
torch_amp_custom_bwd = torch.cuda.amp.custom_bwd
|
27 |
+
else:
|
28 |
+
torch_amp_custom_fwd = torch.amp.custom_fwd(device_type = "cuda")
|
29 |
+
torch_amp_custom_bwd = torch.amp.custom_bwd(device_type = "cuda")
|
30 |
+
pass
|
31 |
+
|
32 |
+
|
33 |
+
# tl.math.tanh now is libdevice.tanh
|
34 |
+
from packaging.version import Version
|
35 |
+
import triton
|
36 |
+
import triton.language as tl
|
37 |
+
if Version(triton.__version__) >= Version("3.0.0"):
|
38 |
+
from triton.language.extra import libdevice
|
39 |
+
triton_tanh = libdevice.tanh
|
40 |
+
triton_cast = tl.cast
|
41 |
+
else:
|
42 |
+
triton_tanh = tl.math.tanh
|
43 |
+
# No casting in old Triton versions
|
44 |
+
@triton.jit
|
45 |
+
def triton_cast(x, dtype):
|
46 |
+
return x.to(dtype)
|
47 |
+
pass
|
48 |
+
pass
|
49 |
+
|
50 |
+
|
51 |
+
def calculate_settings(n : int) -> (int, int,):
|
52 |
+
BLOCK_SIZE : int = next_power_of_2(n)
|
53 |
+
if BLOCK_SIZE > MAX_FUSED_SIZE:
|
54 |
+
raise RuntimeError(f"Cannot launch Triton kernel since n = {n} exceeds "\
|
55 |
+
f"the maximum CUDA blocksize = {MAX_FUSED_SIZE}.")
|
56 |
+
num_warps : int = 4
|
57 |
+
if BLOCK_SIZE >= 32768: num_warps = 32
|
58 |
+
elif BLOCK_SIZE >= 8192: num_warps = 16
|
59 |
+
elif BLOCK_SIZE >= 2048: num_warps = 8
|
60 |
+
return BLOCK_SIZE, num_warps
|
61 |
+
pass
|
62 |
+
|
63 |
+
|
64 |
+
import bitsandbytes as bnb
|
65 |
+
import ctypes
|
66 |
+
|
67 |
+
# https://github.com/bitsandbytes-foundation/bitsandbytes/pull/1330/files
|
68 |
+
HAS_CUDA_STREAM = Version(bnb.__version__) > Version("0.43.3")
|
69 |
+
get_ptr = bnb.functional.get_ptr
|
70 |
+
|
71 |
+
if torch.cuda.device_count() > 1:
|
72 |
+
torch_cuda_device = torch.cuda.device
|
73 |
+
else:
|
74 |
+
from contextlib import nullcontext
|
75 |
+
def torch_cuda_device(device): return nullcontext()
|
76 |
+
pass
|
77 |
+
_cuda_getCurrentRawStream = torch._C._cuda_getCurrentRawStream
|
78 |
+
c_void_p = ctypes.c_void_p
|
79 |
+
def _get_tensor_stream(tensor: torch_Tensor) -> c_void_p:
|
80 |
+
return c_void_p(_cuda_getCurrentRawStream(tensor.device.index))
|
81 |
+
pass
|
82 |
+
|
83 |
+
# Get array of CUDA streams and other buffers
|
84 |
+
global CUDA_STREAMS
|
85 |
+
global WEIGHT_BUFFERS
|
86 |
+
global ABSMAX_BUFFERS
|
87 |
+
|
88 |
+
_CUDA_STREAMS = {
|
89 |
+
(index := torch.cuda.device(i).idx) : ctypes.c_void_p(torch._C._cuda_getCurrentRawStream(index))
|
90 |
+
for i in range(torch.cuda.device_count())
|
91 |
+
}
|
92 |
+
CUDA_STREAMS = [None] * (max(_CUDA_STREAMS.keys()) + 1)
|
93 |
+
WEIGHT_BUFFERS = [None] * (max(_CUDA_STREAMS.keys()) + 1)
|
94 |
+
ABSMAX_BUFFERS = [None] * (max(_CUDA_STREAMS.keys()) + 1)
|
95 |
+
for k, v in _CUDA_STREAMS.items(): CUDA_STREAMS[k] = v
|
96 |
+
CUDA_STREAMS = tuple(CUDA_STREAMS)
|
97 |
+
del _CUDA_STREAMS
|
98 |
+
|
99 |
+
# Bitsandbytes operations
|
100 |
+
ctypes_c_int = ctypes.c_int
|
101 |
+
ctypes_c_int32 = ctypes.c_int32
|
102 |
+
cdequantize_blockwise_fp32 = bnb.functional.lib.cdequantize_blockwise_fp32
|
103 |
+
cdequantize_blockwise_fp16_nf4 = bnb.functional.lib.cdequantize_blockwise_fp16_nf4
|
104 |
+
cdequantize_blockwise_bf16_nf4 = bnb.functional.lib.cdequantize_blockwise_bf16_nf4
|
105 |
+
cgemm_4bit_inference_naive_fp16 = bnb.functional.lib.cgemm_4bit_inference_naive_fp16
|
106 |
+
cgemm_4bit_inference_naive_bf16 = bnb.functional.lib.cgemm_4bit_inference_naive_bf16
|
107 |
+
torch_mm = torch.mm
|
108 |
+
torch_mv = torch.mv
|
109 |
+
torch_matmul = torch.matmul
|
110 |
+
torch_addmm = torch.addmm
|
111 |
+
torch_empty = torch.empty
|
112 |
+
|
113 |
+
def QUANT_STATE(W): return getattr(W, "quant_state", None)
|
114 |
+
|
115 |
+
def get_lora_parameters(proj):
|
116 |
+
# For DPO or disabled adapters
|
117 |
+
base_layer = getattr(proj, "base_layer", proj) # (proj.base_layer if hasattr(proj, "base_layer") else proj)
|
118 |
+
W = base_layer.weight
|
119 |
+
|
120 |
+
# if not hasattr(proj, "disable_adapters") or proj.disable_adapters or proj.merged:
|
121 |
+
if getattr(proj, "disable_adapters", True) or proj.merged:
|
122 |
+
return W, getattr(W, "quant_state", None), None, None, None
|
123 |
+
pass
|
124 |
+
|
125 |
+
adapter = getattr(proj, "active_adapters", None)
|
126 |
+
if adapter is None: adapter = getattr(proj, "active_adapter", ("default"))
|
127 |
+
adapter = adapter[0]
|
128 |
+
|
129 |
+
return (
|
130 |
+
W,
|
131 |
+
getattr(W, "quant_state", None),
|
132 |
+
proj.lora_A [adapter].weight,
|
133 |
+
proj.lora_B [adapter].weight,
|
134 |
+
proj.scaling[adapter],
|
135 |
+
)
|
136 |
+
pass
|
137 |
+
|
138 |
+
|
139 |
+
def get_lora_parameters_bias(proj):
|
140 |
+
# For DPO or disabled adapters
|
141 |
+
base_layer = getattr(proj, "base_layer", proj) # (proj.base_layer if hasattr(proj, "base_layer") else proj)
|
142 |
+
W = base_layer.weight
|
143 |
+
|
144 |
+
# if not hasattr(proj, "disable_adapters") or proj.disable_adapters or proj.merged:
|
145 |
+
if getattr(proj, "disable_adapters", True) or proj.merged:
|
146 |
+
return W, getattr(W, "quant_state", None), None, None, None, base_layer.bias
|
147 |
+
pass
|
148 |
+
|
149 |
+
adapter = getattr(proj, "active_adapters", None)
|
150 |
+
if adapter is None: adapter = getattr(proj, "active_adapter", ("default"))
|
151 |
+
adapter = adapter[0]
|
152 |
+
|
153 |
+
return (
|
154 |
+
W,
|
155 |
+
getattr(W, "quant_state", None),
|
156 |
+
proj.lora_A [adapter].weight,
|
157 |
+
proj.lora_B [adapter].weight,
|
158 |
+
proj.scaling[adapter],
|
159 |
+
base_layer.bias,
|
160 |
+
)
|
161 |
+
pass
|
162 |
+
|
163 |
+
if HAS_CUDA_STREAM:
|
164 |
+
@torch.inference_mode
|
165 |
+
def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False):
|
166 |
+
if quant_state is None: return W
|
167 |
+
if type(quant_state) is not list:
|
168 |
+
# New quant_state as a class
|
169 |
+
# https://github.com/TimDettmers/bitsandbytes/pull/763/files
|
170 |
+
absmax = quant_state.absmax
|
171 |
+
shape = quant_state.shape
|
172 |
+
dtype = quant_state.dtype
|
173 |
+
blocksize = quant_state.blocksize
|
174 |
+
offset = quant_state.offset
|
175 |
+
state2 = quant_state.state2
|
176 |
+
absmax2 = state2.absmax
|
177 |
+
code2 = state2.code
|
178 |
+
blocksize2 = state2.blocksize
|
179 |
+
else:
|
180 |
+
# Old quant_state as a list of lists
|
181 |
+
absmax, shape, dtype, blocksize, compressed_stats, _, _ = quant_state
|
182 |
+
offset, state2 = compressed_stats
|
183 |
+
absmax2, code2, blocksize2, _, _, _, _ = state2
|
184 |
+
pass
|
185 |
+
global CUDA_STREAMS
|
186 |
+
device = W.device
|
187 |
+
device_index = device.index
|
188 |
+
CUDA_STREAM = CUDA_STREAMS[device_index]
|
189 |
+
|
190 |
+
n_elements_absmax = absmax.numel()
|
191 |
+
|
192 |
+
# Create weight matrix
|
193 |
+
if use_global_buffer:
|
194 |
+
|
195 |
+
# Use same buffers for faster inference
|
196 |
+
size = shape[0]*shape[1]
|
197 |
+
global WEIGHT_BUFFERS
|
198 |
+
global ABSMAX_BUFFERS
|
199 |
+
WEIGHT_BUFFER = WEIGHT_BUFFERS[device_index]
|
200 |
+
ABSMAX_BUFFER = ABSMAX_BUFFERS[device_index]
|
201 |
+
if WEIGHT_BUFFER is None:
|
202 |
+
WEIGHT_BUFFERS[device_index] = WEIGHT_BUFFER = torch_empty(size, dtype = dtype, device = device, requires_grad = False)
|
203 |
+
ABSMAX_BUFFERS[device_index] = ABSMAX_BUFFER = torch_empty(n_elements_absmax, dtype = torch.float32, device = device, requires_grad = False)
|
204 |
+
|
205 |
+
if size > WEIGHT_BUFFER.numel(): WEIGHT_BUFFER.resize_(size)
|
206 |
+
if n_elements_absmax > ABSMAX_BUFFER.numel(): ABSMAX_BUFFER.resize_(n_elements_absmax)
|
207 |
+
|
208 |
+
out = WEIGHT_BUFFER[:size].view(shape)
|
209 |
+
out_absmax = ABSMAX_BUFFER[:n_elements_absmax]
|
210 |
+
else:
|
211 |
+
if out is None:
|
212 |
+
out = torch_empty(shape, dtype = dtype, device = device, requires_grad = False)
|
213 |
+
else:
|
214 |
+
assert(out.shape == shape)
|
215 |
+
assert(out.dtype == dtype)
|
216 |
+
out_absmax = torch_empty(n_elements_absmax, dtype = torch.float32, device = device, requires_grad = False)
|
217 |
+
pass
|
218 |
+
|
219 |
+
# NF4 dequantization of statistics
|
220 |
+
ptr_out_absmax = get_ptr(out_absmax)
|
221 |
+
with torch_cuda_device(device):
|
222 |
+
cdequantize_blockwise_fp32(
|
223 |
+
get_ptr(code2), get_ptr(absmax), get_ptr(absmax2), ptr_out_absmax,
|
224 |
+
ctypes_c_int(blocksize2), ctypes_c_int(n_elements_absmax), CUDA_STREAM
|
225 |
+
)
|
226 |
+
out_absmax += offset
|
227 |
+
|
228 |
+
# Dequantize W
|
229 |
+
fx = cdequantize_blockwise_fp16_nf4 if dtype == torch.float16 else \
|
230 |
+
cdequantize_blockwise_bf16_nf4
|
231 |
+
fx(get_ptr(None), get_ptr(W), ptr_out_absmax, get_ptr(out),
|
232 |
+
ctypes_c_int(blocksize), ctypes_c_int(out.numel()), CUDA_STREAM,)
|
233 |
+
pass
|
234 |
+
# Careful returning transposed data
|
235 |
+
is_transposed = (True if W.shape[0] == 1 else False)
|
236 |
+
return out.t() if is_transposed else out
|
237 |
+
pass
|
238 |
+
else:
|
239 |
+
@torch.inference_mode
|
240 |
+
def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False):
|
241 |
+
if quant_state is None: return W
|
242 |
+
if type(quant_state) is not list:
|
243 |
+
# New quant_state as a class
|
244 |
+
# https://github.com/TimDettmers/bitsandbytes/pull/763/files
|
245 |
+
absmax = quant_state.absmax
|
246 |
+
shape = quant_state.shape
|
247 |
+
dtype = quant_state.dtype
|
248 |
+
blocksize = quant_state.blocksize
|
249 |
+
offset = quant_state.offset
|
250 |
+
state2 = quant_state.state2
|
251 |
+
absmax2 = state2.absmax
|
252 |
+
code2 = state2.code
|
253 |
+
blocksize2 = state2.blocksize
|
254 |
+
else:
|
255 |
+
# Old quant_state as a list of lists
|
256 |
+
absmax, shape, dtype, blocksize, compressed_stats, _, _ = quant_state
|
257 |
+
offset, state2 = compressed_stats
|
258 |
+
absmax2, code2, blocksize2, _, _, _, _ = state2
|
259 |
+
pass
|
260 |
+
|
261 |
+
n_elements_absmax = absmax.numel()
|
262 |
+
device = W.device
|
263 |
+
|
264 |
+
# Create weight matrix
|
265 |
+
if out is None:
|
266 |
+
out = torch_empty(shape, dtype = dtype, device = device, requires_grad = False)
|
267 |
+
else:
|
268 |
+
assert(out.shape == shape)
|
269 |
+
assert(out.dtype == dtype)
|
270 |
+
out_absmax = torch_empty(n_elements_absmax, dtype = torch.float32, device = device, requires_grad = False)
|
271 |
+
|
272 |
+
# Do dequantization
|
273 |
+
ptr_out_absmax = get_ptr(out_absmax)
|
274 |
+
cdequantize_blockwise_fp32(
|
275 |
+
get_ptr(code2), get_ptr(absmax), get_ptr(absmax2), ptr_out_absmax,
|
276 |
+
ctypes_c_int(blocksize2), ctypes_c_int(n_elements_absmax),
|
277 |
+
)
|
278 |
+
out_absmax += offset
|
279 |
+
|
280 |
+
fx = cdequantize_blockwise_fp16_nf4 if dtype == torch.float16 else \
|
281 |
+
cdequantize_blockwise_bf16_nf4
|
282 |
+
fx(get_ptr(None), get_ptr(W), ptr_out_absmax, get_ptr(out),
|
283 |
+
ctypes_c_int(blocksize), ctypes_c_int(out.numel()),)
|
284 |
+
|
285 |
+
# Careful returning transposed data
|
286 |
+
is_transposed = (True if W.shape[0] == 1 else False)
|
287 |
+
return out.t() if is_transposed else out
|
288 |
+
pass
|
289 |
+
pass
|
290 |
+
|
291 |
+
|
292 |
+
if HAS_CUDA_STREAM:
|
293 |
+
def fast_gemv(X, W, quant_state, out = None):
|
294 |
+
if quant_state is None: return torch_matmul(X, W, out = out)
|
295 |
+
# For fast X @ W where seq_len == 1
|
296 |
+
# From https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/functional.py#L1469
|
297 |
+
_, q_len, hd = X.shape
|
298 |
+
# assert(q_len == 1)
|
299 |
+
|
300 |
+
if type(quant_state) is not list:
|
301 |
+
# https://github.com/TimDettmers/bitsandbytes/pull/763/files
|
302 |
+
absmax = quant_state.absmax
|
303 |
+
shape = quant_state.shape
|
304 |
+
dtype = quant_state.dtype
|
305 |
+
blocksize = quant_state.blocksize
|
306 |
+
stats = quant_state.code
|
307 |
+
offset = quant_state.offset
|
308 |
+
state2 = quant_state.state2
|
309 |
+
absmax2 = state2.absmax
|
310 |
+
code2 = state2.code
|
311 |
+
blocksize2 = state2.blocksize
|
312 |
+
else:
|
313 |
+
absmax, shape, dtype, blocksize, compressed_stats, quant_type, stats = quant_state
|
314 |
+
offset, state2 = compressed_stats
|
315 |
+
absmax2, code2, blocksize2, _, _, _, _ = state2
|
316 |
+
pass
|
317 |
+
global CUDA_STREAMS
|
318 |
+
device = W.device
|
319 |
+
device_index = device.index
|
320 |
+
CUDA_STREAM = CUDA_STREAMS[device_index]
|
321 |
+
|
322 |
+
# assert(dtype == X.dtype)
|
323 |
+
bout = shape[0]
|
324 |
+
|
325 |
+
if out is None:
|
326 |
+
out = torch_empty((1, 1, bout,), dtype = dtype, device = device)
|
327 |
+
# else:
|
328 |
+
# assert(out.shape == (1, 1, bout,))
|
329 |
+
# pass
|
330 |
+
|
331 |
+
n = 1
|
332 |
+
m = shape[0]
|
333 |
+
k = shape[1]
|
334 |
+
lda = shape[0]
|
335 |
+
ldc = shape[0]
|
336 |
+
ldb = (hd+1)//2
|
337 |
+
m = ctypes_c_int32(m)
|
338 |
+
n = ctypes_c_int32(n)
|
339 |
+
k = ctypes_c_int32(k)
|
340 |
+
lda = ctypes_c_int32(lda)
|
341 |
+
ldb = ctypes_c_int32(ldb)
|
342 |
+
ldc = ctypes_c_int32(ldc)
|
343 |
+
|
344 |
+
df = torch_empty(absmax.shape, dtype = torch.float32, device = device)
|
345 |
+
with torch_cuda_device(device):
|
346 |
+
cdequantize_blockwise_fp32(
|
347 |
+
get_ptr(code2), get_ptr(absmax), get_ptr(absmax2), get_ptr(df),
|
348 |
+
ctypes_c_int(blocksize2), ctypes_c_int(df.numel()), CUDA_STREAM,
|
349 |
+
)
|
350 |
+
df += offset
|
351 |
+
absmax = df
|
352 |
+
|
353 |
+
fx = cgemm_4bit_inference_naive_fp16 if dtype == torch.float16 else \
|
354 |
+
cgemm_4bit_inference_naive_bf16
|
355 |
+
|
356 |
+
blocksize = ctypes_c_int32(blocksize)
|
357 |
+
fx(m, n, k, get_ptr(X), get_ptr(W), get_ptr(absmax), get_ptr(stats), get_ptr(out),
|
358 |
+
lda, ldb, ldc, blocksize, CUDA_STREAM,)
|
359 |
+
pass
|
360 |
+
|
361 |
+
return out
|
362 |
+
pass
|
363 |
+
else:
|
364 |
+
def fast_gemv(X, W, quant_state, out = None):
|
365 |
+
if quant_state is None: return torch.matmul(X, W, out = out)
|
366 |
+
# For fast X @ W where seq_len == 1
|
367 |
+
# From https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/functional.py#L1469
|
368 |
+
_, q_len, hd = X.shape
|
369 |
+
# assert(q_len == 1)
|
370 |
+
|
371 |
+
if type(quant_state) is not list:
|
372 |
+
# https://github.com/TimDettmers/bitsandbytes/pull/763/files
|
373 |
+
absmax = quant_state.absmax
|
374 |
+
shape = quant_state.shape
|
375 |
+
dtype = quant_state.dtype
|
376 |
+
blocksize = quant_state.blocksize
|
377 |
+
stats = quant_state.code
|
378 |
+
offset = quant_state.offset
|
379 |
+
state2 = quant_state.state2
|
380 |
+
absmax2 = state2.absmax
|
381 |
+
code2 = state2.code
|
382 |
+
blocksize2 = state2.blocksize
|
383 |
+
else:
|
384 |
+
absmax, shape, dtype, blocksize, compressed_stats, quant_type, stats = quant_state
|
385 |
+
offset, state2 = compressed_stats
|
386 |
+
absmax2, code2, blocksize2, _, _, _, _ = state2
|
387 |
+
pass
|
388 |
+
# assert(dtype == X.dtype)
|
389 |
+
bout = shape[0]
|
390 |
+
device = W.device
|
391 |
+
|
392 |
+
if out is None:
|
393 |
+
out = torch_empty((1, 1, bout,), dtype = dtype, device = device)
|
394 |
+
# else:
|
395 |
+
# assert(out.shape == (1, 1, bout,))
|
396 |
+
# pass
|
397 |
+
|
398 |
+
n = 1
|
399 |
+
m = shape[0]
|
400 |
+
k = shape[1]
|
401 |
+
lda = shape[0]
|
402 |
+
ldc = shape[0]
|
403 |
+
ldb = (hd+1)//2
|
404 |
+
m = ctypes_c_int32(m)
|
405 |
+
n = ctypes_c_int32(n)
|
406 |
+
k = ctypes_c_int32(k)
|
407 |
+
lda = ctypes_c_int32(lda)
|
408 |
+
ldb = ctypes_c_int32(ldb)
|
409 |
+
ldc = ctypes_c_int32(ldc)
|
410 |
+
|
411 |
+
df = torch_empty(absmax.shape, dtype = torch.float32, device = device)
|
412 |
+
cdequantize_blockwise_fp32(
|
413 |
+
get_ptr(code2), get_ptr(absmax), get_ptr(absmax2), get_ptr(df),
|
414 |
+
ctypes_c_int(blocksize2), ctypes_c_int(df.numel()),
|
415 |
+
)
|
416 |
+
df += offset
|
417 |
+
absmax = df
|
418 |
+
|
419 |
+
fx = cgemm_4bit_inference_naive_fp16 if dtype == torch.float16 else \
|
420 |
+
cgemm_4bit_inference_naive_bf16
|
421 |
+
|
422 |
+
blocksize = ctypes_c_int32(blocksize)
|
423 |
+
fx(m, n, k, get_ptr(X), get_ptr(W), get_ptr(absmax), get_ptr(stats), get_ptr(out),
|
424 |
+
lda, ldb, ldc, blocksize,)
|
425 |
+
|
426 |
+
return out
|
427 |
+
pass
|
428 |
+
pass
|
429 |
+
|
430 |
+
|
431 |
+
def fast_linear_forward(proj, X, temp_lora = None, out = None):
|
432 |
+
|
433 |
+
W, W_quant, lora_A, lora_B, lora_S, bias = get_lora_parameters_bias(proj)
|
434 |
+
bsz, q_len, in_dim = X.shape
|
435 |
+
if q_len != 1: return matmul_lora(X, W, W_quant, lora_A, lora_B, lora_S)
|
436 |
+
|
437 |
+
if W_quant is None:
|
438 |
+
out = torch_matmul(X, W.t(), out = out)
|
439 |
+
elif bsz == 1 and q_len == 1:
|
440 |
+
out = fast_gemv(X, W, W_quant, out = out)
|
441 |
+
else:
|
442 |
+
W = fast_dequantize(W.t(), W_quant, use_global_buffer = True)
|
443 |
+
out = torch_matmul(X, W, out = out)
|
444 |
+
pass
|
445 |
+
|
446 |
+
# Add in LoRA weights
|
447 |
+
if lora_A is not None:
|
448 |
+
out_dim = out.shape[2]
|
449 |
+
dtype = X.dtype
|
450 |
+
|
451 |
+
if not hasattr(lora_A, "_fast_lora"):
|
452 |
+
lora_A._fast_lora = lora_A.to(dtype)
|
453 |
+
lora_B._fast_lora = lora_B.to(dtype)
|
454 |
+
pass
|
455 |
+
|
456 |
+
if bsz == 1:
|
457 |
+
out = out.view(out_dim)
|
458 |
+
temp_lora = torch_mv(lora_A._fast_lora, X.ravel(), out = temp_lora)
|
459 |
+
out.addmv_(lora_B._fast_lora, temp_lora, alpha = lora_S)
|
460 |
+
else:
|
461 |
+
out = out.view(bsz, out_dim)
|
462 |
+
temp_lora = torch_mm(X.view(bsz, in_dim), lora_A._fast_lora.t(), out = temp_lora)
|
463 |
+
out.addmm_(temp_lora, lora_B._fast_lora.t(), alpha = lora_S)
|
464 |
+
pass
|
465 |
+
out = out.view(bsz, 1, out_dim)
|
466 |
+
pass
|
467 |
+
|
468 |
+
if bias is not None: out += bias
|
469 |
+
|
470 |
+
return out
|
471 |
+
pass
|
472 |
+
|
473 |
+
|
474 |
+
def matmul_lora(X, W, W_quant, A, B, s, out = None):
|
475 |
+
dtype = X.dtype
|
476 |
+
W = fast_dequantize(W.t(), W_quant, use_global_buffer = True)
|
477 |
+
|
478 |
+
if X.dim() == 3:
|
479 |
+
batch, seq_len, d = X.shape
|
480 |
+
X = X.view(-1, X.shape[-1])
|
481 |
+
reshape = True
|
482 |
+
else:
|
483 |
+
reshape = False
|
484 |
+
pass
|
485 |
+
out = torch_matmul(X, W, out = out)
|
486 |
+
if W_quant is not None: del W
|
487 |
+
|
488 |
+
if A is not None:
|
489 |
+
# LoRA is enabled
|
490 |
+
A, B = A.t(), B.t()
|
491 |
+
XA = torch_matmul(X, A.to(dtype))
|
492 |
+
out.addmm_(XA, B.to(dtype), alpha = s)
|
493 |
+
# out += (X @ A.to(dtype)) @ (s * B.to(dtype))
|
494 |
+
pass
|
495 |
+
|
496 |
+
return out.view(batch, seq_len, -1) if reshape else out
|
497 |
+
pass
|