File size: 1,910 Bytes
6ae945d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import argparse
from pathlib import Path

import torch
from safetensors.torch import load_file as safe_load_file
from midi_model import config_name_list, MIDIModelConfig, MIDIModel

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--ckpt", type=str, default="", help="load ckpt"
    )
    parser.add_argument(
        "--config", type=str, default="auto",
        help="model config name, file or automatically find config.json"
    )
    parser.add_argument(
        "--precision",
        type=str,
        default="bf16",
        choices=["bf16", "fp16", "fp32"],
        help="convert precision",
    )
    parser.add_argument(
        "--repo-id", type=str, default="midi-model-test",
        help="repo id"
    )
    parser.add_argument(
        "--private", action="store_true", default=False, help="private repo"
    )

    opt = parser.parse_args()
    print(opt)

    if opt.config in config_name_list:
        config = MIDIModelConfig.from_name(opt.config)
    elif opt.config == "auto":
        config_path = Path(opt.ckpt).parent / "config.json"
        if config_path.exists():
            config = MIDIModelConfig.from_json_file(config_path)
        else:
            raise ValueError("can not find config.json, please specify config")
    else:
        config = MIDIModelConfig.from_json_file(opt.config)

    model = MIDIModel(config=config)
    if opt.ckpt.endswith(".safetensors"):
        state_dict = safe_load_file(opt.ckpt)
    else:
        ckpt = torch.load(opt.ckpt, map_location="cpu")
        state_dict = ckpt.get("state_dict", ckpt)
    model.load_state_dict(state_dict, strict=False)
    precision_dict = {
        "fp16": torch.float16,
        "bf16": torch.bfloat16,
        "fp32": torch.float32,
    }
    model.to(dtype=precision_dict[opt.precision]).eval()
    model.push_to_hub(repo_id=opt.repo_id, private=opt.private)