Update patch_comfyui_nunchaku_lora.py

#1
Files changed (1) hide show
  1. patch_comfyui_nunchaku_lora.py +128 -116
patch_comfyui_nunchaku_lora.py CHANGED
@@ -1,116 +1,128 @@
1
- import safetensors.torch
2
- from safetensors import safe_open
3
- import torch
4
-
5
- def patch_final_layer_adaLN(state_dict, prefix="lora_unet_final_layer", verbose=True):
6
- """
7
- Add dummy adaLN weights if missing, using final_layer_linear shapes as reference.
8
- Args:
9
- state_dict (dict): keys -> tensors
10
- prefix (str): base name for final_layer keys
11
- verbose (bool): print debug info
12
- Returns:
13
- dict: patched state_dict
14
- """
15
- final_layer_linear_down = None
16
- final_layer_linear_up = None
17
-
18
- adaLN_down_key = f"{prefix}_adaLN_modulation_1.lora_down.weight"
19
- adaLN_up_key = f"{prefix}_adaLN_modulation_1.lora_up.weight"
20
- linear_down_key = f"{prefix}_linear.lora_down.weight"
21
- linear_up_key = f"{prefix}_linear.lora_up.weight"
22
-
23
- if verbose:
24
- print(f"\n🔍 Checking for final_layer keys with prefix: '{prefix}'")
25
- print(f" Linear down: {linear_down_key}")
26
- print(f" Linear up: {linear_up_key}")
27
-
28
- if linear_down_key in state_dict:
29
- final_layer_linear_down = state_dict[linear_down_key]
30
- if linear_up_key in state_dict:
31
- final_layer_linear_up = state_dict[linear_up_key]
32
-
33
- has_adaLN = adaLN_down_key in state_dict and adaLN_up_key in state_dict
34
- has_linear = final_layer_linear_down is not None and final_layer_linear_up is not None
35
-
36
- if verbose:
37
- print(f" ✅ Has final_layer.linear: {has_linear}")
38
- print(f" ✅ Has final_layer.adaLN_modulation_1: {has_adaLN}")
39
-
40
- if has_linear and not has_adaLN:
41
- dummy_down = torch.zeros_like(final_layer_linear_down)
42
- dummy_up = torch.zeros_like(final_layer_linear_up)
43
- state_dict[adaLN_down_key] = dummy_down
44
- state_dict[adaLN_up_key] = dummy_up
45
-
46
- if verbose:
47
- print(f"✅ Added dummy adaLN weights:")
48
- print(f" {adaLN_down_key} (shape: {dummy_down.shape})")
49
- print(f" {adaLN_up_key} (shape: {dummy_up.shape})")
50
- else:
51
- if verbose:
52
- print("✅ No patch needed — adaLN weights already present or no final_layer.linear found.")
53
-
54
- return state_dict
55
-
56
-
57
- def main():
58
- print("🔄 Universal final_layer.adaLN LoRA patcher (.safetensors)")
59
- input_path = input("Enter path to input LoRA .safetensors file: ").strip()
60
- output_path = input("Enter path to save patched LoRA .safetensors file: ").strip()
61
-
62
- # Load
63
- state_dict = {}
64
- with safe_open(input_path, framework="pt", device="cpu") as f:
65
- for k in f.keys():
66
- state_dict[k] = f.get_tensor(k)
67
-
68
- print(f"\n✅ Loaded {len(state_dict)} tensors from: {input_path}")
69
-
70
- # Show all keys that mention 'final_layer' for debug
71
- final_keys = [k for k in state_dict if "final_layer" in k]
72
- if final_keys:
73
- print("\n🔑 Found these final_layer-related keys:")
74
- for k in final_keys:
75
- print(f" {k}")
76
- else:
77
- print("\n⚠️ No keys with 'final_layer' found — will try patch anyway.")
78
-
79
- # Try common prefixes in order
80
- prefixes = [
81
- "lora_unet_final_layer",
82
- "final_layer",
83
- "base_model.model.final_layer"
84
- ]
85
- patched = False
86
-
87
- for prefix in prefixes:
88
- before = len(state_dict)
89
- state_dict = patch_final_layer_adaLN(state_dict, prefix=prefix)
90
- after = len(state_dict)
91
- if after > before:
92
- patched = True
93
- break # Stop after the first successful patch
94
-
95
- if not patched:
96
- print("\nℹ️ No patch applied — either adaLN already exists or no final_layer.linear found.")
97
-
98
- # Save
99
- safetensors.torch.save_file(state_dict, output_path)
100
- print(f"\n✅ Patched file saved to: {output_path}")
101
- print(f" Total tensors now: {len(state_dict)}")
102
-
103
- # Verify
104
- print("\n🔍 Verifying patched keys:")
105
- with safe_open(output_path, framework="pt", device="cpu") as f:
106
- keys = list(f.keys())
107
- for k in keys:
108
- if "final_layer" in k:
109
- print(f" {k}")
110
-
111
- has_adaLN_after = any("adaLN_modulation_1" in k for k in keys)
112
- print(f"✅ Contains adaLN after patch: {has_adaLN_after}")
113
-
114
-
115
- if __name__ == "__main__":
116
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import safetensors.torch
2
+ from safetensors import safe_open
3
+ import torch
4
+ import os
5
+ import tkinter as tk
6
+ from tkinter import filedialog
7
+
8
+ def patch_final_layer_adaLN(state_dict, prefix="lora_unet_final_layer", verbose=True):
9
+ final_layer_linear_down = None
10
+ final_layer_linear_up = None
11
+
12
+ adaLN_down_key = f"{prefix}_adaLN_modulation_1.lora_down.weight"
13
+ adaLN_up_key = f"{prefix}_adaLN_modulation_1.lora_up.weight"
14
+ linear_down_key = f"{prefix}_linear.lora_down.weight"
15
+ linear_up_key = f"{prefix}_linear.lora_up.weight"
16
+
17
+ if verbose:
18
+ print(f"\n🔍 Checking for final_layer keys with prefix: '{prefix}'")
19
+ print(f" Linear down: {linear_down_key}")
20
+ print(f" Linear up: {linear_up_key}")
21
+
22
+ if linear_down_key in state_dict:
23
+ final_layer_linear_down = state_dict[linear_down_key]
24
+ if linear_up_key in state_dict:
25
+ final_layer_linear_up = state_dict[linear_up_key]
26
+
27
+ has_adaLN = adaLN_down_key in state_dict and adaLN_up_key in state_dict
28
+ has_linear = final_layer_linear_down is not None and final_layer_linear_up is not None
29
+
30
+ if verbose:
31
+ print(f" ✅ Has final_layer.linear: {has_linear}")
32
+ print(f" ✅ Has final_layer.adaLN_modulation_1: {has_adaLN}")
33
+
34
+ if has_linear and not has_adaLN:
35
+ dummy_down = torch.zeros_like(final_layer_linear_down)
36
+ dummy_up = torch.zeros_like(final_layer_linear_up)
37
+ state_dict[adaLN_down_key] = dummy_down
38
+ state_dict[adaLN_up_key] = dummy_up
39
+
40
+ if verbose:
41
+ print(f"✅ Added dummy adaLN weights:")
42
+ print(f" {adaLN_down_key} (shape: {dummy_down.shape})")
43
+ print(f" {adaLN_up_key} (shape: {dummy_up.shape})")
44
+ else:
45
+ if verbose:
46
+ print("✅ No patch needed — adaLN weights already present or no final_layer.linear found.")
47
+
48
+ return state_dict
49
+
50
+ def main():
51
+ print("🔄 Universal final_layer.adaLN LoRA patcher (.safetensors)")
52
+
53
+ # GUI for file/folder selection
54
+ root = tk.Tk()
55
+ root.withdraw()
56
+
57
+ input_path = filedialog.askopenfilename(
58
+ title="Select LoRA .safetensors file",
59
+ filetypes=[("Safetensors files", "*.safetensors")]
60
+ )
61
+ if not input_path:
62
+ print("❌ No file selected. Exiting.")
63
+ return
64
+
65
+ output_dir = filedialog.askdirectory(
66
+ title="Select folder to save patched file"
67
+ )
68
+ if not output_dir:
69
+ print("❌ No folder selected. Exiting.")
70
+ return
71
+
72
+ # Generate output filename
73
+ base_name = os.path.basename(input_path)
74
+ name, ext = os.path.splitext(base_name)
75
+ output_filename = f"{name}-Patched{ext}"
76
+ output_path = os.path.join(output_dir, output_filename)
77
+
78
+ # Load
79
+ state_dict = {}
80
+ with safe_open(input_path, framework="pt", device="cpu") as f:
81
+ for k in f.keys():
82
+ state_dict[k] = f.get_tensor(k)
83
+
84
+ print(f"\n✅ Loaded {len(state_dict)} tensors from: {input_path}")
85
+
86
+ final_keys = [k for k in state_dict if "final_layer" in k]
87
+ if final_keys:
88
+ print("\n🔑 Found these final_layer-related keys:")
89
+ for k in final_keys:
90
+ print(f" {k}")
91
+ else:
92
+ print("\n⚠️ No keys with 'final_layer' found — will try patch anyway.")
93
+
94
+ prefixes = [
95
+ "lora_unet_final_layer",
96
+ "final_layer",
97
+ "base_model.model.final_layer"
98
+ ]
99
+ patched = False
100
+
101
+ for prefix in prefixes:
102
+ before = len(state_dict)
103
+ state_dict = patch_final_layer_adaLN(state_dict, prefix=prefix)
104
+ after = len(state_dict)
105
+ if after > before:
106
+ patched = True
107
+ break
108
+
109
+ if not patched:
110
+ print("\nℹ️ No patch applied — either adaLN already exists or no final_layer.linear found.")
111
+
112
+ # Save
113
+ safetensors.torch.save_file(state_dict, output_path)
114
+ print(f"\n✅ Patched file saved to: {output_path}")
115
+ print(f" Total tensors now: {len(state_dict)}")
116
+
117
+ # Verify
118
+ print("\n🔍 Verifying patched keys:")
119
+ with safe_open(output_path, framework="pt", device="cpu") as f:
120
+ keys = list(f.keys())
121
+ for k in keys:
122
+ if "final_layer" in k:
123
+ print(f" {k}")
124
+ has_adaLN_after = any("adaLN_modulation_1" in k for k in keys)
125
+ print(f"✅ Contains adaLN after patch: {has_adaLN_after}")
126
+
127
+ if __name__ == "__main__":
128
+ main()