File size: 1,800 Bytes
a96506d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import json
from safetensors import safe_open
from safetensors.torch import save_file
import os

# ---- Update paths here ----
SAFETENSORS_DIR = "./"  # Directory containing .safetensors files
INDEX_PATH = "model.safetensors.index.json"  # Path to your index file
# ---------------------------

def rename_safetensors_keys(file_path):
    """Rename keys in a single .safetensors file"""
    with safe_open(file_path, framework="pt") as f:
        tensors = {k: f.get_tensor(k) for k in f.keys()}
        metadata = f.metadata()
    
    # Rename keys starting with "speech_generator."
    new_tensors = {}
    for key in tensors:
        if key.startswith("speech_generator."):
            new_key = f"model.{key}"
            new_tensors[new_key] = tensors[key]
        else:
            new_tensors[key] = tensors[key]
    
    # Overwrite the original file (backup recommended)
    save_file(new_tensors, file_path, metadata=metadata)

def update_index_file():
    """Update keys in the index file"""
    with open(INDEX_PATH, "r") as f:
        index = json.load(f)
    
    new_index = {"metadata": index["metadata"], "weight_map": {}}
    index = index["weight_map"]
    for key, value in index.items():
        if key.startswith("speech_generator."):
            new_key = f"model.{key}"
            new_index["weight_map"][new_key] = value
        else:
            new_index["weight_map"][key] = value
    
    with open(INDEX_PATH, "w") as f:
        json.dump(new_index, f, indent=2)

# Process all .safetensors files
for filename in os.listdir(SAFETENSORS_DIR):
    if filename.endswith(".safetensors"):
        file_path = os.path.join(SAFETENSORS_DIR, filename)
        rename_safetensors_keys(file_path)

# Update the index file
update_index_file()

print("Keys renamed successfully!")