Upload kohya_lora_loader.py
Browse files- kohya_lora_loader.py +259 -0
kohya_lora_loader.py
ADDED
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import safetensors
|
3 |
+
import torch
|
4 |
+
from diffusers import DiffusionPipeline
|
5 |
+
|
6 |
+
"""
|
7 |
+
Kohya's LoRA format Loader for Diffusers
|
8 |
+
|
9 |
+
Usage:
|
10 |
+
```py
|
11 |
+
|
12 |
+
# An usual Diffusers' setup
|
13 |
+
import torch
|
14 |
+
from diffusers import StableDiffusionPipeline
|
15 |
+
pipe = StableDiffusionPipeline.from_pretrained('...',
|
16 |
+
torch_dtype=torch.float16).to('cuda')
|
17 |
+
|
18 |
+
# Import this module
|
19 |
+
import kohya_lora_loader
|
20 |
+
|
21 |
+
# Install LoRA hook. This append apply_loar and remove_loar methods to the pipe.
|
22 |
+
kohya_lora_loader.install_lora_hook(pipe)
|
23 |
+
|
24 |
+
# Load 'lora1.safetensors' file and apply
|
25 |
+
lora1 = pipe.apply_lora('lora1.safetensors', 1.0)
|
26 |
+
|
27 |
+
# You can change alpha
|
28 |
+
lora1.alpha = 0.5
|
29 |
+
|
30 |
+
# Load 'lora2.safetensors' file and apply
|
31 |
+
lora2 = pipe.apply_lora('lora2.safetensors', 1.0)
|
32 |
+
|
33 |
+
# Generate image with lora1 and lora2 applied
|
34 |
+
pipe(...).images[0]
|
35 |
+
|
36 |
+
# Remove lora2
|
37 |
+
pipe.remove_lora(lora2)
|
38 |
+
|
39 |
+
# Generate image with lora1 applied
|
40 |
+
pipe(...).images[0]
|
41 |
+
|
42 |
+
# Uninstall LoRA hook
|
43 |
+
kohya_lora_loader.uninstall_lora_hook(pipe)
|
44 |
+
|
45 |
+
# Generate image with none LoRA applied
|
46 |
+
pipe(...).images[0]
|
47 |
+
|
48 |
+
```
|
49 |
+
"""
|
50 |
+
|
51 |
+
|
52 |
+
# modified from https://github.com/kohya-ss/sd-scripts/blob/ad5f318d066c52e5b27306b399bc87e41f2eef2b/networks/lora.py#L17
|
53 |
+
class LoRAModule(torch.nn.Module):
|
54 |
+
def __init__(
|
55 |
+
self, org_module: torch.nn.Module, lora_dim=4, alpha=1.0, multiplier=1.0
|
56 |
+
):
|
57 |
+
"""if alpha == 0 or None, alpha is rank (no scaling)."""
|
58 |
+
super().__init__()
|
59 |
+
|
60 |
+
if org_module.__class__.__name__ == "Conv2d":
|
61 |
+
in_dim = org_module.in_channels
|
62 |
+
out_dim = org_module.out_channels
|
63 |
+
else:
|
64 |
+
in_dim = org_module.in_features
|
65 |
+
out_dim = org_module.out_features
|
66 |
+
|
67 |
+
self.lora_dim = lora_dim
|
68 |
+
|
69 |
+
if org_module.__class__.__name__ == "Conv2d":
|
70 |
+
kernel_size = org_module.kernel_size
|
71 |
+
stride = org_module.stride
|
72 |
+
padding = org_module.padding
|
73 |
+
self.lora_down = torch.nn.Conv2d(
|
74 |
+
in_dim, self.lora_dim, kernel_size, stride, padding, bias=False
|
75 |
+
)
|
76 |
+
self.lora_up = torch.nn.Conv2d(
|
77 |
+
self.lora_dim, out_dim, (1, 1), (1, 1), bias=False
|
78 |
+
)
|
79 |
+
else:
|
80 |
+
self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False)
|
81 |
+
self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False)
|
82 |
+
|
83 |
+
if alpha is None or alpha == 0:
|
84 |
+
self.alpha = self.lora_dim
|
85 |
+
else:
|
86 |
+
if type(alpha) == torch.Tensor:
|
87 |
+
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
|
88 |
+
self.register_buffer("alpha", torch.tensor(alpha)) # Treatable as a constant.
|
89 |
+
|
90 |
+
# same as microsoft's
|
91 |
+
torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
|
92 |
+
torch.nn.init.zeros_(self.lora_up.weight)
|
93 |
+
|
94 |
+
self.multiplier = multiplier
|
95 |
+
|
96 |
+
def forward(self, x):
|
97 |
+
scale = self.alpha / self.lora_dim
|
98 |
+
return self.multiplier * scale * self.lora_up(self.lora_down(x))
|
99 |
+
|
100 |
+
|
101 |
+
class LoRAModuleContainer(torch.nn.Module):
|
102 |
+
def __init__(self, hooks, state_dict, multiplier):
|
103 |
+
super().__init__()
|
104 |
+
self.multiplier = multiplier
|
105 |
+
|
106 |
+
# Create LoRAModule from state_dict information
|
107 |
+
for key, value in state_dict.items():
|
108 |
+
if "lora_down" in key:
|
109 |
+
lora_name = key.split(".")[0]
|
110 |
+
lora_dim = value.size()[0]
|
111 |
+
lora_name_alpha = key.split(".")[0] + '.alpha'
|
112 |
+
alpha = None
|
113 |
+
if lora_name_alpha in state_dict:
|
114 |
+
alpha = state_dict[lora_name_alpha].item()
|
115 |
+
hook = hooks[lora_name]
|
116 |
+
lora_module = LoRAModule(
|
117 |
+
hook.orig_module, lora_dim=lora_dim, alpha=alpha, multiplier=multiplier
|
118 |
+
)
|
119 |
+
self.register_module(lora_name, lora_module)
|
120 |
+
|
121 |
+
# Load whole LoRA weights
|
122 |
+
self.load_state_dict(state_dict)
|
123 |
+
|
124 |
+
# Register LoRAModule to LoRAHook
|
125 |
+
for name, module in self.named_modules():
|
126 |
+
if module.__class__.__name__ == "LoRAModule":
|
127 |
+
hook = hooks[name]
|
128 |
+
hook.append_lora(module)
|
129 |
+
@property
|
130 |
+
def alpha(self):
|
131 |
+
return self.multiplier
|
132 |
+
|
133 |
+
@alpha.setter
|
134 |
+
def alpha(self, multiplier):
|
135 |
+
self.multiplier = multiplier
|
136 |
+
for name, module in self.named_modules():
|
137 |
+
if module.__class__.__name__ == "LoRAModule":
|
138 |
+
module.multiplier = multiplier
|
139 |
+
|
140 |
+
def remove_from_hooks(self, hooks):
|
141 |
+
for name, module in self.named_modules():
|
142 |
+
if module.__class__.__name__ == "LoRAModule":
|
143 |
+
hook = hooks[name]
|
144 |
+
hook.remove_lora(module)
|
145 |
+
del module
|
146 |
+
|
147 |
+
|
148 |
+
class LoRAHook(torch.nn.Module):
|
149 |
+
"""
|
150 |
+
replaces forward method of the original Linear,
|
151 |
+
instead of replacing the original Linear module.
|
152 |
+
"""
|
153 |
+
|
154 |
+
def __init__(self):
|
155 |
+
super().__init__()
|
156 |
+
self.lora_modules = []
|
157 |
+
|
158 |
+
def install(self, orig_module):
|
159 |
+
assert not hasattr(self, "orig_module")
|
160 |
+
self.orig_module = orig_module
|
161 |
+
self.orig_forward = self.orig_module.forward
|
162 |
+
self.orig_module.forward = self.forward
|
163 |
+
|
164 |
+
def uninstall(self):
|
165 |
+
assert hasattr(self, "orig_module")
|
166 |
+
self.orig_module.forward = self.orig_forward
|
167 |
+
del self.orig_forward
|
168 |
+
del self.orig_module
|
169 |
+
|
170 |
+
def append_lora(self, lora_module):
|
171 |
+
self.lora_modules.append(lora_module)
|
172 |
+
|
173 |
+
def remove_lora(self, lora_module):
|
174 |
+
self.lora_modules.remove(lora_module)
|
175 |
+
|
176 |
+
def forward(self, x):
|
177 |
+
if len(self.lora_modules) == 0:
|
178 |
+
return self.orig_forward(x)
|
179 |
+
lora = torch.sum(torch.stack([lora(x) for lora in self.lora_modules]), dim=0)
|
180 |
+
return self.orig_forward(x) + lora
|
181 |
+
|
182 |
+
|
183 |
+
class LoRAHookInjector(object):
|
184 |
+
def __init__(self):
|
185 |
+
super().__init__()
|
186 |
+
self.hooks = {}
|
187 |
+
self.device = None
|
188 |
+
self.dtype = None
|
189 |
+
|
190 |
+
def _get_target_modules(self, root_module, prefix, target_replace_modules):
|
191 |
+
target_modules = []
|
192 |
+
for name, module in root_module.named_modules():
|
193 |
+
if (
|
194 |
+
module.__class__.__name__ in target_replace_modules
|
195 |
+
and not "transformer_blocks" in name
|
196 |
+
): # to adapt latest diffusers:
|
197 |
+
for child_name, child_module in module.named_modules():
|
198 |
+
is_linear = child_module.__class__.__name__ == "Linear"
|
199 |
+
is_conv2d = child_module.__class__.__name__ == "Conv2d"
|
200 |
+
if is_linear or is_conv2d:
|
201 |
+
lora_name = prefix + "." + name + "." + child_name
|
202 |
+
lora_name = lora_name.replace(".", "_")
|
203 |
+
target_modules.append((lora_name, child_module))
|
204 |
+
return target_modules
|
205 |
+
|
206 |
+
def install_hooks(self, pipe):
|
207 |
+
"""Install LoRAHook to the pipe."""
|
208 |
+
assert len(self.hooks) == 0
|
209 |
+
text_encoder_targets = self._get_target_modules(
|
210 |
+
pipe.text_encoder, "lora_te", ["CLIPAttention", "CLIPMLP"]
|
211 |
+
)
|
212 |
+
unet_targets = self._get_target_modules(
|
213 |
+
pipe.unet, "lora_unet", ["Transformer2DModel", "Attention"]
|
214 |
+
)
|
215 |
+
for name, target_module in text_encoder_targets + unet_targets:
|
216 |
+
hook = LoRAHook()
|
217 |
+
hook.install(target_module)
|
218 |
+
self.hooks[name] = hook
|
219 |
+
|
220 |
+
self.device = pipe.device
|
221 |
+
self.dtype = pipe.unet.dtype
|
222 |
+
|
223 |
+
def uninstall_hooks(self):
|
224 |
+
"""Uninstall LoRAHook from the pipe."""
|
225 |
+
for k, v in self.hooks.items():
|
226 |
+
v.uninstall()
|
227 |
+
self.hooks = {}
|
228 |
+
|
229 |
+
def apply_lora(self, filename, alpha=1.0):
|
230 |
+
"""Load LoRA weights and apply LoRA to the pipe."""
|
231 |
+
assert len(self.hooks) != 0
|
232 |
+
state_dict = safetensors.torch.load_file(filename)
|
233 |
+
container = LoRAModuleContainer(self.hooks, state_dict, alpha)
|
234 |
+
container.to(self.device, self.dtype)
|
235 |
+
return container
|
236 |
+
|
237 |
+
def remove_lora(self, container):
|
238 |
+
"""Remove the individual LoRA from the pipe."""
|
239 |
+
container.remove_from_hooks(self.hooks)
|
240 |
+
|
241 |
+
|
242 |
+
def install_lora_hook(pipe: DiffusionPipeline):
|
243 |
+
"""Install LoRAHook to the pipe."""
|
244 |
+
assert not hasattr(pipe, "lora_injector")
|
245 |
+
assert not hasattr(pipe, "apply_lora")
|
246 |
+
assert not hasattr(pipe, "remove_lora")
|
247 |
+
injector = LoRAHookInjector()
|
248 |
+
injector.install_hooks(pipe)
|
249 |
+
pipe.lora_injector = injector
|
250 |
+
pipe.apply_lora = injector.apply_lora
|
251 |
+
pipe.remove_lora = injector.remove_lora
|
252 |
+
|
253 |
+
|
254 |
+
def uninstall_lora_hook(pipe: DiffusionPipeline):
|
255 |
+
"""Uninstall LoRAHook from the pipe."""
|
256 |
+
pipe.lora_injector.uninstall_hooks()
|
257 |
+
del pipe.lora_injector
|
258 |
+
del pipe.apply_lora
|
259 |
+
del pipe.remove_lora
|