sujitvasanth commited on
Commit
5fd42e9
·
verified ·
1 Parent(s): bd74dcb

Create quatisation_files/architecture.py

Browse files
Files changed (1) hide show
  1. quatisation_files/architecture.py +1020 -0
quatisation_files/architecture.py ADDED
@@ -0,0 +1,1020 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from enum import IntEnum
3
+
4
+ # Common keys
5
+
6
+ layer_keys_llama_norms = [["input_layernorm"],
7
+ ["post_attention_layernorm"]]
8
+ layer_keys_cohere_norms = [["input_layernorm"]]
9
+ layer_keys_gpt2_norms = [["ln_1"],
10
+ ["ln_2"]]
11
+ layer_keys_yi_norms = [["ln1", "input_layernorm"],
12
+ ["ln2", "post_attention_layernorm"]]
13
+ layer_keys_gemma2_norms = [["input_layernorm"],
14
+ ["post_attention_layernorm"],
15
+ ["pre_feedforward_layernorm"],
16
+ ["post_feedforward_layernorm"]]
17
+ layer_keys_internlm2_norms = [["attention_norm"],
18
+ ["ffn_norm"]]
19
+ layer_keys_glm4_norms = [["input_layernorm"],
20
+ ["post_self_attn_layernorm"],
21
+ ["post_attention_layernorm"],
22
+ ["post_mlp_layernorm"]]
23
+ layer_keys_llama_attn = [["self_attn.q_proj"],
24
+ ["self_attn.k_proj"],
25
+ ["self_attn.v_proj"],
26
+ ["self_attn.o_proj"]]
27
+ layer_keys_gpt2_attn = [["self_attn.c_attn", "self_attn.q_proj"],
28
+ ["self_attn.c_attn", "self_attn.k_proj"],
29
+ ["self_attn.c_attn", "self_attn.v_proj"],
30
+ ["self_attn.o_proj"]]
31
+ layer_keys_internlm2_attn = [["self_attn.wqkv", "self_attn.q_proj"],
32
+ ["self_attn.wqkv", "self_attn.k_proj"],
33
+ ["self_attn.wqkv", "self_attn.v_proj"],
34
+ ["self_attn.o_proj"]]
35
+ layer_keys_dbrx_attn = [["self_attn.Wqkv", "self_attn.q_proj"],
36
+ ["self_attn.Wqkv", "self_attn.k_proj"],
37
+ ["self_attn.Wqkv", "self_attn.v_proj"],
38
+ ["self_attn.o_proj"]]
39
+ layer_keys_phi3_attn = [["self_attn.qkv_proj", "self_attn.q_proj"],
40
+ ["self_attn.qkv_proj", "self_attn.k_proj"],
41
+ ["self_attn.qkv_proj", "self_attn.v_proj"],
42
+ ["self_attn.o_proj"]]
43
+ layer_keys_llama_mlp = [["mlp.down_proj"],
44
+ ["mlp.gate_proj"],
45
+ ["mlp.up_proj"]]
46
+ layer_keys_internlm2_mlp = [["feed_forward.w1"],
47
+ ["feed_forward.w2"],
48
+ ["feed_forward.w3"]]
49
+ layer_keys_phi3_mlp = [["mlp.down_proj"],
50
+ ["mlp.gate_up_proj", "mlp.gate_proj"],
51
+ ["mlp.gate_up_proj", "mlp.up_proj"]]
52
+ layer_keys_mixtral_mlp = [["block_sparse_moe.experts.*.w1"],
53
+ ["block_sparse_moe.experts.*.w2"],
54
+ ["block_sparse_moe.experts.*.w3"],
55
+ ["block_sparse_moe.gate"]]
56
+ layer_keys_qwen3moe_mlp = [["mlp.experts.*.gate_proj"],
57
+ ["mlp.experts.*.up_proj"],
58
+ ["mlp.experts.*.down_proj"],
59
+ ["mlp.gate"]]
60
+ layer_keys_dbrx_mlp = [["block_sparse_moe.experts.*.v1", "block_sparse_moe.experts.v1"],
61
+ ["block_sparse_moe.experts.*.w1", "block_sparse_moe.experts.w1"],
62
+ ["block_sparse_moe.experts.*.w2", "block_sparse_moe.experts.w2"],
63
+ ["block_sparse_moe.gate"]]
64
+ layer_keys_llama_mlp_swiglu = [["mlp.swiglu.w12"],
65
+ ["mlp.swiglu.w3"]]
66
+ layer_keys_starcoder2_mlp = [["mlp.c_fc"],
67
+ ["mlp.c_proj"]]
68
+ layer_keys_gpt2_mlp = [["mlp.c_fc"],
69
+ ["mlp.c_proj"]]
70
+ expect_keys_llama = [["lm_head"],
71
+ ["model.norm"],
72
+ ["model.embed_tokens"]]
73
+ expect_keys_gemma = [["model.norm"],
74
+ ["model.embed_tokens"]]
75
+ expect_keys_starcoder2 = [["model.norm"],
76
+ ["model.embed_tokens"]]
77
+ expect_keys_gpt2 = [["model.embed_tokens"]]
78
+
79
+ dbrx_keymap = [("transformer.", "model."),
80
+ (".blocks.", ".layers."),
81
+ (".ffn.experts.mlp.", ".block_sparse_moe.experts."),
82
+ (".ffn.router.layer.", ".block_sparse_moe.gate."),
83
+ (".norm_attn_norm.norm_1.", ".input_layernorm."),
84
+ (".norm_attn_norm.norm_2.", ".post_attention_layernorm."),
85
+ (".norm_attn_norm.attn.", ".self_attn."),
86
+ (".out_proj.", ".o_proj."),
87
+ (".norm_f.", ".norm."),
88
+ (".wte.", ".embed_tokens.")]
89
+ bigcode_keymap = [("transformer.ln_f", "model.norm"),
90
+ ("transformer.", "model."),
91
+ (".attn.c_proj.", ".self_attn.o_proj."),
92
+ (".attn.", ".self_attn."),
93
+ (".h.", ".layers."),
94
+ (".wte.", ".embed_tokens.")]
95
+ gpt2_keymap = [("$ln_f.", "model.norm."),
96
+ (".attn.c_proj.", ".self_attn.o_proj."),
97
+ (".attn.", ".self_attn."),
98
+ ("$h.", "model.layers."),
99
+ ("$wte.", "model.embed_tokens."),
100
+ ("$wpe.", "model.wpe.")]
101
+ internlm2_keymap = [("$output.", "lm_head."),
102
+ ("$model.tok_embeddings.", "model.embed_tokens."),
103
+ (".attention.", ".self_attn."),
104
+ (".wo.", ".o_proj.")]
105
+ google_keymap = [("mm_input_projection_weight", "mm_input_projection.weight")]
106
+
107
+ no_default = object()
108
+
109
+ class RopeStyle(IntEnum):
110
+ NONE = 0
111
+ GPTJ = 1
112
+ NEOX = 2
113
+
114
+ class ExLlamaV2ArchParams:
115
+
116
+ def __init__(self, arch_string: str, read_config: dict):
117
+ """
118
+ Get architecture definition from model config. If the architecture isn't recognized, defaults to Llama
119
+ architecture.
120
+
121
+ :param arch_string:
122
+ Architecture string from config.json
123
+
124
+ :param read_config:
125
+ config.json as Python dict
126
+ """
127
+
128
+ self.arch_string = arch_string
129
+ arch_recognized = False
130
+
131
+ self.keymap = None
132
+ self.compile_fix_keymap = None
133
+
134
+ @dataclass
135
+ class Params:
136
+ keys: dict = field(default_factory = lambda: {
137
+ "norm_eps": "rms_norm_eps",
138
+ "norm_1": ".input_layernorm",
139
+ "norm_1_post": None,
140
+ "fused_qkv": None,
141
+ "mlp_gate": ".mlp.gate_proj",
142
+ "mlp_up": ".mlp.up_proj",
143
+ "mlp_down": ".mlp.down_proj",
144
+ "lm_head": "lm_head",
145
+ "norm_2": ".post_attention_layernorm",
146
+ "norm_2_post": None,
147
+ "fused_mlp_12": None,
148
+ "fused_mlp_3": None,
149
+ "learned_pos_emb": None,
150
+ "attn_q": ".self_attn.q_proj",
151
+ "attn_k": ".self_attn.k_proj",
152
+ "attn_v": ".self_attn.v_proj",
153
+ "attn_o": ".self_attn.o_proj",
154
+ "layers": "layers",
155
+ "patch_conv": "patch_conv",
156
+ })
157
+
158
+ # Compute logit scale from `dim_model_base` key in config.json (MiniCPM quirk)
159
+ logit_scale_basedim = False
160
+
161
+ # Clamp hidden states to FP16 range
162
+ clamp_hidden_states = False
163
+
164
+ # Upcast hidden state to FP32 before adding to residual stream
165
+ residual_stream_fp32 = False
166
+
167
+ # Normalize embeddings (Gemma quirk)
168
+ normalize_embeddings = False
169
+
170
+ # Constant bias for layernorm (Gemma quirk)
171
+ norm_constant_bias = 0
172
+
173
+ # Alternate packing scheme for fused QKV tensor (InternLM2 quirk)
174
+ fused_qkv_altpack = False
175
+
176
+ # SWA required by architecture
177
+ swa = False
178
+ alternating_swa = False
179
+ sliding_rope_theta = None
180
+ sliding_rope_scale = None
181
+ pos_id_index = 0
182
+
183
+ # Model only works with eager attention
184
+ eager_attn_only = False
185
+
186
+ # Expect bias for linear layers
187
+ attention_bias_qkv = False
188
+ attention_bias_o = False
189
+ mlp_bias = False
190
+
191
+ # Default multiplier for MLP inner dim (GPT2 quirk)
192
+ default_inner_dim_mult = None
193
+
194
+ # Use gated MLP
195
+ mlp_gate = True
196
+
197
+ # Use block-sparse MLP
198
+ is_moe = False
199
+
200
+ # Use parallel decoder blocks (Cohere quirk)
201
+ parallel_decoder_blocks = False
202
+
203
+ # Use MQA, effectively num_key_value_heads = 1 (GPTBigCode quirk)
204
+ mqa = False
205
+
206
+ # Model is incoherent without BOS at the start of the context
207
+ requires_bos = False
208
+
209
+ # Scale attn weights (GPT2 quirk, not important for inference)
210
+ scale_attn_weights = False
211
+
212
+ # Model implementation works in tensor-parallel mode
213
+ supports_tp = False
214
+
215
+ # Activation function
216
+ mlp_act_func = "silu"
217
+
218
+ # Layer norm type
219
+ norm = "rmsnorm"
220
+ headnorm = "layernorm"
221
+
222
+ # RoPE style
223
+ rope_style = RopeStyle.NEOX
224
+
225
+ # Expected keys
226
+ expect_keys: list[str] = field(default_factory = lambda: [])
227
+ layer_keys: list[str] = field(default_factory = lambda: [])
228
+
229
+ # Defaults because Gemma3
230
+ default_vocab_size = no_default
231
+ default_rms_norm_eps = no_default
232
+ default_head_dim = no_default
233
+ default_num_attention_heads = no_default
234
+ default_num_key_value_heads = no_default
235
+ default_use_qk_norm = False
236
+ default_sliding_window_pattern = 1
237
+ default_rope_theta = 10000
238
+
239
+ # Vision stuff
240
+ patch_conv_bias: bool = False
241
+ is_vision: bool = False
242
+ vision_input_norm: bool = True
243
+ vision_conv3d: bool = False
244
+ mrope: bool = False
245
+ rope_freq_half: bool = False
246
+ learned_emb: bool = False
247
+ output_norm: bool = False
248
+ mlp_merger: bool = False
249
+ mlp_patch_merger: bool = False
250
+
251
+ # Component models
252
+ self.lm_prefix = ""
253
+ self.vt_prefix = ""
254
+ self.mmp_prefix = ""
255
+ self.lm = Params()
256
+ self.mmp = Params()
257
+ self.vt = Params()
258
+
259
+ self.mmp.keys.update({
260
+ "norm_1": None,
261
+ "norm_1_post": None,
262
+ "norm_2": None,
263
+ "norm_2_post": None,
264
+ "fused_mlp_12": None,
265
+ "fused_mlp_3": None,
266
+ })
267
+ self.mmp.rope_style = RopeStyle.NONE
268
+
269
+ self.vt.is_vision = True
270
+
271
+ # Tensors are transposed in original model weights
272
+ self.orig_weights_transposed = False
273
+
274
+ # Add noise rows to calibration while quantizing
275
+ self.standard_calib_noise = None
276
+
277
+ # Mistral
278
+
279
+ if arch_string == "MistralForCausalLM":
280
+ arch_recognized = True
281
+ self.lm.layer_keys += \
282
+ layer_keys_llama_norms + \
283
+ layer_keys_llama_attn + \
284
+ layer_keys_llama_mlp
285
+ self.lm.expect_keys += \
286
+ expect_keys_llama
287
+ self.lm.supports_tp = True
288
+
289
+ # Mixtral
290
+
291
+ if arch_string == "MixtralForCausalLM":
292
+ arch_recognized = True
293
+ self.lm.layer_keys += \
294
+ layer_keys_llama_norms + \
295
+ layer_keys_llama_attn + \
296
+ layer_keys_mixtral_mlp
297
+ self.lm.expect_keys += \
298
+ expect_keys_llama
299
+ self.lm.keys.update({
300
+ "mlp_gate": ".block_sparse_moe.experts.*.w1",
301
+ "mlp_up": ".block_sparse_moe.experts.*.w3",
302
+ "mlp_down": ".block_sparse_moe.experts.*.w2",
303
+ "mlp_expert_gate": ".block_sparse_moe.gate"
304
+ })
305
+ self.lm.is_moe = True
306
+
307
+ # Pixtral
308
+
309
+ if (
310
+ arch_string == "LlavaForConditionalGeneration" and
311
+ "vision_config" in read_config and
312
+ read_config["vision_config"].get("model_type") == "pixtral"
313
+ ):
314
+ arch_recognized = True
315
+ self.lm_prefix = "language_model."
316
+ self.lm.layer_keys += \
317
+ layer_keys_llama_norms + \
318
+ layer_keys_llama_attn + \
319
+ layer_keys_llama_mlp
320
+ self.lm.expect_keys += \
321
+ expect_keys_llama
322
+
323
+ self.vt_prefix = "vision_tower."
324
+ self.vt.keys.update({
325
+ "attn_q": ".attention.q_proj",
326
+ "attn_k": ".attention.k_proj",
327
+ "attn_v": ".attention.v_proj",
328
+ "attn_o": ".attention.o_proj",
329
+ "mlp_gate": ".feed_forward.gate_proj",
330
+ "mlp_up": ".feed_forward.up_proj",
331
+ "mlp_down": ".feed_forward.down_proj",
332
+ "norm_1": ".attention_norm",
333
+ "norm_2": ".ffn_norm",
334
+ "layers": "transformer.layers",
335
+ "ln_pre": "ln_pre",
336
+ })
337
+ self.vt.mlp_merger = True
338
+
339
+ self.mmp_prefix = "multi_modal_projector."
340
+ self.mmp.keys.update({
341
+ "mlp_gate": None,
342
+ "mlp_up": "linear_1",
343
+ "mlp_down": "linear_2",
344
+ })
345
+ self.mmp.mlp_gate = False
346
+ self.mmp.mlp_act_func = "gelu"
347
+ self.mmp.mlp_bias = bool(read_config.get("multimodal_projector_bias", True))
348
+
349
+ # Mistral 3 multimodal
350
+
351
+ if (
352
+ arch_string == "Mistral3ForConditionalGeneration" and
353
+ "vision_config" in read_config and
354
+ read_config["vision_config"].get("model_type") == "pixtral"
355
+ ):
356
+ arch_recognized = True
357
+ self.lm_prefix = "language_model."
358
+ self.lm.layer_keys += \
359
+ layer_keys_llama_norms + \
360
+ layer_keys_llama_attn + \
361
+ layer_keys_llama_mlp
362
+ self.lm.expect_keys += \
363
+ expect_keys_llama
364
+
365
+ self.vt_prefix = "vision_tower."
366
+ self.vt.keys.update({
367
+ "attn_q": ".attention.q_proj",
368
+ "attn_k": ".attention.k_proj",
369
+ "attn_v": ".attention.v_proj",
370
+ "attn_o": ".attention.o_proj",
371
+ "mlp_gate": ".feed_forward.gate_proj",
372
+ "mlp_up": ".feed_forward.up_proj",
373
+ "mlp_down": ".feed_forward.down_proj",
374
+ "norm_1": ".attention_norm",
375
+ "norm_2": ".ffn_norm",
376
+ "layers": "transformer.layers",
377
+ "ln_pre": "ln_pre",
378
+ })
379
+ self.vt.mlp_merger = True
380
+ self.vt.mlp_patch_merger = True
381
+
382
+ self.mmp_prefix = "multi_modal_projector."
383
+ self.mmp.keys.update({
384
+ "norm_2": "norm",
385
+ "mlp_gate": None,
386
+ "mlp_up": "linear_1",
387
+ "mlp_down": "linear_2",
388
+ "patch_merger": "patch_merger.merging_layer",
389
+ })
390
+ self.mmp.mlp_patch_merger = True
391
+ self.mmp.mlp_gate = False
392
+ self.mmp.mlp_act_func = "gelu"
393
+ self.mmp.mlp_bias = bool(read_config.get("multimodal_projector_bias", True))
394
+
395
+ # Yi
396
+
397
+ if arch_string == "YiForCausalLM":
398
+ arch_recognized = True
399
+ self.lm.layer_keys += \
400
+ layer_keys_yi_norms + \
401
+ layer_keys_llama_attn + \
402
+ layer_keys_llama_mlp
403
+ self.lm.expect_keys += \
404
+ expect_keys_llama
405
+ self.lm.keys.update({
406
+ "norm_1": ".ln1",
407
+ "norm_2": ".ln2",
408
+ })
409
+
410
+ # Orion
411
+
412
+ if arch_string == "OrionForCausalLM":
413
+ arch_recognized = True
414
+ self.lm.layer_keys += \
415
+ layer_keys_llama_norms + \
416
+ layer_keys_llama_attn + \
417
+ layer_keys_llama_mlp
418
+ self.lm.expect_keys += \
419
+ expect_keys_llama
420
+ self.lm.norm = "layernorm"
421
+
422
+ # Qwen2 (1.5, 2, 2.5)
423
+
424
+ if arch_string == "Qwen2ForCausalLM":
425
+ arch_recognized = True
426
+ self.lm.layer_keys += \
427
+ layer_keys_llama_norms + \
428
+ layer_keys_llama_attn + \
429
+ layer_keys_llama_mlp
430
+ self.lm.expect_keys += \
431
+ expect_keys_llama
432
+ self.lm.attention_bias_qkv = True
433
+ self.lm.supports_tp = True
434
+
435
+ # Qwen3
436
+
437
+ if arch_string == "Qwen3ForCausalLM":
438
+ arch_recognized = True
439
+ self.lm.layer_keys += \
440
+ layer_keys_llama_norms + \
441
+ layer_keys_llama_attn + \
442
+ layer_keys_llama_mlp
443
+ self.lm.expect_keys += \
444
+ expect_keys_llama
445
+ self.lm.supports_tp = True
446
+ self.lm.default_use_qk_norm = True
447
+
448
+ # Qwen3MoE
449
+
450
+ if arch_string == "Qwen3MoeForCausalLM":
451
+ arch_recognized = True
452
+ self.lm.layer_keys += \
453
+ layer_keys_llama_norms + \
454
+ layer_keys_llama_attn + \
455
+ layer_keys_qwen3moe_mlp
456
+ self.lm.expect_keys += \
457
+ expect_keys_llama
458
+ self.lm.supports_tp = True
459
+ self.lm.default_use_qk_norm = True
460
+ self.lm.keys.update({
461
+ "mlp_gate": ".mlp.experts.*.gate_proj",
462
+ "mlp_up": ".mlp.experts.*.up_proj",
463
+ "mlp_down": ".mlp.experts.*.down_proj",
464
+ "mlp_expert_gate": ".mlp.gate"
465
+ })
466
+ self.lm.is_moe = True
467
+
468
+ # Qwen2-VL (2, 2.5)
469
+
470
+ if arch_string in ["Qwen2VLForConditionalGeneration", "Qwen2_5_VLForConditionalGeneration"]:
471
+ arch_recognized = True
472
+ self.lm.layer_keys += \
473
+ layer_keys_llama_norms + \
474
+ layer_keys_llama_attn + \
475
+ layer_keys_llama_mlp
476
+ self.lm.expect_keys += \
477
+ expect_keys_llama
478
+ self.lm.attention_bias_qkv = True
479
+ self.lm.mrope = True
480
+ self.lm.rope_freq_half = True
481
+
482
+ self.vt_prefix = "visual."
483
+ if arch_string == "Qwen2VLForConditionalGeneration":
484
+ read_config["vision_config"].update({"model_type": "qwen2"})
485
+ self.vt.keys.update({
486
+ "fused_qkv": ".attn.qkv",
487
+ "attn_o": ".attn.proj",
488
+ "mlp_gate": None,
489
+ "mlp_up": ".mlp.fc1",
490
+ "mlp_down": ".mlp.fc2",
491
+ "norm_1": ".norm1",
492
+ "norm_2": ".norm2",
493
+ "layers": "blocks",
494
+ "patch_conv": "patch_embed.proj",
495
+ })
496
+ self.vt.mlp_gate = False
497
+ self.vt.mlp_act_func = "quickgelu"
498
+ self.vt.norm = "layernorm"
499
+ elif arch_string == "Qwen2_5_VLForConditionalGeneration":
500
+ read_config["vision_config"].update({"model_type": "qwen2.5"})
501
+ self.vt.keys.update({
502
+ "fused_qkv": ".attn.qkv",
503
+ "attn_o": ".attn.proj",
504
+ "mlp_gate": ".mlp.gate_proj",
505
+ "mlp_up": ".mlp.up_proj",
506
+ "mlp_down": ".mlp.down_proj",
507
+ "norm_1": ".norm1",
508
+ "norm_2": ".norm2",
509
+ "layers": "blocks",
510
+ "patch_conv": "patch_embed.proj",
511
+ })
512
+ self.vt.mlp_gate = True
513
+ self.vt.mlp_act_func = "silu"
514
+ self.vt.norm = "rmsnorm"
515
+ self.vt.mlp_bias = True
516
+ self.vt.attention_bias_qkv = True
517
+ self.vt.attention_bias_o = True
518
+ self.vt.vision_input_norm = False
519
+ self.vt.vision_conv3d = True
520
+ self.vt.mlp_merger = True
521
+
522
+ self.mmp_prefix = "visual.merger."
523
+ self.mmp.keys.update({
524
+ "mlp_gate": None,
525
+ "mlp_up": "mlp.0",
526
+ "mlp_down": "mlp.2",
527
+ "norm_2": "ln_q",
528
+ })
529
+ self.mmp.mlp_gate = False
530
+ self.mmp.mlp_act_func = "gelu"
531
+ self.mmp.mlp_bias = True
532
+ self.mmp.norm = "layernorm"
533
+
534
+ self.standard_calib_noise = (5, 30)
535
+
536
+ # OpenCUA (Custom)
537
+ if arch_string == "OpenCUAForConditionalGeneration":
538
+ arch_recognized = True
539
+
540
+ # --- Language Model settings (Correct) ---
541
+ self.lm_prefix = "language_model."
542
+ self.lm.layer_keys += \
543
+ layer_keys_llama_norms + \
544
+ layer_keys_llama_attn + \
545
+ layer_keys_llama_mlp
546
+ self.lm.expect_keys += \
547
+ expect_keys_llama
548
+ self.lm.attention_bias_qkv = True
549
+ self.lm.supports_tp = True
550
+
551
+ # --- Vision Tower settings (Correct) ---
552
+ self.vt_prefix = "vision_tower."
553
+ read_config["vision_config"].update({"model_type": "qwen2.5"})
554
+ self.vt.keys.update({
555
+ "fused_qkv": ".attn.qkv",
556
+ "attn_o": ".attn.proj",
557
+ "mlp_gate": ".mlp.gate_proj",
558
+ "mlp_up": ".mlp.up_proj",
559
+ "mlp_down": ".mlp.down_proj",
560
+ "norm_1": ".norm1",
561
+ "norm_2": ".norm2",
562
+ "layers": "blocks",
563
+ "patch_conv": "patch_embed.proj",
564
+ })
565
+ self.vt.mlp_gate = True
566
+ self.vt.mlp_act_func = "silu"
567
+ self.vt.norm = "rmsnorm"
568
+ self.vt.mlp_bias = True
569
+ self.vt.attention_bias_qkv = True
570
+ self.vt.attention_bias_o = True
571
+ self.vt.vision_input_norm = False
572
+ self.vt.vision_conv3d = True
573
+ self.vt.rope_style = RopeStyle.NONE
574
+ self.vt.mlp_merger = True
575
+
576
+ # --- Multi-Modal Projector settings (Corrected) ---
577
+ self.mmp_prefix = "multi_modal_projector."
578
+ self.mmp.keys.update({
579
+ "mlp_gate": None,
580
+ "mlp_up": "linear_1",
581
+ "mlp_down": "linear_2",
582
+ })
583
+ self.mmp.mlp_gate = False
584
+ self.mmp.mlp_act_func = "gelu"
585
+ self.mmp.mlp_bias = True
586
+ # CRITICAL CHANGE: The following line is removed as there is no norm layer
587
+ # self.mmp.norm = "rmsnorm"
588
+
589
+
590
+ # Gemma
591
+
592
+ if arch_string == "GemmaForCausalLM":
593
+ arch_recognized = True
594
+ self.lm.layer_keys += \
595
+ layer_keys_llama_norms + \
596
+ layer_keys_llama_attn + \
597
+ layer_keys_llama_mlp
598
+ self.lm.expect_keys += \
599
+ expect_keys_gemma
600
+ self.lm.keys.update({
601
+ "lm_head": "model.embed_tokens",
602
+ })
603
+ self.lm.mlp_act_func = "gelu"
604
+ self.lm.normalize_embeddings = True
605
+ self.lm.norm_constant_bias = 1
606
+ self.lm.requires_bos = True
607
+
608
+ # Gemma2
609
+
610
+ if arch_string == "Gemma2ForCausalLM":
611
+ arch_recognized = True
612
+ self.lm.layer_keys += \
613
+ layer_keys_gemma2_norms + \
614
+ layer_keys_llama_attn + \
615
+ layer_keys_llama_mlp
616
+ self.lm.expect_keys += \
617
+ expect_keys_gemma
618
+ self.lm.keys.update({
619
+ "lm_head": "model.embed_tokens",
620
+ "norm_1": ".input_layernorm",
621
+ "norm_1_post": ".post_attention_layernorm",
622
+ "norm_2": ".pre_feedforward_layernorm",
623
+ "norm_2_post": ".post_feedforward_layernorm",
624
+ })
625
+ self.lm.mlp_act_func = "gelu"
626
+ self.lm.normalize_embeddings = True
627
+ self.lm.norm_constant_bias = 1
628
+ self.lm.requires_bos = True
629
+ self.lm.alternating_swa = True
630
+ self.lm.residual_stream_fp32 = True
631
+
632
+ # Gemma3
633
+
634
+ if arch_string == "Gemma3ForConditionalGeneration":
635
+ arch_recognized = True
636
+ self.lm.layer_keys += \
637
+ layer_keys_gemma2_norms + \
638
+ layer_keys_llama_attn + \
639
+ layer_keys_llama_mlp
640
+ self.lm.expect_keys += \
641
+ expect_keys_gemma
642
+ self.lm.keys.update({
643
+ "lm_head": "model.embed_tokens",
644
+ "norm_1": ".input_layernorm",
645
+ "norm_1_post": ".post_attention_layernorm",
646
+ "norm_2": ".pre_feedforward_layernorm",
647
+ "norm_2_post": ".post_feedforward_layernorm",
648
+ })
649
+ self.lm_prefix = "language_model."
650
+ self.lm.mlp_act_func = "gelu"
651
+ self.lm.normalize_embeddings = True
652
+ self.lm.norm_constant_bias = 1
653
+ self.lm.requires_bos = True
654
+ self.lm.alternating_swa = True
655
+ self.lm.residual_stream_fp32 = True
656
+ self.lm.sliding_rope_theta = 10000
657
+ self.lm.sliding_rope_scale = 1
658
+ self.lm.default_vocab_size = 262208
659
+ self.lm.default_rms_norm_eps = 1e-06
660
+ self.lm.default_head_dim = 256
661
+ self.lm.default_num_attention_heads = 8
662
+ self.lm.default_num_key_value_heads = 4
663
+ self.lm.default_use_qk_norm = True
664
+ self.lm.default_sliding_window_pattern = 6
665
+ self.lm.default_rope_theta = 1e6
666
+ self.lm.pos_id_index = 1
667
+ self.lm.headnorm = "rmsnorm"
668
+
669
+ self.vt_prefix = "vision_tower.vision_model."
670
+ self.vt.keys.update({
671
+ "attn_q": ".self_attn.q_proj",
672
+ "attn_k": ".self_attn.k_proj",
673
+ "attn_v": ".self_attn.v_proj",
674
+ "attn_o": ".self_attn.out_proj",
675
+ "norm_1": ".layer_norm1",
676
+ "norm_2": ".layer_norm2",
677
+ "mlp_gate": None,
678
+ "mlp_up": ".mlp.fc1",
679
+ "mlp_down": ".mlp.fc2",
680
+ "layers": "encoder.layers",
681
+ "patch_conv": "embeddings.patch_embedding",
682
+ "position_embedding": "embeddings.position_embedding",
683
+ "output_norm": "post_layernorm",
684
+ })
685
+ self.vt.norm = "rmsnorm"
686
+ self.vt.patch_conv_bias = True
687
+ self.vt.mlp_gate = False
688
+ self.vt.mlp_bias = True
689
+ self.vt.attention_bias_qkv = True
690
+ self.vt.attention_bias_o = True
691
+ self.vt.vision_input_norm = False
692
+ self.vt.mlp_merger = False
693
+ self.vt.norm = "layernorm"
694
+ self.vt.learned_emb = True
695
+ self.vt.rope_style = RopeStyle.NONE
696
+ self.vt.mlp_act_func = "gelu"
697
+ self.vt.output_norm = True
698
+
699
+ self.keymap = google_keymap
700
+ self.compile_fix_keymap = google_keymap
701
+ self.mmp_prefix = "multi_modal_projector."
702
+ self.mmp.keys.update({
703
+ "input_projection": "mm_input_projection",
704
+ "input_projection_norm": "mm_soft_emb_norm",
705
+ })
706
+ self.mmp.norm_constant_bias = 1
707
+
708
+ # StarCoder2
709
+
710
+ if arch_string == "Starcoder2ForCausalLM":
711
+ arch_recognized = True
712
+ self.lm.layer_keys += \
713
+ layer_keys_llama_norms + \
714
+ layer_keys_llama_attn + \
715
+ layer_keys_starcoder2_mlp
716
+ self.lm.expect_keys += \
717
+ expect_keys_starcoder2
718
+ self.lm.keys.update({
719
+ "mlp_gate": None,
720
+ "mlp_up": ".mlp.c_fc",
721
+ "mlp_down": ".mlp.c_proj",
722
+ "lm_head": "model.embed_tokens",
723
+ "norm_eps": "layer_norm_epsilon",
724
+ })
725
+ self.lm.mlp_act_func = "gelu"
726
+ self.lm.norm = "layernorm"
727
+ self.lm.attention_bias_qkv = True
728
+ self.lm.attention_bias_o = True
729
+ self.lm.mlp_bias = True
730
+ self.lm.mlp_gate = False
731
+
732
+ # GemMoE
733
+
734
+ if arch_string == "GemmoeForCausalLM":
735
+ arch_recognized = True
736
+ print(f" !! Warning, Gemmoe support is experimental and has not been fully tested")
737
+ self.lm.layer_keys += \
738
+ layer_keys_llama_norms + \
739
+ layer_keys_llama_attn + \
740
+ layer_keys_mixtral_mlp
741
+ self.lm.expect_keys += \
742
+ expect_keys_gemma
743
+ self.lm.keys.update({
744
+ "mlp_gate": ".block_sparse_moe.experts.*.w1",
745
+ "mlp_up": ".block_sparse_moe.experts.*.w3",
746
+ "mlp_down": ".block_sparse_moe.experts.*.w2",
747
+ "mlp_expert_gate": ".block_sparse_moe.gate",
748
+ "lm_head": "model.embed_tokens",
749
+ })
750
+ self.lm.mlp_act_func = "gelu"
751
+ self.lm.normalize_embeddings = True
752
+ self.lm.norm_constant_bias = 1
753
+ self.lm.is_moe = True
754
+ self.lm.requires_bos = True
755
+
756
+ # Cohere
757
+
758
+ if arch_string == "CohereForCausalLM":
759
+ arch_recognized = True
760
+ self.lm.layer_keys += \
761
+ layer_keys_cohere_norms + \
762
+ layer_keys_llama_attn + \
763
+ layer_keys_llama_mlp
764
+ self.lm.expect_keys += \
765
+ expect_keys_gemma
766
+ self.lm.keys.update({
767
+ "norm_eps": "layer_norm_eps",
768
+ "lm_head": "model.embed_tokens",
769
+ "norm_1": ".input_layernorm",
770
+ "norm_2": None,
771
+ })
772
+ self.lm.norm = "layernorm"
773
+ self.lm.rope_style = RopeStyle.GPTJ
774
+ self.lm.parallel_decoder_blocks = True
775
+ self.lm.requires_bos = True
776
+
777
+ # Cohere 2
778
+
779
+ if arch_string == "Cohere2ForCausalLM":
780
+ arch_recognized = True
781
+ self.lm.layer_keys += \
782
+ layer_keys_cohere_norms + \
783
+ layer_keys_llama_attn + \
784
+ layer_keys_llama_mlp
785
+ self.lm.expect_keys += \
786
+ expect_keys_gemma
787
+ self.lm.keys.update({
788
+ "norm_eps": "layer_norm_eps",
789
+ "lm_head": "model.embed_tokens",
790
+ "norm_1": ".input_layernorm",
791
+ "norm_2": None,
792
+ })
793
+ self.lm.norm = "layernorm"
794
+ self.lm.rope_style = RopeStyle.GPTJ
795
+ self.lm.parallel_decoder_blocks = True
796
+ self.lm.requires_bos = True
797
+ self.lm.alternating_swa = True
798
+
799
+ # DBRX
800
+
801
+ if arch_string == "DbrxForCausalLM":
802
+ arch_recognized = True
803
+ self.keymap = dbrx_keymap
804
+ self.lm.layer_keys += \
805
+ layer_keys_llama_norms + \
806
+ layer_keys_dbrx_attn + \
807
+ layer_keys_dbrx_mlp
808
+ self.lm.expect_keys += \
809
+ expect_keys_llama
810
+ self.lm.keys.update({
811
+ "norm_eps": None,
812
+ "mlp_gate": ".block_sparse_moe.experts.*.w1",
813
+ "mlp_up": ".block_sparse_moe.experts.*.v1",
814
+ "mlp_down": ".block_sparse_moe.experts.*.w2",
815
+ "mlp_expert_gate": ".block_sparse_moe.gate",
816
+ "fused_qkv": ".self_attn.Wqkv",
817
+ })
818
+ self.lm.norm = "layernorm"
819
+ self.lm.is_moe = True
820
+
821
+ # Phi3
822
+
823
+ if arch_string == "Phi3ForCausalLM":
824
+ arch_recognized = True
825
+ self.lm.layer_keys += \
826
+ layer_keys_llama_norms + \
827
+ layer_keys_phi3_attn + \
828
+ layer_keys_phi3_mlp
829
+ self.lm.expect_keys += \
830
+ expect_keys_llama
831
+ self.lm.keys.update({
832
+ "fused_qkv": ".self_attn.qkv_proj",
833
+ "fused_mlp_12": "gate_up_proj",
834
+ })
835
+
836
+ # GPTBigCode
837
+
838
+ if arch_string == "GPTBigCodeForCausalLM":
839
+ arch_recognized = True
840
+ self.keymap = bigcode_keymap
841
+ self.lm.layer_keys += \
842
+ layer_keys_gpt2_norms + \
843
+ layer_keys_gpt2_attn + \
844
+ layer_keys_gpt2_mlp
845
+ self.lm.expect_keys += \
846
+ expect_keys_gpt2
847
+ self.lm.keys.update({
848
+ "norm_eps": "layer_norm_epsilon",
849
+ "mlp_gate": None,
850
+ "mlp_up": ".mlp.c_fc",
851
+ "mlp_down": ".mlp.c_proj",
852
+ "lm_head": "model.embed_tokens",
853
+ "norm_1": ".ln_1",
854
+ "norm_2": ".ln_2",
855
+ "fused_qkv": ".self_attn.c_attn",
856
+ "learned_pos_emb": "model.wpe",
857
+ })
858
+ self.lm.mlp_act_func = "gelu"
859
+ self.lm.norm = "layernorm"
860
+ self.lm.rope_style = RopeStyle.NONE
861
+ self.lm.mqa = True
862
+ self.lm.attention_bias_qkv = True
863
+ self.lm.attention_bias_o = True
864
+ self.lm.mlp_bias = True
865
+ self.lm.mlp_gate = False
866
+
867
+ # GPT2
868
+
869
+ if arch_string == "GPT2LMHeadModel":
870
+ arch_recognized = True
871
+ self.keymap = gpt2_keymap
872
+ self.lm.layer_keys += \
873
+ layer_keys_gpt2_norms + \
874
+ layer_keys_gpt2_attn + \
875
+ layer_keys_gpt2_mlp
876
+ self.lm.expect_keys += \
877
+ expect_keys_gpt2
878
+ self.lm.keys.update({
879
+ "norm_eps": "layer_norm_epsilon",
880
+ "mlp_gate": None,
881
+ "mlp_up": ".mlp.c_fc",
882
+ "mlp_down": ".mlp.c_proj",
883
+ "lm_head": "model.embed_tokens",
884
+ "norm_1": ".ln_1",
885
+ "norm_2": ".ln_2",
886
+ "fused_qkv": ".self_attn.c_attn",
887
+ "learned_pos_emb": "model.wpe",
888
+ })
889
+ self.lm.mlp_act_func = "gelu"
890
+ self.lm.norm = "layernorm"
891
+ self.lm.rope_style = RopeStyle.NONE
892
+ self.lm.default_inner_dim_mult = 4
893
+ self.lm.attention_bias_qkv = True
894
+ self.lm.attention_bias_o = True
895
+ self.lm.mlp_bias = True
896
+ self.lm.mlp_gate = False
897
+ self.orig_weights_transposed = True
898
+
899
+ # MiniCPM
900
+
901
+ if arch_string == "MiniCPMForCausalLM":
902
+ arch_recognized = True
903
+ self.lm.layer_keys += \
904
+ layer_keys_llama_norms + \
905
+ layer_keys_llama_attn + \
906
+ layer_keys_llama_mlp
907
+ self.lm.expect_keys += \
908
+ expect_keys_llama
909
+ self.lm.logit_scale_basedim = True
910
+
911
+ # InternLM2
912
+
913
+ if arch_string == "InternLM2ForCausalLM":
914
+ arch_recognized = True
915
+ self.keymap = internlm2_keymap
916
+ self.lm.layer_keys += \
917
+ layer_keys_internlm2_norms + \
918
+ layer_keys_internlm2_attn + \
919
+ layer_keys_internlm2_mlp
920
+ self.lm.expect_keys += \
921
+ expect_keys_llama
922
+ self.lm.keys.update({
923
+ "mlp_gate": ".feed_forward.w1",
924
+ "mlp_up": ".feed_forward.w3",
925
+ "mlp_down": ".feed_forward.w2",
926
+ "norm_1": ".attention_norm",
927
+ "norm_2": ".ffn_norm",
928
+ "fused_qkv": ".self_attn.wqkv",
929
+ })
930
+ self.lm.fused_qkv_altpack = True
931
+
932
+ # Index
933
+
934
+ if arch_string == "IndexForCausalLM":
935
+ arch_recognized = True
936
+ self.lm.layer_keys += \
937
+ layer_keys_llama_norms + \
938
+ layer_keys_llama_attn + \
939
+ layer_keys_llama_mlp
940
+ self.lm.expect_keys += \
941
+ expect_keys_llama
942
+
943
+ # Granite (v3)
944
+
945
+ if arch_string == "GraniteForCausalLM":
946
+ arch_recognized = True
947
+ self.lm.layer_keys += \
948
+ layer_keys_llama_norms + \
949
+ layer_keys_llama_attn + \
950
+ layer_keys_llama_mlp
951
+ self.lm.expect_keys += \
952
+ expect_keys_llama
953
+
954
+ # GLM4
955
+
956
+ if arch_string == "Glm4ForCausalLM":
957
+ arch_recognized = True
958
+ self.lm.layer_keys += \
959
+ layer_keys_glm4_norms + \
960
+ layer_keys_llama_attn + \
961
+ layer_keys_phi3_mlp
962
+ self.lm.expect_keys += \
963
+ expect_keys_llama
964
+ self.lm.supports_tp = True
965
+ self.lm.rope_style = RopeStyle.GPTJ
966
+ self.lm.keys.update({
967
+ "fused_mlp_12": "gate_up_proj",
968
+ "lm_head": "model.embed_tokens",
969
+ "norm_1": ".input_layernorm",
970
+ "norm_1_post": ".post_self_attn_layernorm",
971
+ "norm_2": ".post_attention_layernorm",
972
+ "norm_2_post": ".post_mlp_layernorm",
973
+ })
974
+ self.lm.attention_bias_qkv = read_config.get("attention_bias", False)
975
+
976
+ # Llama (default + fallback)
977
+
978
+ if arch_string != "LlamaForCausalLM" and not arch_recognized:
979
+ print(f" !! Warning, unknown architecture: {arch_string}")
980
+ print(f" !! Loading as LlamaForCausalLM")
981
+ self.arch_string = "LlamaForCausalLM"
982
+ if not arch_recognized:
983
+ self.lm.layer_keys += \
984
+ layer_keys_llama_norms + \
985
+ layer_keys_llama_attn + \
986
+ layer_keys_llama_mlp
987
+ self.lm.expect_keys += \
988
+ expect_keys_llama
989
+ self.lm.supports_tp = True
990
+
991
+ # Arch overrides
992
+
993
+ if read_config.get("attention_bias", False) and not (self.lm.attention_bias_qkv or self.lm.attention_bias_o):
994
+ self.lm.attention_bias_qkv = True
995
+ self.lm.attention_bias_o = True
996
+
997
+ if read_config.get("mlp_bias", False):
998
+ self.lm.mlp_bias = True
999
+
1000
+ if read_config.get("tie_word_embeddings", False):
1001
+ if ["lm_head"] in self.lm.expect_keys:
1002
+ self.lm.expect_keys.remove(["lm_head"])
1003
+ self.lm.keys.update({
1004
+ "lm_head": "model.embed_tokens",
1005
+ })
1006
+
1007
+ # Sanity checks
1008
+
1009
+ if self.lm.residual_stream_fp32:
1010
+ assert self.lm.keys["norm_1_post"] and self.lm.keys["norm_2_post"], \
1011
+ "FP32 residual stream only implemented for arch with post layernorms"
1012
+
1013
+ def make_fused_mlp(self):
1014
+
1015
+ for x in layer_keys_llama_mlp: self.lm.layer_keys.remove(x)
1016
+ self.lm.layer_keys += layer_keys_llama_mlp_swiglu
1017
+ self.lm.keys.update({
1018
+ "fused_mlp_12": layer_keys_llama_mlp_swiglu[0][0],
1019
+ "fused_mlp_3": layer_keys_llama_mlp_swiglu[1][0],
1020
+ })