Spaces:
Running
on
Zero
Running
on
Zero
import logging | |
from typing import Optional | |
import torch | |
import comfy.model_management | |
from .base import WeightAdapterBase, weight_decompose | |
class GLoRAAdapter(WeightAdapterBase): | |
name = "glora" | |
def __init__(self, loaded_keys, weights): | |
self.loaded_keys = loaded_keys | |
self.weights = weights | |
def load( | |
cls, | |
x: str, | |
lora: dict[str, torch.Tensor], | |
alpha: float, | |
dora_scale: torch.Tensor, | |
loaded_keys: set[str] = None, | |
) -> Optional["GLoRAAdapter"]: | |
if loaded_keys is None: | |
loaded_keys = set() | |
a1_name = "{}.a1.weight".format(x) | |
a2_name = "{}.a2.weight".format(x) | |
b1_name = "{}.b1.weight".format(x) | |
b2_name = "{}.b2.weight".format(x) | |
if a1_name in lora: | |
weights = (lora[a1_name], lora[a2_name], lora[b1_name], lora[b2_name], alpha, dora_scale) | |
loaded_keys.add(a1_name) | |
loaded_keys.add(a2_name) | |
loaded_keys.add(b1_name) | |
loaded_keys.add(b2_name) | |
return cls(loaded_keys, weights) | |
else: | |
return None | |
def calculate_weight( | |
self, | |
weight, | |
key, | |
strength, | |
strength_model, | |
offset, | |
function, | |
intermediate_dtype=torch.float32, | |
original_weight=None, | |
): | |
v = self.weights | |
dora_scale = v[5] | |
old_glora = False | |
if v[3].shape[1] == v[2].shape[0] == v[0].shape[0] == v[1].shape[1]: | |
rank = v[0].shape[0] | |
old_glora = True | |
if v[3].shape[0] == v[2].shape[1] == v[0].shape[1] == v[1].shape[0]: | |
if old_glora and v[1].shape[0] == weight.shape[0] and weight.shape[0] == weight.shape[1]: | |
pass | |
else: | |
old_glora = False | |
rank = v[1].shape[0] | |
a1 = comfy.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, intermediate_dtype) | |
a2 = comfy.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, intermediate_dtype) | |
b1 = comfy.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, intermediate_dtype) | |
b2 = comfy.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, intermediate_dtype) | |
if v[4] is not None: | |
alpha = v[4] / rank | |
else: | |
alpha = 1.0 | |
try: | |
if old_glora: | |
lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1).to(dtype=intermediate_dtype), a2), a1)).reshape(weight.shape) #old lycoris glora | |
else: | |
if weight.dim() > 2: | |
lora_diff = torch.einsum("o i ..., i j -> o j ...", torch.einsum("o i ..., i j -> o j ...", weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape) | |
else: | |
lora_diff = torch.mm(torch.mm(weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape) | |
lora_diff += torch.mm(b1, b2).reshape(weight.shape) | |
if dora_scale is not None: | |
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function) | |
else: | |
weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) | |
except Exception as e: | |
logging.error("ERROR {} {} {}".format(self.name, key, e)) | |
return weight | |