Safetensors
qwen3
ehartford commited on
Commit
9ff632c
Β·
verified Β·
1 Parent(s): 328c03d

Create stage1_v2.py

Browse files
Files changed (1) hide show
  1. stage1_v2.py +706 -0
stage1_v2.py ADDED
@@ -0,0 +1,706 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ """
3
+ Stage 1 v2 Sharted Edition πŸ’©: Fast Multi-GPU Interpolation from Qwen3-32B to Qwen3-72B
4
+ Optimized for 8x MI300X GPUs with parallel processing and sharted weight loading
5
+ FIXED: Correct o_proj dimensions
6
+ """
7
+
8
+ import torch
9
+ import torch.distributed as dist
10
+ import torch.multiprocessing as mp
11
+ import os
12
+ import json
13
+ from tqdm import tqdm
14
+ from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
15
+ from accelerate import init_empty_weights
16
+ import numpy as np
17
+ from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
18
+ import gc
19
+ from safetensors.torch import load_file, save_file
20
+ import shutil
21
+
22
+ # --- Configuration ---
23
+ # Source (32B) dimensions
24
+ SRC_HIDDEN_SIZE = 5120
25
+ SRC_INTERMEDIATE_SIZE = 25600
26
+ SRC_NUM_HEADS = 40
27
+ SRC_NUM_LAYERS = 64
28
+
29
+ # IMPORTANT: Qwen3-32B already has asymmetric attention!
30
+ # Q heads: 64 (for q_proj output and o_proj input)
31
+ # KV heads: 8
32
+ SRC_Q_HEADS = 64 # This gives us 8192 dims for Q
33
+ SRC_KV_HEADS = 8 # This gives us 1024 dims for K,V
34
+
35
+ # Target (72B) dimensions
36
+ TGT_HIDDEN_SIZE = 8192
37
+ TGT_INTERMEDIATE_SIZE = 29568
38
+ TGT_NUM_HEADS = 64
39
+
40
+ # Target also has asymmetric attention
41
+ TGT_Q_HEADS = 64
42
+ TGT_KV_HEADS = 8
43
+ HEAD_DIM = 128
44
+
45
+ # Deltas for interpolation
46
+ DELTA_HIDDEN = TGT_HIDDEN_SIZE - SRC_HIDDEN_SIZE
47
+ DELTA_INTERMEDIATE = TGT_INTERMEDIATE_SIZE - SRC_INTERMEDIATE_SIZE
48
+
49
+ OUTPUT_DIR = "./Qwen3-32B-to-72B-Stage1-v2-sharted"
50
+
51
+ # GPU configuration
52
+ NUM_GPUS = 8
53
+ BATCH_SIZE = 16 # Process multiple tensors at once
54
+
55
+ def get_layer_info(name):
56
+ """Extract layer number and component type from parameter name."""
57
+ if "model.layers." in name:
58
+ parts = name.split(".")
59
+ try:
60
+ layer_idx = int(parts[2])
61
+ return layer_idx, ".".join(parts[3:])
62
+ except:
63
+ return None, name
64
+ return None, name
65
+
66
+ def get_interpolation_weight(layer_idx, num_layers=SRC_NUM_LAYERS):
67
+ """Get interpolation weight based on layer depth."""
68
+ if layer_idx is None:
69
+ return 0.5
70
+
71
+ relative_pos = layer_idx / (num_layers - 1)
72
+
73
+ if relative_pos < 0.25:
74
+ return 0.3
75
+ elif relative_pos < 0.75:
76
+ return 0.5
77
+ else:
78
+ return 0.7
79
+
80
+ @torch.jit.script
81
+ def add_structured_noise_jit(tensor: torch.Tensor, noise_scale: float = 0.01) -> torch.Tensor:
82
+ """JIT-compiled structured noise addition."""
83
+ noise = torch.randn_like(tensor) * noise_scale * tensor.std()
84
+
85
+ if tensor.ndim == 2 and tensor.shape[0] > 100 and tensor.shape[1] > 100:
86
+ h, w = noise.shape
87
+ center_mask = torch.ones_like(noise)
88
+ center_mask[h//4:3*h//4, w//4:3*w//4] *= 0.5
89
+ noise *= center_mask
90
+
91
+ return noise
92
+
93
+ @torch.jit.script
94
+ def preserve_norm_jit(original: torch.Tensor, interpolated: torch.Tensor) -> torch.Tensor:
95
+ """JIT-compiled norm preservation."""
96
+ original_norm = original.norm()
97
+ interpolated_norm = interpolated.norm()
98
+
99
+ if interpolated_norm > 0:
100
+ scale_factor = original_norm / interpolated_norm
101
+ return interpolated * scale_factor
102
+ return interpolated
103
+
104
+ def structure_aware_interpolation_gpu(block1, block2, weight=0.5, add_noise=True, device='cuda'):
105
+ """GPU-accelerated interpolation."""
106
+ # Move to GPU if not already
107
+ if block1.device.type != 'cuda':
108
+ block1 = block1.to(device)
109
+ if block2.device.type != 'cuda':
110
+ block2 = block2.to(device)
111
+
112
+ # Basic interpolation
113
+ interpolated = (1 - weight) * block1 + weight * block2
114
+
115
+ # Add noise on GPU
116
+ if add_noise:
117
+ noise = add_structured_noise_jit(interpolated, 0.005)
118
+ interpolated = interpolated + noise
119
+
120
+ return interpolated
121
+
122
+ def upscale_tensor_gpu(tensor: torch.Tensor, name: str, device='cuda') -> torch.Tensor:
123
+ """GPU-accelerated tensor upscaling with FIXED o_proj dimensions."""
124
+ # Move tensor to GPU
125
+ tensor = tensor.to(device)
126
+
127
+ layer_idx, component = get_layer_info(name)
128
+ interp_weight = get_interpolation_weight(layer_idx)
129
+
130
+ # Debug print for ANY o_proj to catch the first one
131
+ if "o_proj.weight" in name:
132
+ print(f"\n[DEBUG] Processing {name}: input shape = {tensor.shape}")
133
+
134
+ # Handle 1D tensors
135
+ if tensor.ndim == 1:
136
+ if tensor.shape[0] == SRC_HIDDEN_SIZE:
137
+ block1, block2 = tensor[:DELTA_HIDDEN], tensor[-DELTA_HIDDEN:]
138
+ interpolated = structure_aware_interpolation_gpu(block1, block2, weight=interp_weight, device=device)
139
+ result = torch.cat([tensor, interpolated], dim=0)
140
+ if "layernorm" in name:
141
+ result = preserve_norm_jit(tensor, result)
142
+ return result
143
+ elif "k_norm" in name or "q_norm" in name:
144
+ return tensor
145
+
146
+ # Handle 2D tensors
147
+ elif tensor.ndim == 2:
148
+ # Embeddings and LM head
149
+ if "embed_tokens" in name or "lm_head" in name:
150
+ if tensor.shape[1] == SRC_HIDDEN_SIZE:
151
+ block1, block2 = tensor[:, :DELTA_HIDDEN], tensor[:, -DELTA_HIDDEN:]
152
+ interpolated = structure_aware_interpolation_gpu(block1, block2, weight=0.3, device=device)
153
+ return torch.cat([tensor, interpolated], dim=1)
154
+
155
+ # Attention projections
156
+ elif "self_attn" in name:
157
+ if "q_proj.weight" in name:
158
+ # Q projection: [8192, 5120] -> [8192, 8192]
159
+ # Already has 64 heads in output, just need to expand input
160
+ # Only scale input dimension (columns)
161
+ block1, block2 = tensor[:, :DELTA_HIDDEN], tensor[:, -DELTA_HIDDEN:]
162
+ interpolated = structure_aware_interpolation_gpu(block1, block2, weight=interp_weight, device=device)
163
+ result = torch.cat([tensor, interpolated], dim=1)
164
+
165
+ return preserve_norm_jit(tensor, result)
166
+
167
+ elif "k_proj.weight" in name or "v_proj.weight" in name:
168
+ # K,V projections: [1024, 5120] -> [1024, 8192]
169
+ # Only scale input dimension, keep 8 KV heads
170
+ block1, block2 = tensor[:, :DELTA_HIDDEN], tensor[:, -DELTA_HIDDEN:]
171
+ interpolated = structure_aware_interpolation_gpu(block1, block2, weight=interp_weight, device=device)
172
+ result = torch.cat([tensor, interpolated], dim=1)
173
+ return preserve_norm_jit(tensor, result)
174
+
175
+ elif "o_proj.weight" in name:
176
+ # O projection: [5120, 8192] -> [8192, 8192]
177
+ # Input already has 64 heads (8192), only expand output
178
+
179
+ # Debug the input
180
+ print(f"\n[DEBUG] Processing {name}: input shape = {tensor.shape}")
181
+ print(f"[DEBUG] Expected input: [5120, 8192], Expected output: [8192, 8192]")
182
+
183
+ # Only need to expand rows (output dim) from 5120 to 8192
184
+ row_block1 = tensor[:DELTA_HIDDEN, :] # [3072, 8192]
185
+ row_block2 = tensor[-DELTA_HIDDEN:, :] # [3072, 8192]
186
+ row_interp = structure_aware_interpolation_gpu(row_block1, row_block2, weight=interp_weight, device=device)
187
+
188
+ print(f"[DEBUG] row interpolation: block1={row_block1.shape}, block2={row_block2.shape}, interp={row_interp.shape}")
189
+
190
+ result = torch.cat([tensor, row_interp], dim=0) # [5120+3072, 8192] = [8192, 8192]
191
+
192
+ print(f"[DEBUG] Final result: {result.shape}")
193
+
194
+ assert result.shape == (TGT_HIDDEN_SIZE, TGT_HIDDEN_SIZE), f"o_proj shape error: got {result.shape}"
195
+
196
+ return preserve_norm_jit(tensor, result)
197
+
198
+ # MLP projections
199
+ elif "mlp" in name:
200
+ if "gate_proj.weight" in name or "up_proj.weight" in name:
201
+ # [25600, 5120] -> [29568, 8192]
202
+ mlp_weight = min(interp_weight + 0.1, 0.8)
203
+
204
+ # Expand rows first
205
+ row_block1, row_block2 = tensor[:DELTA_INTERMEDIATE, :], tensor[-DELTA_INTERMEDIATE:, :]
206
+ upscaled_rows = torch.cat([tensor, structure_aware_interpolation_gpu(row_block1, row_block2, weight=mlp_weight, device=device)], dim=0)
207
+
208
+ # Then expand columns
209
+ col_block1, col_block2 = upscaled_rows[:, :DELTA_HIDDEN], upscaled_rows[:, -DELTA_HIDDEN:]
210
+ result = torch.cat([upscaled_rows, structure_aware_interpolation_gpu(col_block1, col_block2, weight=mlp_weight, device=device)], dim=1)
211
+
212
+ result = preserve_norm_jit(tensor, result)
213
+ return result
214
+
215
+ elif "down_proj.weight" in name:
216
+ # [5120, 25600] -> [8192, 29568]
217
+ mlp_weight = interp_weight
218
+
219
+ # Expand rows first
220
+ row_block1, row_block2 = tensor[:DELTA_HIDDEN, :], tensor[-DELTA_HIDDEN:, :]
221
+ upscaled_rows = torch.cat([tensor, structure_aware_interpolation_gpu(row_block1, row_block2, weight=mlp_weight, device=device)], dim=0)
222
+
223
+ # Then expand columns
224
+ col_block1, col_block2 = upscaled_rows[:, :DELTA_INTERMEDIATE], upscaled_rows[:, -DELTA_INTERMEDIATE:]
225
+ result = torch.cat([upscaled_rows, structure_aware_interpolation_gpu(col_block1, col_block2, weight=mlp_weight, device=device)], dim=1)
226
+
227
+ return result
228
+
229
+ return tensor
230
+
231
+ def process_layer_batch(layer_tensors, device):
232
+ """Process a batch of tensors from the same layer on a specific GPU."""
233
+ processed = {}
234
+
235
+ with torch.cuda.device(device):
236
+ for name, tensor in layer_tensors:
237
+ processed_tensor = upscale_tensor_gpu(tensor, name, device=device)
238
+ # Move back to CPU to save GPU memory
239
+ processed[name] = processed_tensor.cpu()
240
+
241
+ return processed
242
+
243
+ def load_model_sharted(model_id):
244
+ """Load model weights from sharted safetensors files. πŸ’©"""
245
+ print("\nπŸ’© Loading sharted weights...")
246
+
247
+ model_path = os.path.join(model_id, "model.safetensors.index.json")
248
+
249
+ if os.path.exists(model_path):
250
+ # Load from local path with sharted files
251
+ with open(model_path, 'r') as f:
252
+ index = json.load(f)
253
+
254
+ weight_map = index['weight_map']
255
+ unique_files = set(weight_map.values())
256
+
257
+ all_weights = {}
258
+ for file in tqdm(unique_files, desc="Loading sharts"):
259
+ file_path = os.path.join(model_id, file)
260
+ weights = load_file(file_path)
261
+ all_weights.update(weights)
262
+
263
+ return all_weights
264
+ else:
265
+ # Try loading from HuggingFace
266
+ from huggingface_hub import snapshot_download
267
+
268
+ print(f"Downloading model from HuggingFace: {model_id}")
269
+ local_dir = snapshot_download(model_id)
270
+ return load_model_sharted(local_dir)
271
+
272
+ def save_model_sharted(state_dict, output_dir, max_shart_size="5GB"):
273
+ """Save model in sharted safetensors format. πŸ’©"""
274
+ print("\nπŸ’© Sharting model weights...")
275
+
276
+ os.makedirs(output_dir, exist_ok=True)
277
+
278
+ # Convert max_shart_size to bytes
279
+ size_map = {'GB': 1e9, 'MB': 1e6}
280
+ for unit, multiplier in size_map.items():
281
+ if unit in max_shart_size:
282
+ max_bytes = int(float(max_shart_size.replace(unit, '')) * multiplier)
283
+ break
284
+
285
+ # Group weights into sharts
286
+ sharts = []
287
+ current_shart = {}
288
+ current_size = 0
289
+
290
+ for name, tensor in state_dict.items():
291
+ tensor_size = tensor.numel() * tensor.element_size()
292
+
293
+ if current_size + tensor_size > max_bytes and current_shart:
294
+ sharts.append(current_shart)
295
+ current_shart = {}
296
+ current_size = 0
297
+
298
+ current_shart[name] = tensor
299
+ current_size += tensor_size
300
+
301
+ if current_shart:
302
+ sharts.append(current_shart)
303
+
304
+ # Save sharts
305
+ weight_map = {}
306
+ for i, shart in enumerate(tqdm(sharts, desc="Saving sharts")):
307
+ shart_name = f"model-{i+1:05d}-of-{len(sharts):05d}.safetensors"
308
+ save_file(shart, os.path.join(output_dir, shart_name))
309
+
310
+ for name in shart:
311
+ weight_map[name] = shart_name
312
+
313
+ # Save index
314
+ index = {
315
+ "metadata": {"total_size": sum(t.numel() * t.element_size() for t in state_dict.values())},
316
+ "weight_map": weight_map
317
+ }
318
+
319
+ with open(os.path.join(output_dir, "model.safetensors.index.json"), 'w') as f:
320
+ json.dump(index, f, indent=2)
321
+
322
+ print(f"πŸ’© Successfully sharted into {len(sharts)} files!")
323
+
324
+ def verify_architecture(model_path):
325
+ """Verify the model architecture matches expected dimensions."""
326
+ print("\n" + "="*60)
327
+ print("ARCHITECTURE VERIFICATION")
328
+ print("="*60)
329
+
330
+ model = AutoModelForCausalLM.from_pretrained(
331
+ model_path,
332
+ torch_dtype=torch.bfloat16,
333
+ device_map="cpu",
334
+ trust_remote_code=True
335
+ )
336
+
337
+ expected = {
338
+ "lm_head.weight": (151936, 8192),
339
+ "model.embed_tokens.weight": (151936, 8192),
340
+ "model.layers.0.input_layernorm.weight": (8192,),
341
+ "model.layers.0.mlp.down_proj.weight": (8192, 29568),
342
+ "model.layers.0.mlp.gate_proj.weight": (29568, 8192),
343
+ "model.layers.0.mlp.up_proj.weight": (29568, 8192),
344
+ "model.layers.0.post_attention_layernorm.weight": (8192,),
345
+ "model.layers.0.self_attn.k_norm.weight": (128,),
346
+ "model.layers.0.self_attn.k_proj.weight": (1024, 8192),
347
+ "model.layers.0.self_attn.o_proj.weight": (8192, 8192),
348
+ "model.layers.0.self_attn.q_norm.weight": (128,),
349
+ "model.layers.0.self_attn.q_proj.weight": (8192, 8192),
350
+ "model.layers.0.self_attn.v_proj.weight": (1024, 8192),
351
+ "model.norm.weight": (8192,),
352
+ }
353
+
354
+ all_correct = True
355
+
356
+ for name, expected_shape in expected.items():
357
+ param_dict = dict(model.named_parameters())
358
+ if name in param_dict:
359
+ actual_shape = tuple(param_dict[name].shape)
360
+ if actual_shape == expected_shape:
361
+ print(f"βœ“ {name}: {actual_shape}")
362
+ else:
363
+ print(f"βœ— {name}: {actual_shape} (expected {expected_shape})")
364
+ all_correct = False
365
+ else:
366
+ print(f"βœ— {name}: NOT FOUND")
367
+ all_correct = False
368
+
369
+ num_layers = model.config.num_hidden_layers
370
+ print(f"\nNumber of layers: {num_layers} (Stage 1 should have 64)")
371
+
372
+ if all_correct and num_layers == 64:
373
+ print("\nβœ… Architecture verification PASSED!")
374
+ else:
375
+ print("\n❌ Architecture verification FAILED!")
376
+
377
+ del model
378
+ return all_correct
379
+
380
+ def run_diagnostics(model_path):
381
+ """Run comprehensive diagnostics on the upscaled model."""
382
+ print("\n" + "="*60)
383
+ print("COMPREHENSIVE DIAGNOSTICS")
384
+ print("="*60)
385
+
386
+ # Load model and tokenizer
387
+ print("\nLoading model for diagnostics...")
388
+ model = AutoModelForCausalLM.from_pretrained(
389
+ model_path,
390
+ torch_dtype=torch.bfloat16,
391
+ device_map="auto",
392
+ trust_remote_code=True
393
+ )
394
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
395
+
396
+ # Test generation quality
397
+ print("\nπŸ§ͺ Generation Quality Tests:")
398
+ test_cases = [
399
+ ("The capital of France is", ["Paris"]),
400
+ ("2 + 2 =", ["4", "four"]),
401
+ ("The quick brown fox", ["jumps", "jumped", "lazy", "dog"]),
402
+ ("Hello, my name is", None),
403
+ ("Water boils at", ["100", "212", "degrees"]),
404
+ ("The Earth orbits the", ["Sun", "solar"]),
405
+ ("Machine learning is a type of", ["artificial intelligence", "AI"]),
406
+ ("Python is a", ["programming", "language", "snake"]),
407
+ ("The largest planet is", ["Jupiter"]),
408
+ ("DNA stands for", ["deoxyribonucleic", "acid"]),
409
+ ]
410
+
411
+ device = model.device
412
+ coherent_count = 0
413
+ total_tests = len(test_cases)
414
+
415
+ for prompt, expected in test_cases:
416
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
417
+
418
+ with torch.no_grad():
419
+ outputs = model.generate(
420
+ **inputs,
421
+ max_new_tokens=20,
422
+ do_sample=True,
423
+ temperature=0.7,
424
+ top_k=50,
425
+ top_p=0.95,
426
+ pad_token_id=tokenizer.pad_token_id,
427
+ )
428
+
429
+ generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
430
+ generated_only = generated_text[len(prompt):].strip()
431
+
432
+ print(f"\n Prompt: '{prompt}'")
433
+ print(f" Generated: '{generated_only}'")
434
+
435
+ # Check coherence
436
+ is_coherent = True
437
+
438
+ # Check for repetition
439
+ words = generated_only.split()
440
+ if len(words) > 3:
441
+ if len(set(words)) < len(words) / 2:
442
+ print(" ⚠️ High repetition detected")
443
+ is_coherent = False
444
+
445
+ # Check for expected content
446
+ if expected and len(generated_only) > 0:
447
+ found = any(kw.lower() in generated_only.lower() for kw in expected)
448
+ if found:
449
+ print(" βœ“ Contains expected content")
450
+ else:
451
+ print(" ⚠️ Missing expected keywords")
452
+ is_coherent = False
453
+
454
+ if is_coherent and len(generated_only.split()) >= 2:
455
+ coherent_count += 1
456
+
457
+ coherence_rate = (coherent_count / total_tests) * 100
458
+ print(f"\nπŸ“Š Overall coherence rate: {coherence_rate:.1f}%")
459
+
460
+ # Quick perplexity test
461
+ print("\nπŸ“ˆ Perplexity Test:")
462
+ test_text = "The quick brown fox jumps over the lazy dog."
463
+ inputs = tokenizer(test_text, return_tensors="pt").to(device)
464
+
465
+ with torch.no_grad():
466
+ outputs = model(**inputs, labels=inputs["input_ids"])
467
+ perplexity = torch.exp(outputs.loss).item()
468
+
469
+ print(f" Perplexity: {perplexity:.2f}")
470
+
471
+ if perplexity > 100:
472
+ print(" ⚠️ Very high perplexity")
473
+ elif perplexity > 50:
474
+ print(" ⚠️ Moderately high perplexity")
475
+ else:
476
+ print(" βœ“ Reasonable perplexity")
477
+
478
+ # Weight statistics check
479
+ print("\nπŸ” Weight Statistics (checking for anomalies):")
480
+ anomalies = 0
481
+
482
+ for name, param in model.named_parameters():
483
+ if torch.isnan(param).any():
484
+ print(f" ⚠️ {name}: Contains NaN!")
485
+ anomalies += 1
486
+ elif torch.isinf(param).any():
487
+ print(f" ⚠️ {name}: Contains Inf!")
488
+ anomalies += 1
489
+ elif param.std() < 1e-8:
490
+ print(f" ⚠️ {name}: Zero variance!")
491
+ anomalies += 1
492
+
493
+ if anomalies == 0:
494
+ print(" βœ“ No anomalies detected in weights")
495
+
496
+ # Final summary
497
+ success = coherence_rate >= 70 and perplexity < 100 and anomalies == 0
498
+
499
+ print("\n" + "="*60)
500
+ print("DIAGNOSTIC SUMMARY")
501
+ print("="*60)
502
+
503
+ if success:
504
+ print("βœ… Model passed all basic diagnostics!")
505
+ print(" - Good coherence rate")
506
+ print(" - Reasonable perplexity")
507
+ print(" - No weight anomalies")
508
+ else:
509
+ print("⚠️ Some issues detected:")
510
+ if coherence_rate < 70:
511
+ print(f" - Low coherence rate: {coherence_rate:.1f}%")
512
+ if perplexity >= 100:
513
+ print(f" - High perplexity: {perplexity:.2f}")
514
+ if anomalies > 0:
515
+ print(f" - Weight anomalies: {anomalies}")
516
+
517
+ return success
518
+
519
+ def main():
520
+ print("="*60)
521
+ print("Stage 1 v2 SHARTED πŸ’©: Multi-GPU Accelerated Interpolation")
522
+ print("Qwen3-32B β†’ 72B Dimensions")
523
+ print(f"Using {NUM_GPUS} GPUs for parallel processing")
524
+ print("FIXED: Correct o_proj dimensions")
525
+ print("="*60)
526
+
527
+ source_model_id = "Qwen/Qwen3-32B"
528
+
529
+ # Set up multi-GPU environment
530
+ if torch.cuda.is_available():
531
+ torch.cuda.set_device(0)
532
+ print(f"\nπŸš€ CUDA available: {torch.cuda.device_count()} devices")
533
+ for i in range(min(NUM_GPUS, torch.cuda.device_count())):
534
+ print(f" GPU {i}: {torch.cuda.get_device_name(i)}")
535
+
536
+ # Load tokenizer
537
+ print(f"\nπŸ“š Loading tokenizer from: {source_model_id}")
538
+ tokenizer = AutoTokenizer.from_pretrained(source_model_id, trust_remote_code=True)
539
+
540
+ # Load weights directly (faster than loading full model)
541
+ print(f"\n⚑ Loading model weights using fast sharted loading...")
542
+ source_weights = load_model_sharted(source_model_id)
543
+
544
+ print(f"\nπŸ“Š Loaded {len(source_weights)} tensors from sharts")
545
+
546
+ # Group tensors by layer for efficient GPU processing
547
+ layer_groups = {}
548
+ other_tensors = []
549
+
550
+ for name, tensor in source_weights.items():
551
+ layer_idx, _ = get_layer_info(name)
552
+ if layer_idx is not None:
553
+ if layer_idx not in layer_groups:
554
+ layer_groups[layer_idx] = []
555
+ layer_groups[layer_idx].append((name, tensor))
556
+ else:
557
+ other_tensors.append((name, tensor))
558
+
559
+ print(f"\nπŸ”§ Processing tensors across {NUM_GPUS} GPUs...")
560
+ print(" - Parallel layer processing")
561
+ print(" - JIT-compiled operations")
562
+ print(" - Efficient memory management")
563
+ print(" - Sharted weight I/O πŸ’©")
564
+
565
+ new_state_dict = {}
566
+
567
+ # Process layers in parallel across GPUs
568
+ with tqdm(total=len(source_weights), desc="Upscaling tensors") as pbar:
569
+ # Process layer groups in batches across GPUs
570
+ layer_indices = sorted(layer_groups.keys())
571
+
572
+ for i in range(0, len(layer_indices), NUM_GPUS):
573
+ batch_futures = []
574
+
575
+ # Assign each layer in this batch to a GPU
576
+ for j, layer_idx in enumerate(layer_indices[i:i+NUM_GPUS]):
577
+ gpu_id = j % NUM_GPUS
578
+ device = f'cuda:{gpu_id}'
579
+
580
+ # Process this layer on the assigned GPU
581
+ layer_tensors = layer_groups[layer_idx]
582
+ processed = process_layer_batch(layer_tensors, device)
583
+ new_state_dict.update(processed)
584
+ pbar.update(len(layer_tensors))
585
+
586
+ # Clear GPU cache periodically
587
+ if j % 4 == 0:
588
+ torch.cuda.empty_cache()
589
+
590
+ # Process non-layer tensors
591
+ for name, tensor in other_tensors:
592
+ device = 'cuda:0'
593
+ new_tensor = upscale_tensor_gpu(tensor, name, device=device).cpu()
594
+ new_state_dict[name] = new_tensor
595
+ pbar.update(1)
596
+
597
+ # Free source weights
598
+ del source_weights
599
+ gc.collect()
600
+ torch.cuda.empty_cache()
601
+
602
+ # Create config
603
+ print("\nπŸ“ Creating target model configuration...")
604
+ config = AutoConfig.from_pretrained(source_model_id, trust_remote_code=True)
605
+ config.hidden_size = TGT_HIDDEN_SIZE
606
+ config.intermediate_size = TGT_INTERMEDIATE_SIZE
607
+ config.num_attention_heads = TGT_NUM_HEADS
608
+ config.torch_dtype = torch.bfloat16
609
+
610
+ # Quick verification
611
+ print("\nπŸ” Quick verification of tensor dimensions BEFORE saving:")
612
+
613
+ # Check critical dimensions
614
+ critical_checks = [
615
+ "model.layers.0.self_attn.q_proj.weight",
616
+ "model.layers.0.self_attn.k_proj.weight",
617
+ "model.layers.0.self_attn.v_proj.weight",
618
+ "model.layers.0.self_attn.o_proj.weight",
619
+ "model.layers.0.mlp.gate_proj.weight"
620
+ ]
621
+
622
+ for check_name in critical_checks:
623
+ for name, tensor in new_state_dict.items():
624
+ if check_name in name:
625
+ print(f" {name}: {tensor.shape}")
626
+ break
627
+
628
+ # Specifically verify o_proj dimensions
629
+ print("\n🎯 Verifying ALL o_proj dimensions:")
630
+ o_proj_issue = False
631
+ for name, tensor in new_state_dict.items():
632
+ if "o_proj.weight" in name:
633
+ if tensor.shape != (TGT_HIDDEN_SIZE, TGT_HIDDEN_SIZE):
634
+ print(f" ❌ {name}: {tensor.shape} - INCORRECT!")
635
+ o_proj_issue = True
636
+ else:
637
+ if "layer.0" in name or "layer.63" in name: # Show first and last
638
+ print(f" βœ“ {name}: {tensor.shape}")
639
+
640
+ if o_proj_issue:
641
+ print("\n❌ ERROR: o_proj dimensions are incorrect! Not saving model.")
642
+ return False
643
+
644
+ # Save model and config
645
+ print(f"\nπŸ’Ύ Saving model to: {OUTPUT_DIR}")
646
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
647
+
648
+ # Save config
649
+ config.save_pretrained(OUTPUT_DIR)
650
+ tokenizer.save_pretrained(OUTPUT_DIR)
651
+
652
+ # Save weights in sharted format
653
+ save_model_sharted(new_state_dict, OUTPUT_DIR)
654
+
655
+ # Copy model configuration files
656
+ for file in ['generation_config.json', 'tokenizer_config.json', 'special_tokens_map.json']:
657
+ src = os.path.join(source_model_id, file)
658
+ dst = os.path.join(OUTPUT_DIR, file)
659
+ if os.path.exists(src):
660
+ shutil.copy(src, dst)
661
+
662
+ # Save metadata
663
+ metadata = {
664
+ "stage": "1-v2-sharted",
665
+ "source_model": source_model_id,
666
+ "method": "gpu_accelerated_structure_aware_interpolation_sharted",
667
+ "num_gpus_used": NUM_GPUS,
668
+ "fixes": [
669
+ "Corrected o_proj dimensions to 8192x8192",
670
+ "Proper handling of GQA architecture"
671
+ ],
672
+ "optimizations": [
673
+ "Multi-GPU parallel processing",
674
+ "JIT-compiled operations",
675
+ "Sharted weight loading/saving πŸ’©",
676
+ "Efficient memory management"
677
+ ],
678
+ "sharting_info": {
679
+ "format": "safetensors",
680
+ "max_shart_size": "5GB",
681
+ "poop_emoji": "πŸ’©"
682
+ }
683
+ }
684
+
685
+ with open(os.path.join(OUTPUT_DIR, "stage1_v2_metadata.json"), "w") as f:
686
+ json.dump(metadata, f, indent=2)
687
+
688
+ print("\nβœ… Stage 1 v2 SHARTED interpolation complete! πŸ’©")
689
+ print(f"πŸ“ Model saved to: {OUTPUT_DIR}")
690
+
691
+ # Run verifications
692
+ arch_ok = verify_architecture(OUTPUT_DIR)
693
+ diag_ok = run_diagnostics(OUTPUT_DIR)
694
+
695
+ if arch_ok and diag_ok:
696
+ print("\nπŸŽ‰ SUCCESS! Enhanced sharted interpolation completed successfully. πŸ’©")
697
+ print(f"πŸ“ Model saved to: {OUTPUT_DIR}")
698
+ print("\nπŸš€ Ready for Stage 2: Layer duplication (64β†’80 layers)")
699
+ else:
700
+ print("\n⚠️ Some issues detected. Review the diagnostics above.")
701
+
702
+ return arch_ok and diag_ok
703
+
704
+ if __name__ == "__main__":
705
+ success = main()
706
+ exit(0 if success else 1)