Spaces:
Running
on
Zero
Running
on
Zero
import logging | |
from typing import Optional | |
import torch | |
import comfy.model_management | |
from .base import WeightAdapterBase, weight_decompose, pad_tensor_to_shape | |
class LoRAAdapter(WeightAdapterBase): | |
name = "lora" | |
def __init__(self, loaded_keys, weights): | |
self.loaded_keys = loaded_keys | |
self.weights = weights | |
def load( | |
cls, | |
x: str, | |
lora: dict[str, torch.Tensor], | |
alpha: float, | |
dora_scale: torch.Tensor, | |
loaded_keys: set[str] = None, | |
) -> Optional["LoRAAdapter"]: | |
if loaded_keys is None: | |
loaded_keys = set() | |
reshape_name = "{}.reshape_weight".format(x) | |
regular_lora = "{}.lora_up.weight".format(x) | |
diffusers_lora = "{}_lora.up.weight".format(x) | |
diffusers2_lora = "{}.lora_B.weight".format(x) | |
diffusers3_lora = "{}.lora.up.weight".format(x) | |
mochi_lora = "{}.lora_B".format(x) | |
transformers_lora = "{}.lora_linear_layer.up.weight".format(x) | |
A_name = None | |
if regular_lora in lora.keys(): | |
A_name = regular_lora | |
B_name = "{}.lora_down.weight".format(x) | |
mid_name = "{}.lora_mid.weight".format(x) | |
elif diffusers_lora in lora.keys(): | |
A_name = diffusers_lora | |
B_name = "{}_lora.down.weight".format(x) | |
mid_name = None | |
elif diffusers2_lora in lora.keys(): | |
A_name = diffusers2_lora | |
B_name = "{}.lora_A.weight".format(x) | |
mid_name = None | |
elif diffusers3_lora in lora.keys(): | |
A_name = diffusers3_lora | |
B_name = "{}.lora.down.weight".format(x) | |
mid_name = None | |
elif mochi_lora in lora.keys(): | |
A_name = mochi_lora | |
B_name = "{}.lora_A".format(x) | |
mid_name = None | |
elif transformers_lora in lora.keys(): | |
A_name = transformers_lora | |
B_name = "{}.lora_linear_layer.down.weight".format(x) | |
mid_name = None | |
if A_name is not None: | |
mid = None | |
if mid_name is not None and mid_name in lora.keys(): | |
mid = lora[mid_name] | |
loaded_keys.add(mid_name) | |
reshape = None | |
if reshape_name in lora.keys(): | |
try: | |
reshape = lora[reshape_name].tolist() | |
loaded_keys.add(reshape_name) | |
except: | |
pass | |
weights = (lora[A_name], lora[B_name], alpha, mid, dora_scale, reshape) | |
loaded_keys.add(A_name) | |
loaded_keys.add(B_name) | |
return cls(loaded_keys, weights) | |
else: | |
return None | |
def calculate_weight( | |
self, | |
weight, | |
key, | |
strength, | |
strength_model, | |
offset, | |
function, | |
intermediate_dtype=torch.float32, | |
original_weight=None, | |
): | |
v = self.weights | |
mat1 = comfy.model_management.cast_to_device( | |
v[0], weight.device, intermediate_dtype | |
) | |
mat2 = comfy.model_management.cast_to_device( | |
v[1], weight.device, intermediate_dtype | |
) | |
dora_scale = v[4] | |
reshape = v[5] | |
if reshape is not None: | |
weight = pad_tensor_to_shape(weight, reshape) | |
if v[2] is not None: | |
alpha = v[2] / mat2.shape[0] | |
else: | |
alpha = 1.0 | |
if v[3] is not None: | |
# locon mid weights, hopefully the math is fine because I didn't properly test it | |
mat3 = comfy.model_management.cast_to_device( | |
v[3], weight.device, intermediate_dtype | |
) | |
final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]] | |
mat2 = ( | |
torch.mm( | |
mat2.transpose(0, 1).flatten(start_dim=1), | |
mat3.transpose(0, 1).flatten(start_dim=1), | |
) | |
.reshape(final_shape) | |
.transpose(0, 1) | |
) | |
try: | |
lora_diff = torch.mm( | |
mat1.flatten(start_dim=1), mat2.flatten(start_dim=1) | |
).reshape(weight.shape) | |
if dora_scale is not None: | |
weight = weight_decompose( | |
dora_scale, | |
weight, | |
lora_diff, | |
alpha, | |
strength, | |
intermediate_dtype, | |
function, | |
) | |
else: | |
weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) | |
except Exception as e: | |
logging.error("ERROR {} {} {}".format(self.name, key, e)) | |
return weight | |