jasonisme99 commited on
Commit
8909507
·
verified ·
1 Parent(s): b22ed7b

Upload kohya_lora_loader.py

Browse files
Files changed (1) hide show
  1. kohya_lora_loader.py +259 -0
kohya_lora_loader.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import safetensors
3
+ import torch
4
+ from diffusers import DiffusionPipeline
5
+
6
+ """
7
+ Kohya's LoRA format Loader for Diffusers
8
+
9
+ Usage:
10
+ ```py
11
+
12
+ # An usual Diffusers' setup
13
+ import torch
14
+ from diffusers import StableDiffusionPipeline
15
+ pipe = StableDiffusionPipeline.from_pretrained('...',
16
+ torch_dtype=torch.float16).to('cuda')
17
+
18
+ # Import this module
19
+ import kohya_lora_loader
20
+
21
+ # Install LoRA hook. This append apply_loar and remove_loar methods to the pipe.
22
+ kohya_lora_loader.install_lora_hook(pipe)
23
+
24
+ # Load 'lora1.safetensors' file and apply
25
+ lora1 = pipe.apply_lora('lora1.safetensors', 1.0)
26
+
27
+ # You can change alpha
28
+ lora1.alpha = 0.5
29
+
30
+ # Load 'lora2.safetensors' file and apply
31
+ lora2 = pipe.apply_lora('lora2.safetensors', 1.0)
32
+
33
+ # Generate image with lora1 and lora2 applied
34
+ pipe(...).images[0]
35
+
36
+ # Remove lora2
37
+ pipe.remove_lora(lora2)
38
+
39
+ # Generate image with lora1 applied
40
+ pipe(...).images[0]
41
+
42
+ # Uninstall LoRA hook
43
+ kohya_lora_loader.uninstall_lora_hook(pipe)
44
+
45
+ # Generate image with none LoRA applied
46
+ pipe(...).images[0]
47
+
48
+ ```
49
+ """
50
+
51
+
52
+ # modified from https://github.com/kohya-ss/sd-scripts/blob/ad5f318d066c52e5b27306b399bc87e41f2eef2b/networks/lora.py#L17
53
+ class LoRAModule(torch.nn.Module):
54
+ def __init__(
55
+ self, org_module: torch.nn.Module, lora_dim=4, alpha=1.0, multiplier=1.0
56
+ ):
57
+ """if alpha == 0 or None, alpha is rank (no scaling)."""
58
+ super().__init__()
59
+
60
+ if org_module.__class__.__name__ == "Conv2d":
61
+ in_dim = org_module.in_channels
62
+ out_dim = org_module.out_channels
63
+ else:
64
+ in_dim = org_module.in_features
65
+ out_dim = org_module.out_features
66
+
67
+ self.lora_dim = lora_dim
68
+
69
+ if org_module.__class__.__name__ == "Conv2d":
70
+ kernel_size = org_module.kernel_size
71
+ stride = org_module.stride
72
+ padding = org_module.padding
73
+ self.lora_down = torch.nn.Conv2d(
74
+ in_dim, self.lora_dim, kernel_size, stride, padding, bias=False
75
+ )
76
+ self.lora_up = torch.nn.Conv2d(
77
+ self.lora_dim, out_dim, (1, 1), (1, 1), bias=False
78
+ )
79
+ else:
80
+ self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False)
81
+ self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False)
82
+
83
+ if alpha is None or alpha == 0:
84
+ self.alpha = self.lora_dim
85
+ else:
86
+ if type(alpha) == torch.Tensor:
87
+ alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
88
+ self.register_buffer("alpha", torch.tensor(alpha)) # Treatable as a constant.
89
+
90
+ # same as microsoft's
91
+ torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
92
+ torch.nn.init.zeros_(self.lora_up.weight)
93
+
94
+ self.multiplier = multiplier
95
+
96
+ def forward(self, x):
97
+ scale = self.alpha / self.lora_dim
98
+ return self.multiplier * scale * self.lora_up(self.lora_down(x))
99
+
100
+
101
+ class LoRAModuleContainer(torch.nn.Module):
102
+ def __init__(self, hooks, state_dict, multiplier):
103
+ super().__init__()
104
+ self.multiplier = multiplier
105
+
106
+ # Create LoRAModule from state_dict information
107
+ for key, value in state_dict.items():
108
+ if "lora_down" in key:
109
+ lora_name = key.split(".")[0]
110
+ lora_dim = value.size()[0]
111
+ lora_name_alpha = key.split(".")[0] + '.alpha'
112
+ alpha = None
113
+ if lora_name_alpha in state_dict:
114
+ alpha = state_dict[lora_name_alpha].item()
115
+ hook = hooks[lora_name]
116
+ lora_module = LoRAModule(
117
+ hook.orig_module, lora_dim=lora_dim, alpha=alpha, multiplier=multiplier
118
+ )
119
+ self.register_module(lora_name, lora_module)
120
+
121
+ # Load whole LoRA weights
122
+ self.load_state_dict(state_dict)
123
+
124
+ # Register LoRAModule to LoRAHook
125
+ for name, module in self.named_modules():
126
+ if module.__class__.__name__ == "LoRAModule":
127
+ hook = hooks[name]
128
+ hook.append_lora(module)
129
+ @property
130
+ def alpha(self):
131
+ return self.multiplier
132
+
133
+ @alpha.setter
134
+ def alpha(self, multiplier):
135
+ self.multiplier = multiplier
136
+ for name, module in self.named_modules():
137
+ if module.__class__.__name__ == "LoRAModule":
138
+ module.multiplier = multiplier
139
+
140
+ def remove_from_hooks(self, hooks):
141
+ for name, module in self.named_modules():
142
+ if module.__class__.__name__ == "LoRAModule":
143
+ hook = hooks[name]
144
+ hook.remove_lora(module)
145
+ del module
146
+
147
+
148
+ class LoRAHook(torch.nn.Module):
149
+ """
150
+ replaces forward method of the original Linear,
151
+ instead of replacing the original Linear module.
152
+ """
153
+
154
+ def __init__(self):
155
+ super().__init__()
156
+ self.lora_modules = []
157
+
158
+ def install(self, orig_module):
159
+ assert not hasattr(self, "orig_module")
160
+ self.orig_module = orig_module
161
+ self.orig_forward = self.orig_module.forward
162
+ self.orig_module.forward = self.forward
163
+
164
+ def uninstall(self):
165
+ assert hasattr(self, "orig_module")
166
+ self.orig_module.forward = self.orig_forward
167
+ del self.orig_forward
168
+ del self.orig_module
169
+
170
+ def append_lora(self, lora_module):
171
+ self.lora_modules.append(lora_module)
172
+
173
+ def remove_lora(self, lora_module):
174
+ self.lora_modules.remove(lora_module)
175
+
176
+ def forward(self, x):
177
+ if len(self.lora_modules) == 0:
178
+ return self.orig_forward(x)
179
+ lora = torch.sum(torch.stack([lora(x) for lora in self.lora_modules]), dim=0)
180
+ return self.orig_forward(x) + lora
181
+
182
+
183
+ class LoRAHookInjector(object):
184
+ def __init__(self):
185
+ super().__init__()
186
+ self.hooks = {}
187
+ self.device = None
188
+ self.dtype = None
189
+
190
+ def _get_target_modules(self, root_module, prefix, target_replace_modules):
191
+ target_modules = []
192
+ for name, module in root_module.named_modules():
193
+ if (
194
+ module.__class__.__name__ in target_replace_modules
195
+ and not "transformer_blocks" in name
196
+ ): # to adapt latest diffusers:
197
+ for child_name, child_module in module.named_modules():
198
+ is_linear = child_module.__class__.__name__ == "Linear"
199
+ is_conv2d = child_module.__class__.__name__ == "Conv2d"
200
+ if is_linear or is_conv2d:
201
+ lora_name = prefix + "." + name + "." + child_name
202
+ lora_name = lora_name.replace(".", "_")
203
+ target_modules.append((lora_name, child_module))
204
+ return target_modules
205
+
206
+ def install_hooks(self, pipe):
207
+ """Install LoRAHook to the pipe."""
208
+ assert len(self.hooks) == 0
209
+ text_encoder_targets = self._get_target_modules(
210
+ pipe.text_encoder, "lora_te", ["CLIPAttention", "CLIPMLP"]
211
+ )
212
+ unet_targets = self._get_target_modules(
213
+ pipe.unet, "lora_unet", ["Transformer2DModel", "Attention"]
214
+ )
215
+ for name, target_module in text_encoder_targets + unet_targets:
216
+ hook = LoRAHook()
217
+ hook.install(target_module)
218
+ self.hooks[name] = hook
219
+
220
+ self.device = pipe.device
221
+ self.dtype = pipe.unet.dtype
222
+
223
+ def uninstall_hooks(self):
224
+ """Uninstall LoRAHook from the pipe."""
225
+ for k, v in self.hooks.items():
226
+ v.uninstall()
227
+ self.hooks = {}
228
+
229
+ def apply_lora(self, filename, alpha=1.0):
230
+ """Load LoRA weights and apply LoRA to the pipe."""
231
+ assert len(self.hooks) != 0
232
+ state_dict = safetensors.torch.load_file(filename)
233
+ container = LoRAModuleContainer(self.hooks, state_dict, alpha)
234
+ container.to(self.device, self.dtype)
235
+ return container
236
+
237
+ def remove_lora(self, container):
238
+ """Remove the individual LoRA from the pipe."""
239
+ container.remove_from_hooks(self.hooks)
240
+
241
+
242
+ def install_lora_hook(pipe: DiffusionPipeline):
243
+ """Install LoRAHook to the pipe."""
244
+ assert not hasattr(pipe, "lora_injector")
245
+ assert not hasattr(pipe, "apply_lora")
246
+ assert not hasattr(pipe, "remove_lora")
247
+ injector = LoRAHookInjector()
248
+ injector.install_hooks(pipe)
249
+ pipe.lora_injector = injector
250
+ pipe.apply_lora = injector.apply_lora
251
+ pipe.remove_lora = injector.remove_lora
252
+
253
+
254
+ def uninstall_lora_hook(pipe: DiffusionPipeline):
255
+ """Uninstall LoRAHook from the pipe."""
256
+ pipe.lora_injector.uninstall_hooks()
257
+ del pipe.lora_injector
258
+ del pipe.apply_lora
259
+ del pipe.remove_lora