Create stage1_v2.py
Browse files- stage1_v2.py +706 -0
stage1_v2.py
ADDED
@@ -0,0 +1,706 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
"""
|
3 |
+
Stage 1 v2 Sharted Edition π©: Fast Multi-GPU Interpolation from Qwen3-32B to Qwen3-72B
|
4 |
+
Optimized for 8x MI300X GPUs with parallel processing and sharted weight loading
|
5 |
+
FIXED: Correct o_proj dimensions
|
6 |
+
"""
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.distributed as dist
|
10 |
+
import torch.multiprocessing as mp
|
11 |
+
import os
|
12 |
+
import json
|
13 |
+
from tqdm import tqdm
|
14 |
+
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
|
15 |
+
from accelerate import init_empty_weights
|
16 |
+
import numpy as np
|
17 |
+
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
|
18 |
+
import gc
|
19 |
+
from safetensors.torch import load_file, save_file
|
20 |
+
import shutil
|
21 |
+
|
22 |
+
# --- Configuration ---
|
23 |
+
# Source (32B) dimensions
|
24 |
+
SRC_HIDDEN_SIZE = 5120
|
25 |
+
SRC_INTERMEDIATE_SIZE = 25600
|
26 |
+
SRC_NUM_HEADS = 40
|
27 |
+
SRC_NUM_LAYERS = 64
|
28 |
+
|
29 |
+
# IMPORTANT: Qwen3-32B already has asymmetric attention!
|
30 |
+
# Q heads: 64 (for q_proj output and o_proj input)
|
31 |
+
# KV heads: 8
|
32 |
+
SRC_Q_HEADS = 64 # This gives us 8192 dims for Q
|
33 |
+
SRC_KV_HEADS = 8 # This gives us 1024 dims for K,V
|
34 |
+
|
35 |
+
# Target (72B) dimensions
|
36 |
+
TGT_HIDDEN_SIZE = 8192
|
37 |
+
TGT_INTERMEDIATE_SIZE = 29568
|
38 |
+
TGT_NUM_HEADS = 64
|
39 |
+
|
40 |
+
# Target also has asymmetric attention
|
41 |
+
TGT_Q_HEADS = 64
|
42 |
+
TGT_KV_HEADS = 8
|
43 |
+
HEAD_DIM = 128
|
44 |
+
|
45 |
+
# Deltas for interpolation
|
46 |
+
DELTA_HIDDEN = TGT_HIDDEN_SIZE - SRC_HIDDEN_SIZE
|
47 |
+
DELTA_INTERMEDIATE = TGT_INTERMEDIATE_SIZE - SRC_INTERMEDIATE_SIZE
|
48 |
+
|
49 |
+
OUTPUT_DIR = "./Qwen3-32B-to-72B-Stage1-v2-sharted"
|
50 |
+
|
51 |
+
# GPU configuration
|
52 |
+
NUM_GPUS = 8
|
53 |
+
BATCH_SIZE = 16 # Process multiple tensors at once
|
54 |
+
|
55 |
+
def get_layer_info(name):
|
56 |
+
"""Extract layer number and component type from parameter name."""
|
57 |
+
if "model.layers." in name:
|
58 |
+
parts = name.split(".")
|
59 |
+
try:
|
60 |
+
layer_idx = int(parts[2])
|
61 |
+
return layer_idx, ".".join(parts[3:])
|
62 |
+
except:
|
63 |
+
return None, name
|
64 |
+
return None, name
|
65 |
+
|
66 |
+
def get_interpolation_weight(layer_idx, num_layers=SRC_NUM_LAYERS):
|
67 |
+
"""Get interpolation weight based on layer depth."""
|
68 |
+
if layer_idx is None:
|
69 |
+
return 0.5
|
70 |
+
|
71 |
+
relative_pos = layer_idx / (num_layers - 1)
|
72 |
+
|
73 |
+
if relative_pos < 0.25:
|
74 |
+
return 0.3
|
75 |
+
elif relative_pos < 0.75:
|
76 |
+
return 0.5
|
77 |
+
else:
|
78 |
+
return 0.7
|
79 |
+
|
80 |
+
@torch.jit.script
|
81 |
+
def add_structured_noise_jit(tensor: torch.Tensor, noise_scale: float = 0.01) -> torch.Tensor:
|
82 |
+
"""JIT-compiled structured noise addition."""
|
83 |
+
noise = torch.randn_like(tensor) * noise_scale * tensor.std()
|
84 |
+
|
85 |
+
if tensor.ndim == 2 and tensor.shape[0] > 100 and tensor.shape[1] > 100:
|
86 |
+
h, w = noise.shape
|
87 |
+
center_mask = torch.ones_like(noise)
|
88 |
+
center_mask[h//4:3*h//4, w//4:3*w//4] *= 0.5
|
89 |
+
noise *= center_mask
|
90 |
+
|
91 |
+
return noise
|
92 |
+
|
93 |
+
@torch.jit.script
|
94 |
+
def preserve_norm_jit(original: torch.Tensor, interpolated: torch.Tensor) -> torch.Tensor:
|
95 |
+
"""JIT-compiled norm preservation."""
|
96 |
+
original_norm = original.norm()
|
97 |
+
interpolated_norm = interpolated.norm()
|
98 |
+
|
99 |
+
if interpolated_norm > 0:
|
100 |
+
scale_factor = original_norm / interpolated_norm
|
101 |
+
return interpolated * scale_factor
|
102 |
+
return interpolated
|
103 |
+
|
104 |
+
def structure_aware_interpolation_gpu(block1, block2, weight=0.5, add_noise=True, device='cuda'):
|
105 |
+
"""GPU-accelerated interpolation."""
|
106 |
+
# Move to GPU if not already
|
107 |
+
if block1.device.type != 'cuda':
|
108 |
+
block1 = block1.to(device)
|
109 |
+
if block2.device.type != 'cuda':
|
110 |
+
block2 = block2.to(device)
|
111 |
+
|
112 |
+
# Basic interpolation
|
113 |
+
interpolated = (1 - weight) * block1 + weight * block2
|
114 |
+
|
115 |
+
# Add noise on GPU
|
116 |
+
if add_noise:
|
117 |
+
noise = add_structured_noise_jit(interpolated, 0.005)
|
118 |
+
interpolated = interpolated + noise
|
119 |
+
|
120 |
+
return interpolated
|
121 |
+
|
122 |
+
def upscale_tensor_gpu(tensor: torch.Tensor, name: str, device='cuda') -> torch.Tensor:
|
123 |
+
"""GPU-accelerated tensor upscaling with FIXED o_proj dimensions."""
|
124 |
+
# Move tensor to GPU
|
125 |
+
tensor = tensor.to(device)
|
126 |
+
|
127 |
+
layer_idx, component = get_layer_info(name)
|
128 |
+
interp_weight = get_interpolation_weight(layer_idx)
|
129 |
+
|
130 |
+
# Debug print for ANY o_proj to catch the first one
|
131 |
+
if "o_proj.weight" in name:
|
132 |
+
print(f"\n[DEBUG] Processing {name}: input shape = {tensor.shape}")
|
133 |
+
|
134 |
+
# Handle 1D tensors
|
135 |
+
if tensor.ndim == 1:
|
136 |
+
if tensor.shape[0] == SRC_HIDDEN_SIZE:
|
137 |
+
block1, block2 = tensor[:DELTA_HIDDEN], tensor[-DELTA_HIDDEN:]
|
138 |
+
interpolated = structure_aware_interpolation_gpu(block1, block2, weight=interp_weight, device=device)
|
139 |
+
result = torch.cat([tensor, interpolated], dim=0)
|
140 |
+
if "layernorm" in name:
|
141 |
+
result = preserve_norm_jit(tensor, result)
|
142 |
+
return result
|
143 |
+
elif "k_norm" in name or "q_norm" in name:
|
144 |
+
return tensor
|
145 |
+
|
146 |
+
# Handle 2D tensors
|
147 |
+
elif tensor.ndim == 2:
|
148 |
+
# Embeddings and LM head
|
149 |
+
if "embed_tokens" in name or "lm_head" in name:
|
150 |
+
if tensor.shape[1] == SRC_HIDDEN_SIZE:
|
151 |
+
block1, block2 = tensor[:, :DELTA_HIDDEN], tensor[:, -DELTA_HIDDEN:]
|
152 |
+
interpolated = structure_aware_interpolation_gpu(block1, block2, weight=0.3, device=device)
|
153 |
+
return torch.cat([tensor, interpolated], dim=1)
|
154 |
+
|
155 |
+
# Attention projections
|
156 |
+
elif "self_attn" in name:
|
157 |
+
if "q_proj.weight" in name:
|
158 |
+
# Q projection: [8192, 5120] -> [8192, 8192]
|
159 |
+
# Already has 64 heads in output, just need to expand input
|
160 |
+
# Only scale input dimension (columns)
|
161 |
+
block1, block2 = tensor[:, :DELTA_HIDDEN], tensor[:, -DELTA_HIDDEN:]
|
162 |
+
interpolated = structure_aware_interpolation_gpu(block1, block2, weight=interp_weight, device=device)
|
163 |
+
result = torch.cat([tensor, interpolated], dim=1)
|
164 |
+
|
165 |
+
return preserve_norm_jit(tensor, result)
|
166 |
+
|
167 |
+
elif "k_proj.weight" in name or "v_proj.weight" in name:
|
168 |
+
# K,V projections: [1024, 5120] -> [1024, 8192]
|
169 |
+
# Only scale input dimension, keep 8 KV heads
|
170 |
+
block1, block2 = tensor[:, :DELTA_HIDDEN], tensor[:, -DELTA_HIDDEN:]
|
171 |
+
interpolated = structure_aware_interpolation_gpu(block1, block2, weight=interp_weight, device=device)
|
172 |
+
result = torch.cat([tensor, interpolated], dim=1)
|
173 |
+
return preserve_norm_jit(tensor, result)
|
174 |
+
|
175 |
+
elif "o_proj.weight" in name:
|
176 |
+
# O projection: [5120, 8192] -> [8192, 8192]
|
177 |
+
# Input already has 64 heads (8192), only expand output
|
178 |
+
|
179 |
+
# Debug the input
|
180 |
+
print(f"\n[DEBUG] Processing {name}: input shape = {tensor.shape}")
|
181 |
+
print(f"[DEBUG] Expected input: [5120, 8192], Expected output: [8192, 8192]")
|
182 |
+
|
183 |
+
# Only need to expand rows (output dim) from 5120 to 8192
|
184 |
+
row_block1 = tensor[:DELTA_HIDDEN, :] # [3072, 8192]
|
185 |
+
row_block2 = tensor[-DELTA_HIDDEN:, :] # [3072, 8192]
|
186 |
+
row_interp = structure_aware_interpolation_gpu(row_block1, row_block2, weight=interp_weight, device=device)
|
187 |
+
|
188 |
+
print(f"[DEBUG] row interpolation: block1={row_block1.shape}, block2={row_block2.shape}, interp={row_interp.shape}")
|
189 |
+
|
190 |
+
result = torch.cat([tensor, row_interp], dim=0) # [5120+3072, 8192] = [8192, 8192]
|
191 |
+
|
192 |
+
print(f"[DEBUG] Final result: {result.shape}")
|
193 |
+
|
194 |
+
assert result.shape == (TGT_HIDDEN_SIZE, TGT_HIDDEN_SIZE), f"o_proj shape error: got {result.shape}"
|
195 |
+
|
196 |
+
return preserve_norm_jit(tensor, result)
|
197 |
+
|
198 |
+
# MLP projections
|
199 |
+
elif "mlp" in name:
|
200 |
+
if "gate_proj.weight" in name or "up_proj.weight" in name:
|
201 |
+
# [25600, 5120] -> [29568, 8192]
|
202 |
+
mlp_weight = min(interp_weight + 0.1, 0.8)
|
203 |
+
|
204 |
+
# Expand rows first
|
205 |
+
row_block1, row_block2 = tensor[:DELTA_INTERMEDIATE, :], tensor[-DELTA_INTERMEDIATE:, :]
|
206 |
+
upscaled_rows = torch.cat([tensor, structure_aware_interpolation_gpu(row_block1, row_block2, weight=mlp_weight, device=device)], dim=0)
|
207 |
+
|
208 |
+
# Then expand columns
|
209 |
+
col_block1, col_block2 = upscaled_rows[:, :DELTA_HIDDEN], upscaled_rows[:, -DELTA_HIDDEN:]
|
210 |
+
result = torch.cat([upscaled_rows, structure_aware_interpolation_gpu(col_block1, col_block2, weight=mlp_weight, device=device)], dim=1)
|
211 |
+
|
212 |
+
result = preserve_norm_jit(tensor, result)
|
213 |
+
return result
|
214 |
+
|
215 |
+
elif "down_proj.weight" in name:
|
216 |
+
# [5120, 25600] -> [8192, 29568]
|
217 |
+
mlp_weight = interp_weight
|
218 |
+
|
219 |
+
# Expand rows first
|
220 |
+
row_block1, row_block2 = tensor[:DELTA_HIDDEN, :], tensor[-DELTA_HIDDEN:, :]
|
221 |
+
upscaled_rows = torch.cat([tensor, structure_aware_interpolation_gpu(row_block1, row_block2, weight=mlp_weight, device=device)], dim=0)
|
222 |
+
|
223 |
+
# Then expand columns
|
224 |
+
col_block1, col_block2 = upscaled_rows[:, :DELTA_INTERMEDIATE], upscaled_rows[:, -DELTA_INTERMEDIATE:]
|
225 |
+
result = torch.cat([upscaled_rows, structure_aware_interpolation_gpu(col_block1, col_block2, weight=mlp_weight, device=device)], dim=1)
|
226 |
+
|
227 |
+
return result
|
228 |
+
|
229 |
+
return tensor
|
230 |
+
|
231 |
+
def process_layer_batch(layer_tensors, device):
|
232 |
+
"""Process a batch of tensors from the same layer on a specific GPU."""
|
233 |
+
processed = {}
|
234 |
+
|
235 |
+
with torch.cuda.device(device):
|
236 |
+
for name, tensor in layer_tensors:
|
237 |
+
processed_tensor = upscale_tensor_gpu(tensor, name, device=device)
|
238 |
+
# Move back to CPU to save GPU memory
|
239 |
+
processed[name] = processed_tensor.cpu()
|
240 |
+
|
241 |
+
return processed
|
242 |
+
|
243 |
+
def load_model_sharted(model_id):
|
244 |
+
"""Load model weights from sharted safetensors files. π©"""
|
245 |
+
print("\nπ© Loading sharted weights...")
|
246 |
+
|
247 |
+
model_path = os.path.join(model_id, "model.safetensors.index.json")
|
248 |
+
|
249 |
+
if os.path.exists(model_path):
|
250 |
+
# Load from local path with sharted files
|
251 |
+
with open(model_path, 'r') as f:
|
252 |
+
index = json.load(f)
|
253 |
+
|
254 |
+
weight_map = index['weight_map']
|
255 |
+
unique_files = set(weight_map.values())
|
256 |
+
|
257 |
+
all_weights = {}
|
258 |
+
for file in tqdm(unique_files, desc="Loading sharts"):
|
259 |
+
file_path = os.path.join(model_id, file)
|
260 |
+
weights = load_file(file_path)
|
261 |
+
all_weights.update(weights)
|
262 |
+
|
263 |
+
return all_weights
|
264 |
+
else:
|
265 |
+
# Try loading from HuggingFace
|
266 |
+
from huggingface_hub import snapshot_download
|
267 |
+
|
268 |
+
print(f"Downloading model from HuggingFace: {model_id}")
|
269 |
+
local_dir = snapshot_download(model_id)
|
270 |
+
return load_model_sharted(local_dir)
|
271 |
+
|
272 |
+
def save_model_sharted(state_dict, output_dir, max_shart_size="5GB"):
|
273 |
+
"""Save model in sharted safetensors format. π©"""
|
274 |
+
print("\nπ© Sharting model weights...")
|
275 |
+
|
276 |
+
os.makedirs(output_dir, exist_ok=True)
|
277 |
+
|
278 |
+
# Convert max_shart_size to bytes
|
279 |
+
size_map = {'GB': 1e9, 'MB': 1e6}
|
280 |
+
for unit, multiplier in size_map.items():
|
281 |
+
if unit in max_shart_size:
|
282 |
+
max_bytes = int(float(max_shart_size.replace(unit, '')) * multiplier)
|
283 |
+
break
|
284 |
+
|
285 |
+
# Group weights into sharts
|
286 |
+
sharts = []
|
287 |
+
current_shart = {}
|
288 |
+
current_size = 0
|
289 |
+
|
290 |
+
for name, tensor in state_dict.items():
|
291 |
+
tensor_size = tensor.numel() * tensor.element_size()
|
292 |
+
|
293 |
+
if current_size + tensor_size > max_bytes and current_shart:
|
294 |
+
sharts.append(current_shart)
|
295 |
+
current_shart = {}
|
296 |
+
current_size = 0
|
297 |
+
|
298 |
+
current_shart[name] = tensor
|
299 |
+
current_size += tensor_size
|
300 |
+
|
301 |
+
if current_shart:
|
302 |
+
sharts.append(current_shart)
|
303 |
+
|
304 |
+
# Save sharts
|
305 |
+
weight_map = {}
|
306 |
+
for i, shart in enumerate(tqdm(sharts, desc="Saving sharts")):
|
307 |
+
shart_name = f"model-{i+1:05d}-of-{len(sharts):05d}.safetensors"
|
308 |
+
save_file(shart, os.path.join(output_dir, shart_name))
|
309 |
+
|
310 |
+
for name in shart:
|
311 |
+
weight_map[name] = shart_name
|
312 |
+
|
313 |
+
# Save index
|
314 |
+
index = {
|
315 |
+
"metadata": {"total_size": sum(t.numel() * t.element_size() for t in state_dict.values())},
|
316 |
+
"weight_map": weight_map
|
317 |
+
}
|
318 |
+
|
319 |
+
with open(os.path.join(output_dir, "model.safetensors.index.json"), 'w') as f:
|
320 |
+
json.dump(index, f, indent=2)
|
321 |
+
|
322 |
+
print(f"π© Successfully sharted into {len(sharts)} files!")
|
323 |
+
|
324 |
+
def verify_architecture(model_path):
|
325 |
+
"""Verify the model architecture matches expected dimensions."""
|
326 |
+
print("\n" + "="*60)
|
327 |
+
print("ARCHITECTURE VERIFICATION")
|
328 |
+
print("="*60)
|
329 |
+
|
330 |
+
model = AutoModelForCausalLM.from_pretrained(
|
331 |
+
model_path,
|
332 |
+
torch_dtype=torch.bfloat16,
|
333 |
+
device_map="cpu",
|
334 |
+
trust_remote_code=True
|
335 |
+
)
|
336 |
+
|
337 |
+
expected = {
|
338 |
+
"lm_head.weight": (151936, 8192),
|
339 |
+
"model.embed_tokens.weight": (151936, 8192),
|
340 |
+
"model.layers.0.input_layernorm.weight": (8192,),
|
341 |
+
"model.layers.0.mlp.down_proj.weight": (8192, 29568),
|
342 |
+
"model.layers.0.mlp.gate_proj.weight": (29568, 8192),
|
343 |
+
"model.layers.0.mlp.up_proj.weight": (29568, 8192),
|
344 |
+
"model.layers.0.post_attention_layernorm.weight": (8192,),
|
345 |
+
"model.layers.0.self_attn.k_norm.weight": (128,),
|
346 |
+
"model.layers.0.self_attn.k_proj.weight": (1024, 8192),
|
347 |
+
"model.layers.0.self_attn.o_proj.weight": (8192, 8192),
|
348 |
+
"model.layers.0.self_attn.q_norm.weight": (128,),
|
349 |
+
"model.layers.0.self_attn.q_proj.weight": (8192, 8192),
|
350 |
+
"model.layers.0.self_attn.v_proj.weight": (1024, 8192),
|
351 |
+
"model.norm.weight": (8192,),
|
352 |
+
}
|
353 |
+
|
354 |
+
all_correct = True
|
355 |
+
|
356 |
+
for name, expected_shape in expected.items():
|
357 |
+
param_dict = dict(model.named_parameters())
|
358 |
+
if name in param_dict:
|
359 |
+
actual_shape = tuple(param_dict[name].shape)
|
360 |
+
if actual_shape == expected_shape:
|
361 |
+
print(f"β {name}: {actual_shape}")
|
362 |
+
else:
|
363 |
+
print(f"β {name}: {actual_shape} (expected {expected_shape})")
|
364 |
+
all_correct = False
|
365 |
+
else:
|
366 |
+
print(f"β {name}: NOT FOUND")
|
367 |
+
all_correct = False
|
368 |
+
|
369 |
+
num_layers = model.config.num_hidden_layers
|
370 |
+
print(f"\nNumber of layers: {num_layers} (Stage 1 should have 64)")
|
371 |
+
|
372 |
+
if all_correct and num_layers == 64:
|
373 |
+
print("\nβ
Architecture verification PASSED!")
|
374 |
+
else:
|
375 |
+
print("\nβ Architecture verification FAILED!")
|
376 |
+
|
377 |
+
del model
|
378 |
+
return all_correct
|
379 |
+
|
380 |
+
def run_diagnostics(model_path):
|
381 |
+
"""Run comprehensive diagnostics on the upscaled model."""
|
382 |
+
print("\n" + "="*60)
|
383 |
+
print("COMPREHENSIVE DIAGNOSTICS")
|
384 |
+
print("="*60)
|
385 |
+
|
386 |
+
# Load model and tokenizer
|
387 |
+
print("\nLoading model for diagnostics...")
|
388 |
+
model = AutoModelForCausalLM.from_pretrained(
|
389 |
+
model_path,
|
390 |
+
torch_dtype=torch.bfloat16,
|
391 |
+
device_map="auto",
|
392 |
+
trust_remote_code=True
|
393 |
+
)
|
394 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
395 |
+
|
396 |
+
# Test generation quality
|
397 |
+
print("\nπ§ͺ Generation Quality Tests:")
|
398 |
+
test_cases = [
|
399 |
+
("The capital of France is", ["Paris"]),
|
400 |
+
("2 + 2 =", ["4", "four"]),
|
401 |
+
("The quick brown fox", ["jumps", "jumped", "lazy", "dog"]),
|
402 |
+
("Hello, my name is", None),
|
403 |
+
("Water boils at", ["100", "212", "degrees"]),
|
404 |
+
("The Earth orbits the", ["Sun", "solar"]),
|
405 |
+
("Machine learning is a type of", ["artificial intelligence", "AI"]),
|
406 |
+
("Python is a", ["programming", "language", "snake"]),
|
407 |
+
("The largest planet is", ["Jupiter"]),
|
408 |
+
("DNA stands for", ["deoxyribonucleic", "acid"]),
|
409 |
+
]
|
410 |
+
|
411 |
+
device = model.device
|
412 |
+
coherent_count = 0
|
413 |
+
total_tests = len(test_cases)
|
414 |
+
|
415 |
+
for prompt, expected in test_cases:
|
416 |
+
inputs = tokenizer(prompt, return_tensors="pt").to(device)
|
417 |
+
|
418 |
+
with torch.no_grad():
|
419 |
+
outputs = model.generate(
|
420 |
+
**inputs,
|
421 |
+
max_new_tokens=20,
|
422 |
+
do_sample=True,
|
423 |
+
temperature=0.7,
|
424 |
+
top_k=50,
|
425 |
+
top_p=0.95,
|
426 |
+
pad_token_id=tokenizer.pad_token_id,
|
427 |
+
)
|
428 |
+
|
429 |
+
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
430 |
+
generated_only = generated_text[len(prompt):].strip()
|
431 |
+
|
432 |
+
print(f"\n Prompt: '{prompt}'")
|
433 |
+
print(f" Generated: '{generated_only}'")
|
434 |
+
|
435 |
+
# Check coherence
|
436 |
+
is_coherent = True
|
437 |
+
|
438 |
+
# Check for repetition
|
439 |
+
words = generated_only.split()
|
440 |
+
if len(words) > 3:
|
441 |
+
if len(set(words)) < len(words) / 2:
|
442 |
+
print(" β οΈ High repetition detected")
|
443 |
+
is_coherent = False
|
444 |
+
|
445 |
+
# Check for expected content
|
446 |
+
if expected and len(generated_only) > 0:
|
447 |
+
found = any(kw.lower() in generated_only.lower() for kw in expected)
|
448 |
+
if found:
|
449 |
+
print(" β Contains expected content")
|
450 |
+
else:
|
451 |
+
print(" β οΈ Missing expected keywords")
|
452 |
+
is_coherent = False
|
453 |
+
|
454 |
+
if is_coherent and len(generated_only.split()) >= 2:
|
455 |
+
coherent_count += 1
|
456 |
+
|
457 |
+
coherence_rate = (coherent_count / total_tests) * 100
|
458 |
+
print(f"\nπ Overall coherence rate: {coherence_rate:.1f}%")
|
459 |
+
|
460 |
+
# Quick perplexity test
|
461 |
+
print("\nπ Perplexity Test:")
|
462 |
+
test_text = "The quick brown fox jumps over the lazy dog."
|
463 |
+
inputs = tokenizer(test_text, return_tensors="pt").to(device)
|
464 |
+
|
465 |
+
with torch.no_grad():
|
466 |
+
outputs = model(**inputs, labels=inputs["input_ids"])
|
467 |
+
perplexity = torch.exp(outputs.loss).item()
|
468 |
+
|
469 |
+
print(f" Perplexity: {perplexity:.2f}")
|
470 |
+
|
471 |
+
if perplexity > 100:
|
472 |
+
print(" β οΈ Very high perplexity")
|
473 |
+
elif perplexity > 50:
|
474 |
+
print(" β οΈ Moderately high perplexity")
|
475 |
+
else:
|
476 |
+
print(" β Reasonable perplexity")
|
477 |
+
|
478 |
+
# Weight statistics check
|
479 |
+
print("\nπ Weight Statistics (checking for anomalies):")
|
480 |
+
anomalies = 0
|
481 |
+
|
482 |
+
for name, param in model.named_parameters():
|
483 |
+
if torch.isnan(param).any():
|
484 |
+
print(f" β οΈ {name}: Contains NaN!")
|
485 |
+
anomalies += 1
|
486 |
+
elif torch.isinf(param).any():
|
487 |
+
print(f" β οΈ {name}: Contains Inf!")
|
488 |
+
anomalies += 1
|
489 |
+
elif param.std() < 1e-8:
|
490 |
+
print(f" β οΈ {name}: Zero variance!")
|
491 |
+
anomalies += 1
|
492 |
+
|
493 |
+
if anomalies == 0:
|
494 |
+
print(" β No anomalies detected in weights")
|
495 |
+
|
496 |
+
# Final summary
|
497 |
+
success = coherence_rate >= 70 and perplexity < 100 and anomalies == 0
|
498 |
+
|
499 |
+
print("\n" + "="*60)
|
500 |
+
print("DIAGNOSTIC SUMMARY")
|
501 |
+
print("="*60)
|
502 |
+
|
503 |
+
if success:
|
504 |
+
print("β
Model passed all basic diagnostics!")
|
505 |
+
print(" - Good coherence rate")
|
506 |
+
print(" - Reasonable perplexity")
|
507 |
+
print(" - No weight anomalies")
|
508 |
+
else:
|
509 |
+
print("β οΈ Some issues detected:")
|
510 |
+
if coherence_rate < 70:
|
511 |
+
print(f" - Low coherence rate: {coherence_rate:.1f}%")
|
512 |
+
if perplexity >= 100:
|
513 |
+
print(f" - High perplexity: {perplexity:.2f}")
|
514 |
+
if anomalies > 0:
|
515 |
+
print(f" - Weight anomalies: {anomalies}")
|
516 |
+
|
517 |
+
return success
|
518 |
+
|
519 |
+
def main():
|
520 |
+
print("="*60)
|
521 |
+
print("Stage 1 v2 SHARTED π©: Multi-GPU Accelerated Interpolation")
|
522 |
+
print("Qwen3-32B β 72B Dimensions")
|
523 |
+
print(f"Using {NUM_GPUS} GPUs for parallel processing")
|
524 |
+
print("FIXED: Correct o_proj dimensions")
|
525 |
+
print("="*60)
|
526 |
+
|
527 |
+
source_model_id = "Qwen/Qwen3-32B"
|
528 |
+
|
529 |
+
# Set up multi-GPU environment
|
530 |
+
if torch.cuda.is_available():
|
531 |
+
torch.cuda.set_device(0)
|
532 |
+
print(f"\nπ CUDA available: {torch.cuda.device_count()} devices")
|
533 |
+
for i in range(min(NUM_GPUS, torch.cuda.device_count())):
|
534 |
+
print(f" GPU {i}: {torch.cuda.get_device_name(i)}")
|
535 |
+
|
536 |
+
# Load tokenizer
|
537 |
+
print(f"\nπ Loading tokenizer from: {source_model_id}")
|
538 |
+
tokenizer = AutoTokenizer.from_pretrained(source_model_id, trust_remote_code=True)
|
539 |
+
|
540 |
+
# Load weights directly (faster than loading full model)
|
541 |
+
print(f"\nβ‘ Loading model weights using fast sharted loading...")
|
542 |
+
source_weights = load_model_sharted(source_model_id)
|
543 |
+
|
544 |
+
print(f"\nπ Loaded {len(source_weights)} tensors from sharts")
|
545 |
+
|
546 |
+
# Group tensors by layer for efficient GPU processing
|
547 |
+
layer_groups = {}
|
548 |
+
other_tensors = []
|
549 |
+
|
550 |
+
for name, tensor in source_weights.items():
|
551 |
+
layer_idx, _ = get_layer_info(name)
|
552 |
+
if layer_idx is not None:
|
553 |
+
if layer_idx not in layer_groups:
|
554 |
+
layer_groups[layer_idx] = []
|
555 |
+
layer_groups[layer_idx].append((name, tensor))
|
556 |
+
else:
|
557 |
+
other_tensors.append((name, tensor))
|
558 |
+
|
559 |
+
print(f"\nπ§ Processing tensors across {NUM_GPUS} GPUs...")
|
560 |
+
print(" - Parallel layer processing")
|
561 |
+
print(" - JIT-compiled operations")
|
562 |
+
print(" - Efficient memory management")
|
563 |
+
print(" - Sharted weight I/O π©")
|
564 |
+
|
565 |
+
new_state_dict = {}
|
566 |
+
|
567 |
+
# Process layers in parallel across GPUs
|
568 |
+
with tqdm(total=len(source_weights), desc="Upscaling tensors") as pbar:
|
569 |
+
# Process layer groups in batches across GPUs
|
570 |
+
layer_indices = sorted(layer_groups.keys())
|
571 |
+
|
572 |
+
for i in range(0, len(layer_indices), NUM_GPUS):
|
573 |
+
batch_futures = []
|
574 |
+
|
575 |
+
# Assign each layer in this batch to a GPU
|
576 |
+
for j, layer_idx in enumerate(layer_indices[i:i+NUM_GPUS]):
|
577 |
+
gpu_id = j % NUM_GPUS
|
578 |
+
device = f'cuda:{gpu_id}'
|
579 |
+
|
580 |
+
# Process this layer on the assigned GPU
|
581 |
+
layer_tensors = layer_groups[layer_idx]
|
582 |
+
processed = process_layer_batch(layer_tensors, device)
|
583 |
+
new_state_dict.update(processed)
|
584 |
+
pbar.update(len(layer_tensors))
|
585 |
+
|
586 |
+
# Clear GPU cache periodically
|
587 |
+
if j % 4 == 0:
|
588 |
+
torch.cuda.empty_cache()
|
589 |
+
|
590 |
+
# Process non-layer tensors
|
591 |
+
for name, tensor in other_tensors:
|
592 |
+
device = 'cuda:0'
|
593 |
+
new_tensor = upscale_tensor_gpu(tensor, name, device=device).cpu()
|
594 |
+
new_state_dict[name] = new_tensor
|
595 |
+
pbar.update(1)
|
596 |
+
|
597 |
+
# Free source weights
|
598 |
+
del source_weights
|
599 |
+
gc.collect()
|
600 |
+
torch.cuda.empty_cache()
|
601 |
+
|
602 |
+
# Create config
|
603 |
+
print("\nπ Creating target model configuration...")
|
604 |
+
config = AutoConfig.from_pretrained(source_model_id, trust_remote_code=True)
|
605 |
+
config.hidden_size = TGT_HIDDEN_SIZE
|
606 |
+
config.intermediate_size = TGT_INTERMEDIATE_SIZE
|
607 |
+
config.num_attention_heads = TGT_NUM_HEADS
|
608 |
+
config.torch_dtype = torch.bfloat16
|
609 |
+
|
610 |
+
# Quick verification
|
611 |
+
print("\nπ Quick verification of tensor dimensions BEFORE saving:")
|
612 |
+
|
613 |
+
# Check critical dimensions
|
614 |
+
critical_checks = [
|
615 |
+
"model.layers.0.self_attn.q_proj.weight",
|
616 |
+
"model.layers.0.self_attn.k_proj.weight",
|
617 |
+
"model.layers.0.self_attn.v_proj.weight",
|
618 |
+
"model.layers.0.self_attn.o_proj.weight",
|
619 |
+
"model.layers.0.mlp.gate_proj.weight"
|
620 |
+
]
|
621 |
+
|
622 |
+
for check_name in critical_checks:
|
623 |
+
for name, tensor in new_state_dict.items():
|
624 |
+
if check_name in name:
|
625 |
+
print(f" {name}: {tensor.shape}")
|
626 |
+
break
|
627 |
+
|
628 |
+
# Specifically verify o_proj dimensions
|
629 |
+
print("\nπ― Verifying ALL o_proj dimensions:")
|
630 |
+
o_proj_issue = False
|
631 |
+
for name, tensor in new_state_dict.items():
|
632 |
+
if "o_proj.weight" in name:
|
633 |
+
if tensor.shape != (TGT_HIDDEN_SIZE, TGT_HIDDEN_SIZE):
|
634 |
+
print(f" β {name}: {tensor.shape} - INCORRECT!")
|
635 |
+
o_proj_issue = True
|
636 |
+
else:
|
637 |
+
if "layer.0" in name or "layer.63" in name: # Show first and last
|
638 |
+
print(f" β {name}: {tensor.shape}")
|
639 |
+
|
640 |
+
if o_proj_issue:
|
641 |
+
print("\nβ ERROR: o_proj dimensions are incorrect! Not saving model.")
|
642 |
+
return False
|
643 |
+
|
644 |
+
# Save model and config
|
645 |
+
print(f"\nπΎ Saving model to: {OUTPUT_DIR}")
|
646 |
+
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
647 |
+
|
648 |
+
# Save config
|
649 |
+
config.save_pretrained(OUTPUT_DIR)
|
650 |
+
tokenizer.save_pretrained(OUTPUT_DIR)
|
651 |
+
|
652 |
+
# Save weights in sharted format
|
653 |
+
save_model_sharted(new_state_dict, OUTPUT_DIR)
|
654 |
+
|
655 |
+
# Copy model configuration files
|
656 |
+
for file in ['generation_config.json', 'tokenizer_config.json', 'special_tokens_map.json']:
|
657 |
+
src = os.path.join(source_model_id, file)
|
658 |
+
dst = os.path.join(OUTPUT_DIR, file)
|
659 |
+
if os.path.exists(src):
|
660 |
+
shutil.copy(src, dst)
|
661 |
+
|
662 |
+
# Save metadata
|
663 |
+
metadata = {
|
664 |
+
"stage": "1-v2-sharted",
|
665 |
+
"source_model": source_model_id,
|
666 |
+
"method": "gpu_accelerated_structure_aware_interpolation_sharted",
|
667 |
+
"num_gpus_used": NUM_GPUS,
|
668 |
+
"fixes": [
|
669 |
+
"Corrected o_proj dimensions to 8192x8192",
|
670 |
+
"Proper handling of GQA architecture"
|
671 |
+
],
|
672 |
+
"optimizations": [
|
673 |
+
"Multi-GPU parallel processing",
|
674 |
+
"JIT-compiled operations",
|
675 |
+
"Sharted weight loading/saving π©",
|
676 |
+
"Efficient memory management"
|
677 |
+
],
|
678 |
+
"sharting_info": {
|
679 |
+
"format": "safetensors",
|
680 |
+
"max_shart_size": "5GB",
|
681 |
+
"poop_emoji": "π©"
|
682 |
+
}
|
683 |
+
}
|
684 |
+
|
685 |
+
with open(os.path.join(OUTPUT_DIR, "stage1_v2_metadata.json"), "w") as f:
|
686 |
+
json.dump(metadata, f, indent=2)
|
687 |
+
|
688 |
+
print("\nβ
Stage 1 v2 SHARTED interpolation complete! π©")
|
689 |
+
print(f"π Model saved to: {OUTPUT_DIR}")
|
690 |
+
|
691 |
+
# Run verifications
|
692 |
+
arch_ok = verify_architecture(OUTPUT_DIR)
|
693 |
+
diag_ok = run_diagnostics(OUTPUT_DIR)
|
694 |
+
|
695 |
+
if arch_ok and diag_ok:
|
696 |
+
print("\nπ SUCCESS! Enhanced sharted interpolation completed successfully. π©")
|
697 |
+
print(f"π Model saved to: {OUTPUT_DIR}")
|
698 |
+
print("\nπ Ready for Stage 2: Layer duplication (64β80 layers)")
|
699 |
+
else:
|
700 |
+
print("\nβ οΈ Some issues detected. Review the diagnostics above.")
|
701 |
+
|
702 |
+
return arch_ok and diag_ok
|
703 |
+
|
704 |
+
if __name__ == "__main__":
|
705 |
+
success = main()
|
706 |
+
exit(0 if success else 1)
|