import json from safetensors import safe_open from safetensors.torch import save_file import os # ---- Update paths here ---- SAFETENSORS_DIR = "./" # Directory containing .safetensors files INDEX_PATH = "model.safetensors.index.json" # Path to your index file # --------------------------- def rename_safetensors_keys(file_path): """Rename keys in a single .safetensors file""" with safe_open(file_path, framework="pt") as f: tensors = {k: f.get_tensor(k) for k in f.keys()} metadata = f.metadata() # Rename keys starting with "speech_generator." new_tensors = {} for key in tensors: if key.startswith("speech_generator."): new_key = f"model.{key}" new_tensors[new_key] = tensors[key] else: new_tensors[key] = tensors[key] # Overwrite the original file (backup recommended) save_file(new_tensors, file_path, metadata=metadata) def update_index_file(): """Update keys in the index file""" with open(INDEX_PATH, "r") as f: index = json.load(f) new_index = {"metadata": index["metadata"], "weight_map": {}} index = index["weight_map"] for key, value in index.items(): if key.startswith("speech_generator."): new_key = f"model.{key}" new_index["weight_map"][new_key] = value else: new_index["weight_map"][key] = value with open(INDEX_PATH, "w") as f: json.dump(new_index, f, indent=2) # Process all .safetensors files for filename in os.listdir(SAFETENSORS_DIR): if filename.endswith(".safetensors"): file_path = os.path.join(SAFETENSORS_DIR, filename) rename_safetensors_keys(file_path) # Update the index file update_index_file() print("Keys renamed successfully!")