File size: 1,800 Bytes
a96506d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 |
import json
from safetensors import safe_open
from safetensors.torch import save_file
import os
# ---- Update paths here ----
SAFETENSORS_DIR = "./" # Directory containing .safetensors files
INDEX_PATH = "model.safetensors.index.json" # Path to your index file
# ---------------------------
def rename_safetensors_keys(file_path):
"""Rename keys in a single .safetensors file"""
with safe_open(file_path, framework="pt") as f:
tensors = {k: f.get_tensor(k) for k in f.keys()}
metadata = f.metadata()
# Rename keys starting with "speech_generator."
new_tensors = {}
for key in tensors:
if key.startswith("speech_generator."):
new_key = f"model.{key}"
new_tensors[new_key] = tensors[key]
else:
new_tensors[key] = tensors[key]
# Overwrite the original file (backup recommended)
save_file(new_tensors, file_path, metadata=metadata)
def update_index_file():
"""Update keys in the index file"""
with open(INDEX_PATH, "r") as f:
index = json.load(f)
new_index = {"metadata": index["metadata"], "weight_map": {}}
index = index["weight_map"]
for key, value in index.items():
if key.startswith("speech_generator."):
new_key = f"model.{key}"
new_index["weight_map"][new_key] = value
else:
new_index["weight_map"][key] = value
with open(INDEX_PATH, "w") as f:
json.dump(new_index, f, indent=2)
# Process all .safetensors files
for filename in os.listdir(SAFETENSORS_DIR):
if filename.endswith(".safetensors"):
file_path = os.path.join(SAFETENSORS_DIR, filename)
rename_safetensors_keys(file_path)
# Update the index file
update_index_file()
print("Keys renamed successfully!") |