|
import torch |
|
from torch.nn import functional as F |
|
from safetensors.torch import load_file, save_file |
|
|
|
pad_size = 128 |
|
total_shards = 32 |
|
|
|
for shard_idx in range(1, total_shards + 1): |
|
|
|
filename = f"model-{shard_idx:05d}-of-{total_shards:05d}.safetensors" |
|
|
|
|
|
state_dict = load_file(filename) |
|
modified = False |
|
|
|
|
|
for key in list(state_dict.keys()): |
|
tensor = state_dict[key] |
|
if 'multi_modal_projector.linear_1.weight' in key or 'multi_modal_projector.linear_3.weight' in key: |
|
prev_tensor = F.pad(tensor.unsqueeze(1), (0, 0, 0, 1, 0, 0)).reshape(29568*2, -1)[:pad_size*2] |
|
new_tensor = torch.cat([prev_tensor, tensor[pad_size:]], dim=0) |
|
state_dict[key] = new_tensor |
|
modified = True |
|
elif 'multi_modal_projector.linear_2.weight' in key: |
|
prev_tensor = F.pad(tensor.unsqueeze(2), (0, 1)).reshape(8192, 29568*2)[:, :pad_size*2] |
|
new_tensor = torch.cat([prev_tensor, tensor[:, pad_size:]], dim=1) |
|
state_dict[key] = new_tensor |
|
modified = True |
|
|
|
|
|
|
|
elif 'mlp.fc1.weight' in key: |
|
print(tensor.shape, "KEK1") |
|
gate_proj, up_proj = torch.chunk(tensor, 2, dim=0) |
|
print(gate_proj.shape, up_proj.shape, "KEK2") |
|
|
|
prev_tensor_gate = F.pad(gate_proj.unsqueeze(1), (0, 0, 0, 1, 0, 0)).reshape(29568*2, -1)[:pad_size*2] |
|
new_tensor_gate = torch.cat([prev_tensor_gate, gate_proj[pad_size:]], dim=0) |
|
|
|
prev_tensor_up = F.pad(up_proj.unsqueeze(1), (0, 0, 0, 1, 0, 0)).reshape(29568*2, -1)[:pad_size*2] |
|
new_tensor_up = torch.cat([prev_tensor_up, up_proj[pad_size:]], dim=0) |
|
|
|
new_tensor = torch.cat([new_tensor_gate, new_tensor_up], dim=0) |
|
print(new_tensor.shape, "KEK3") |
|
state_dict[key] = new_tensor |
|
modified = True |
|
|
|
|
|
elif 'mlp.fc2.weight' in key: |
|
|
|
prev_tensor = F.pad(tensor.unsqueeze(2), (0, 1)).reshape(8192, 29568*2)[:, :pad_size*2] |
|
new_tensor = torch.cat([prev_tensor, tensor[:, pad_size:]], dim=1) |
|
state_dict[key] = new_tensor |
|
modified = True |
|
|
|
|
|
if modified: |
|
save_file(state_dict, filename, metadata={"format": "pt"}) |
|
print(f"Processed and saved {filename}") |
|
else: |
|
print(f"No modifications needed for {filename}") |
|
|