tori29umai commited on
Commit
77e9bab
·
verified ·
1 Parent(s): e80905e

Upload 2 files

Browse files
Files changed (2) hide show
  1. utils/fp8_optimization_utils.py +277 -0
  2. utils/lora_utils.py +234 -0
utils/fp8_optimization_utils.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from tqdm import tqdm
6
+
7
+
8
+ def calculate_fp8_maxval(exp_bits=4, mantissa_bits=3, sign_bits=1):
9
+ """
10
+ Calculate the maximum representable value in FP8 format.
11
+ Default is E4M3 format (4-bit exponent, 3-bit mantissa, 1-bit sign).
12
+
13
+ Args:
14
+ exp_bits (int): Number of exponent bits
15
+ mantissa_bits (int): Number of mantissa bits
16
+ sign_bits (int): Number of sign bits (0 or 1)
17
+
18
+ Returns:
19
+ float: Maximum value representable in FP8 format
20
+ """
21
+ assert exp_bits + mantissa_bits + sign_bits == 8, "Total bits must be 8"
22
+
23
+ # Calculate exponent bias
24
+ bias = 2 ** (exp_bits - 1) - 1
25
+
26
+ # Calculate maximum mantissa value
27
+ mantissa_max = 1.0
28
+ for i in range(mantissa_bits - 1):
29
+ mantissa_max += 2 ** -(i + 1)
30
+
31
+ # Calculate maximum value
32
+ max_value = mantissa_max * (2 ** (2**exp_bits - 1 - bias))
33
+
34
+ return max_value
35
+
36
+
37
+ def quantize_tensor_to_fp8(tensor, scale, exp_bits=4, mantissa_bits=3, sign_bits=1, max_value=None, min_value=None):
38
+ """
39
+ Quantize a tensor to FP8 format.
40
+
41
+ Args:
42
+ tensor (torch.Tensor): Tensor to quantize
43
+ scale (float or torch.Tensor): Scale factor
44
+ exp_bits (int): Number of exponent bits
45
+ mantissa_bits (int): Number of mantissa bits
46
+ sign_bits (int): Number of sign bits
47
+
48
+ Returns:
49
+ tuple: (quantized_tensor, scale_factor)
50
+ """
51
+ # Create scaled tensor
52
+ scaled_tensor = tensor / scale
53
+
54
+ # Calculate FP8 parameters
55
+ bias = 2 ** (exp_bits - 1) - 1
56
+
57
+ if max_value is None:
58
+ # Calculate max and min values
59
+ max_value = calculate_fp8_maxval(exp_bits, mantissa_bits, sign_bits)
60
+ min_value = -max_value if sign_bits > 0 else 0.0
61
+
62
+ # Clamp tensor to range
63
+ clamped_tensor = torch.clamp(scaled_tensor, min_value, max_value)
64
+
65
+ # Quantization process
66
+ abs_values = torch.abs(clamped_tensor)
67
+ nonzero_mask = abs_values > 0
68
+
69
+ # Calculate logF scales (only for non-zero elements)
70
+ log_scales = torch.zeros_like(clamped_tensor)
71
+ if nonzero_mask.any():
72
+ log_scales[nonzero_mask] = torch.floor(torch.log2(abs_values[nonzero_mask]) + bias).detach()
73
+
74
+ # Limit log scales and calculate quantization factor
75
+ log_scales = torch.clamp(log_scales, min=1.0)
76
+ quant_factor = 2.0 ** (log_scales - mantissa_bits - bias)
77
+
78
+ # Quantize and dequantize
79
+ quantized = torch.round(clamped_tensor / quant_factor) * quant_factor
80
+
81
+ return quantized, scale
82
+
83
+
84
+ def optimize_state_dict_with_fp8(
85
+ state_dict, calc_device, target_layer_keys=None, exclude_layer_keys=None, exp_bits=4, mantissa_bits=3, move_to_device=False
86
+ ):
87
+ """
88
+ Optimize Linear layer weights in a model's state dict to FP8 format.
89
+
90
+ Args:
91
+ state_dict (dict): State dict to optimize, replaced in-place
92
+ calc_device (str): Device to quantize tensors on
93
+ target_layer_keys (list, optional): Layer key patterns to target (None for all Linear layers)
94
+ exclude_layer_keys (list, optional): Layer key patterns to exclude
95
+ exp_bits (int): Number of exponent bits
96
+ mantissa_bits (int): Number of mantissa bits
97
+ move_to_device (bool): Move optimized tensors to the calculating device
98
+
99
+ Returns:
100
+ dict: FP8 optimized state dict
101
+ """
102
+ if exp_bits == 4 and mantissa_bits == 3:
103
+ fp8_dtype = torch.float8_e4m3fn
104
+ elif exp_bits == 5 and mantissa_bits == 2:
105
+ fp8_dtype = torch.float8_e5m2
106
+ else:
107
+ raise ValueError(f"Unsupported FP8 format: E{exp_bits}M{mantissa_bits}")
108
+
109
+ # Calculate FP8 max value
110
+ max_value = calculate_fp8_maxval(exp_bits, mantissa_bits)
111
+ min_value = -max_value # this function supports only signed FP8
112
+
113
+ # Create optimized state dict
114
+ optimized_count = 0
115
+
116
+ # Enumerate tarket keys
117
+ target_state_dict_keys = []
118
+ for key in state_dict.keys():
119
+ # Check if it's a weight key and matches target patterns
120
+ is_target = (target_layer_keys is None or any(pattern in key for pattern in target_layer_keys)) and key.endswith(".weight")
121
+ is_excluded = exclude_layer_keys is not None and any(pattern in key for pattern in exclude_layer_keys)
122
+ is_target = is_target and not is_excluded
123
+
124
+ if is_target and isinstance(state_dict[key], torch.Tensor):
125
+ target_state_dict_keys.append(key)
126
+
127
+ # Process each key
128
+ for key in tqdm(target_state_dict_keys):
129
+ value = state_dict[key]
130
+
131
+ # Save original device and dtype
132
+ original_device = value.device
133
+ original_dtype = value.dtype
134
+
135
+ # Move to calculation device
136
+ if calc_device is not None:
137
+ value = value.to(calc_device)
138
+
139
+ # Calculate scale factor
140
+ scale = torch.max(torch.abs(value.flatten())) / max_value
141
+ # print(f"Optimizing {key} with scale: {scale}")
142
+
143
+ # Quantize weight to FP8
144
+ quantized_weight, _ = quantize_tensor_to_fp8(value, scale, exp_bits, mantissa_bits, 1, max_value, min_value)
145
+
146
+ # Add to state dict using original key for weight and new key for scale
147
+ fp8_key = key # Maintain original key
148
+ scale_key = key.replace(".weight", ".scale_weight")
149
+
150
+ quantized_weight = quantized_weight.to(fp8_dtype)
151
+
152
+ if not move_to_device:
153
+ quantized_weight = quantized_weight.to(original_device)
154
+
155
+ scale_tensor = torch.tensor([scale], dtype=original_dtype, device=quantized_weight.device)
156
+
157
+ state_dict[fp8_key] = quantized_weight
158
+ state_dict[scale_key] = scale_tensor
159
+
160
+ optimized_count += 1
161
+
162
+ if calc_device is not None: # optimized_count % 10 == 0 and
163
+ # free memory on calculation device
164
+ torch.cuda.empty_cache() # TODO check device typ
165
+
166
+ print(f"Number of optimized Linear layers: {optimized_count}")
167
+ return state_dict
168
+
169
+
170
+ def fp8_linear_forward_patch(self: nn.Linear, x, use_scaled_mm=False, max_value=None):
171
+ """
172
+ Patched forward method for Linear layers with FP8 weights.
173
+
174
+ Args:
175
+ self: Linear layer instance
176
+ x (torch.Tensor): Input tensor
177
+ use_scaled_mm (bool): Use scaled_mm for FP8 Linear layers, requires SM 8.9+ (RTX 40 series)
178
+ max_value (float): Maximum value for FP8 quantization. If None, no quantization is applied for input tensor.
179
+
180
+ Returns:
181
+ torch.Tensor: Result of linear transformation
182
+ """
183
+ if use_scaled_mm:
184
+ input_dtype = x.dtype
185
+ original_weight_dtype = self.scale_weight.dtype
186
+ weight_dtype = self.weight.dtype
187
+ target_dtype = torch.float8_e5m2
188
+ assert weight_dtype == torch.float8_e4m3fn, "Only FP8 E4M3FN format is supported"
189
+ assert x.ndim == 3, "Input tensor must be 3D (batch_size, seq_len, hidden_dim)"
190
+
191
+ if max_value is None:
192
+ # no input quantization
193
+ scale_x = torch.tensor(1.0, dtype=torch.float32, device=x.device)
194
+ else:
195
+ # calculate scale factor for input tensor
196
+ scale_x = (torch.max(torch.abs(x.flatten())) / max_value).to(torch.float32)
197
+
198
+ # quantize input tensor to FP8: this seems to consume a lot of memory
199
+ x, _ = quantize_tensor_to_fp8(x, scale_x, 5, 2, 1, max_value, -max_value)
200
+
201
+ original_shape = x.shape
202
+ x = x.reshape(-1, x.shape[2]).to(target_dtype)
203
+
204
+ weight = self.weight.t()
205
+ scale_weight = self.scale_weight.to(torch.float32)
206
+
207
+ if self.bias is not None:
208
+ # float32 is not supported with bias in scaled_mm
209
+ o = torch._scaled_mm(x, weight, out_dtype=original_weight_dtype, bias=self.bias, scale_a=scale_x, scale_b=scale_weight)
210
+ else:
211
+ o = torch._scaled_mm(x, weight, out_dtype=input_dtype, scale_a=scale_x, scale_b=scale_weight)
212
+
213
+ return o.reshape(original_shape[0], original_shape[1], -1).to(input_dtype)
214
+
215
+ else:
216
+ # Dequantize the weight
217
+ original_dtype = self.scale_weight.dtype
218
+ dequantized_weight = self.weight.to(original_dtype) * self.scale_weight
219
+
220
+ # Perform linear transformation
221
+ if self.bias is not None:
222
+ output = F.linear(x, dequantized_weight, self.bias)
223
+ else:
224
+ output = F.linear(x, dequantized_weight)
225
+
226
+ return output
227
+
228
+
229
+ def apply_fp8_monkey_patch(model, optimized_state_dict, use_scaled_mm=False):
230
+ """
231
+ Apply monkey patching to a model using FP8 optimized state dict.
232
+
233
+ Args:
234
+ model (nn.Module): Model instance to patch
235
+ optimized_state_dict (dict): FP8 optimized state dict
236
+ use_scaled_mm (bool): Use scaled_mm for FP8 Linear layers, requires SM 8.9+ (RTX 40 series)
237
+
238
+ Returns:
239
+ nn.Module: The patched model (same instance, modified in-place)
240
+ """
241
+ # # Calculate FP8 float8_e5m2 max value
242
+ # max_value = calculate_fp8_maxval(5, 2)
243
+ max_value = None # do not quantize input tensor
244
+
245
+ # Find all scale keys to identify FP8-optimized layers
246
+ scale_keys = [k for k in optimized_state_dict.keys() if k.endswith(".scale_weight")]
247
+
248
+ # Enumerate patched layers
249
+ patched_module_paths = set()
250
+ for scale_key in scale_keys:
251
+ # Extract module path from scale key (remove .scale_weight)
252
+ module_path = scale_key.rsplit(".scale_weight", 1)[0]
253
+ patched_module_paths.add(module_path)
254
+
255
+ patched_count = 0
256
+
257
+ # Apply monkey patch to each layer with FP8 weights
258
+ for name, module in model.named_modules():
259
+ # Check if this module has a corresponding scale_weight
260
+ has_scale = name in patched_module_paths
261
+
262
+ # Apply patch if it's a Linear layer with FP8 scale
263
+ if isinstance(module, nn.Linear) and has_scale:
264
+ # register the scale_weight as a buffer to load the state_dict
265
+ module.register_buffer("scale_weight", torch.tensor(1.0, dtype=module.weight.dtype))
266
+
267
+ # Create a new forward method with the patched version.
268
+ def new_forward(self, x):
269
+ return fp8_linear_forward_patch(self, x, use_scaled_mm, max_value)
270
+
271
+ # Bind method to module
272
+ module.forward = new_forward.__get__(module, type(module))
273
+
274
+ patched_count += 1
275
+
276
+ print(f"Number of monkey-patched Linear layers: {patched_count}")
277
+ return model
utils/lora_utils.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from safetensors.torch import load_file
4
+ from tqdm import tqdm
5
+
6
+
7
+ def merge_lora_to_state_dict(
8
+ state_dict: dict[str, torch.Tensor], lora_file: str, multiplier: float, device: torch.device
9
+ ) -> dict[str, torch.Tensor]:
10
+ """
11
+ Merge LoRA weights into the state dict of a model.
12
+ """
13
+ lora_sd = load_file(lora_file)
14
+
15
+ # Check the format of the LoRA file
16
+ keys = list(lora_sd.keys())
17
+ if keys[0].startswith("lora_unet_"):
18
+ print(f"Musubi Tuner LoRA detected")
19
+ return merge_musubi_tuner(lora_sd, state_dict, multiplier, device)
20
+
21
+ transformer_prefixes = ["diffusion_model", "transformer"] # to ignore Text Encoder modules
22
+ lora_suffix = None
23
+ prefix = None
24
+ for key in keys:
25
+ if lora_suffix is None and "lora_A" in key:
26
+ lora_suffix = "lora_A"
27
+ if prefix is None:
28
+ pfx = key.split(".")[0]
29
+ if pfx in transformer_prefixes:
30
+ prefix = pfx
31
+ if lora_suffix is not None and prefix is not None:
32
+ break
33
+
34
+ if lora_suffix == "lora_A" and prefix is not None:
35
+ print(f"Diffusion-pipe (?) LoRA detected")
36
+ return merge_diffusion_pipe_or_something(lora_sd, state_dict, "lora_unet_", multiplier, device)
37
+
38
+ print(f"LoRA file format not recognized: {os.path.basename(lora_file)}")
39
+ return state_dict
40
+
41
+
42
+ def merge_diffusion_pipe_or_something(
43
+ lora_sd: dict[str, torch.Tensor], state_dict: dict[str, torch.Tensor], prefix: str, multiplier: float, device: torch.device
44
+ ) -> dict[str, torch.Tensor]:
45
+ """
46
+ Convert LoRA weights to the format used by the diffusion pipeline to Musubi Tuner.
47
+ Copy from Musubi Tuner repo.
48
+ """
49
+ # convert from diffusers(?) to default LoRA
50
+ # Diffusers format: {"diffusion_model.module.name.lora_A.weight": weight, "diffusion_model.module.name.lora_B.weight": weight, ...}
51
+ # default LoRA format: {"prefix_module_name.lora_down.weight": weight, "prefix_module_name.lora_up.weight": weight, ...}
52
+
53
+ # note: Diffusers has no alpha, so alpha is set to rank
54
+ new_weights_sd = {}
55
+ lora_dims = {}
56
+ for key, weight in lora_sd.items():
57
+ diffusers_prefix, key_body = key.split(".", 1)
58
+ if diffusers_prefix != "diffusion_model" and diffusers_prefix != "transformer":
59
+ print(f"unexpected key: {key} in diffusers format")
60
+ continue
61
+
62
+ new_key = f"{prefix}{key_body}".replace(".", "_").replace("_lora_A_", ".lora_down.").replace("_lora_B_", ".lora_up.")
63
+ new_weights_sd[new_key] = weight
64
+
65
+ lora_name = new_key.split(".")[0] # before first dot
66
+ if lora_name not in lora_dims and "lora_down" in new_key:
67
+ lora_dims[lora_name] = weight.shape[0]
68
+
69
+ # add alpha with rank
70
+ for lora_name, dim in lora_dims.items():
71
+ new_weights_sd[f"{lora_name}.alpha"] = torch.tensor(dim)
72
+
73
+ return merge_musubi_tuner(new_weights_sd, state_dict, multiplier, device)
74
+
75
+
76
+ def merge_musubi_tuner(
77
+ lora_sd: dict[str, torch.Tensor], state_dict: dict[str, torch.Tensor], multiplier: float, device: torch.device
78
+ ) -> dict[str, torch.Tensor]:
79
+ """
80
+ Merge LoRA weights into the state dict of a model.
81
+ """
82
+ # Check LoRA is for FramePack or for HunyuanVideo
83
+ is_hunyuan = False
84
+ for key in lora_sd.keys():
85
+ if "double_blocks" in key or "single_blocks" in key:
86
+ is_hunyuan = True
87
+ break
88
+ if is_hunyuan:
89
+ print("HunyuanVideo LoRA detected, converting to FramePack format")
90
+ lora_sd = convert_hunyuan_to_framepack(lora_sd)
91
+
92
+ # Merge LoRA weights into the state dict
93
+ print(f"Merging LoRA weights into state dict. multiplier: {multiplier}")
94
+
95
+ # Create module map
96
+ name_to_original_key = {}
97
+ for key in state_dict.keys():
98
+ if key.endswith(".weight"):
99
+ lora_name = key.rsplit(".", 1)[0] # remove trailing ".weight"
100
+ lora_name = "lora_unet_" + lora_name.replace(".", "_")
101
+ if lora_name not in name_to_original_key:
102
+ name_to_original_key[lora_name] = key
103
+
104
+ # Merge LoRA weights
105
+ keys = list([k for k in lora_sd.keys() if "lora_down" in k])
106
+ for key in tqdm(keys, desc="Merging LoRA weights"):
107
+ up_key = key.replace("lora_down", "lora_up")
108
+ alpha_key = key[: key.index("lora_down")] + "alpha"
109
+
110
+ # find original key for this lora
111
+ module_name = ".".join(key.split(".")[:-2]) # remove trailing ".lora_down.weight"
112
+ if module_name not in name_to_original_key:
113
+ print(f"No module found for LoRA weight: {key}")
114
+ continue
115
+
116
+ original_key = name_to_original_key[module_name]
117
+
118
+ down_weight = lora_sd[key]
119
+ up_weight = lora_sd[up_key]
120
+
121
+ dim = down_weight.size()[0]
122
+ alpha = lora_sd.get(alpha_key, dim)
123
+ scale = alpha / dim
124
+
125
+ weight = state_dict[original_key]
126
+ original_device = weight.device
127
+ if original_device != device:
128
+ weight = weight.to(device) # to make calculation faster
129
+
130
+ down_weight = down_weight.to(device)
131
+ up_weight = up_weight.to(device)
132
+
133
+ # W <- W + U * D
134
+ if len(weight.size()) == 2:
135
+ # linear
136
+ if len(up_weight.size()) == 4: # use linear projection mismatch
137
+ up_weight = up_weight.squeeze(3).squeeze(2)
138
+ down_weight = down_weight.squeeze(3).squeeze(2)
139
+ weight = weight + multiplier * (up_weight @ down_weight) * scale
140
+ elif down_weight.size()[2:4] == (1, 1):
141
+ # conv2d 1x1
142
+ weight = (
143
+ weight
144
+ + multiplier
145
+ * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
146
+ * scale
147
+ )
148
+ else:
149
+ # conv2d 3x3
150
+ conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
151
+ # logger.info(conved.size(), weight.size(), module.stride, module.padding)
152
+ weight = weight + multiplier * conved * scale
153
+
154
+ weight = weight.to(original_device) # move back to original device
155
+ state_dict[original_key] = weight
156
+
157
+ return state_dict
158
+
159
+
160
+ def convert_hunyuan_to_framepack(lora_sd: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
161
+ """
162
+ Convert HunyuanVideo LoRA weights to FramePack format.
163
+ """
164
+ new_lora_sd = {}
165
+ for key, weight in lora_sd.items():
166
+ if "double_blocks" in key:
167
+ key = key.replace("double_blocks", "transformer_blocks")
168
+ key = key.replace("img_mod_linear", "norm1_linear")
169
+ key = key.replace("img_attn_qkv", "attn_to_QKV") # split later
170
+ key = key.replace("img_attn_proj", "attn_to_out_0")
171
+ key = key.replace("img_mlp_fc1", "ff_net_0_proj")
172
+ key = key.replace("img_mlp_fc2", "ff_net_2")
173
+ key = key.replace("txt_mod_linear", "norm1_context_linear")
174
+ key = key.replace("txt_attn_qkv", "attn_add_QKV_proj") # split later
175
+ key = key.replace("txt_attn_proj", "attn_to_add_out")
176
+ key = key.replace("txt_mlp_fc1", "ff_context_net_0_proj")
177
+ key = key.replace("txt_mlp_fc2", "ff_context_net_2")
178
+ elif "single_blocks" in key:
179
+ key = key.replace("single_blocks", "single_transformer_blocks")
180
+ key = key.replace("linear1", "attn_to_QKVM") # split later
181
+ key = key.replace("linear2", "proj_out")
182
+ key = key.replace("modulation_linear", "norm_linear")
183
+ else:
184
+ print(f"Unsupported module name: {key}, only double_blocks and single_blocks are supported")
185
+ continue
186
+
187
+ if "QKVM" in key:
188
+ # split QKVM into Q, K, V, M
189
+ key_q = key.replace("QKVM", "q")
190
+ key_k = key.replace("QKVM", "k")
191
+ key_v = key.replace("QKVM", "v")
192
+ key_m = key.replace("attn_to_QKVM", "proj_mlp")
193
+ if "_down" in key or "alpha" in key:
194
+ # copy QKVM weight or alpha to Q, K, V, M
195
+ assert "alpha" in key or weight.size(1) == 3072, f"QKVM weight size mismatch: {key}. {weight.size()}"
196
+ new_lora_sd[key_q] = weight
197
+ new_lora_sd[key_k] = weight
198
+ new_lora_sd[key_v] = weight
199
+ new_lora_sd[key_m] = weight
200
+ elif "_up" in key:
201
+ # split QKVM weight into Q, K, V, M
202
+ assert weight.size(0) == 21504, f"QKVM weight size mismatch: {key}. {weight.size()}"
203
+ new_lora_sd[key_q] = weight[:3072]
204
+ new_lora_sd[key_k] = weight[3072 : 3072 * 2]
205
+ new_lora_sd[key_v] = weight[3072 * 2 : 3072 * 3]
206
+ new_lora_sd[key_m] = weight[3072 * 3 :] # 21504 - 3072 * 3 = 12288
207
+ else:
208
+ print(f"Unsupported module name: {key}")
209
+ continue
210
+ elif "QKV" in key:
211
+ # split QKV into Q, K, V
212
+ key_q = key.replace("QKV", "q")
213
+ key_k = key.replace("QKV", "k")
214
+ key_v = key.replace("QKV", "v")
215
+ if "_down" in key or "alpha" in key:
216
+ # copy QKV weight or alpha to Q, K, V
217
+ assert "alpha" in key or weight.size(1) == 3072, f"QKV weight size mismatch: {key}. {weight.size()}"
218
+ new_lora_sd[key_q] = weight
219
+ new_lora_sd[key_k] = weight
220
+ new_lora_sd[key_v] = weight
221
+ elif "_up" in key:
222
+ # split QKV weight into Q, K, V
223
+ assert weight.size(0) == 3072 * 3, f"QKV weight size mismatch: {key}. {weight.size()}"
224
+ new_lora_sd[key_q] = weight[:3072]
225
+ new_lora_sd[key_k] = weight[3072 : 3072 * 2]
226
+ new_lora_sd[key_v] = weight[3072 * 2 :]
227
+ else:
228
+ print(f"Unsupported module name: {key}")
229
+ continue
230
+ else:
231
+ # no split needed
232
+ new_lora_sd[key] = weight
233
+
234
+ return new_lora_sd