SteveZerb's picture
Upload 3 files
6ae945d verified
import argparse
from pathlib import Path
import torch
from safetensors.torch import load_file as safe_load_file
from midi_model import config_name_list, MIDIModelConfig, MIDIModel
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
"--ckpt", type=str, default="", help="load ckpt"
)
parser.add_argument(
"--config", type=str, default="auto",
help="model config name, file or automatically find config.json"
)
parser.add_argument(
"--precision",
type=str,
default="bf16",
choices=["bf16", "fp16", "fp32"],
help="convert precision",
)
parser.add_argument(
"--repo-id", type=str, default="midi-model-test",
help="repo id"
)
parser.add_argument(
"--private", action="store_true", default=False, help="private repo"
)
opt = parser.parse_args()
print(opt)
if opt.config in config_name_list:
config = MIDIModelConfig.from_name(opt.config)
elif opt.config == "auto":
config_path = Path(opt.ckpt).parent / "config.json"
if config_path.exists():
config = MIDIModelConfig.from_json_file(config_path)
else:
raise ValueError("can not find config.json, please specify config")
else:
config = MIDIModelConfig.from_json_file(opt.config)
model = MIDIModel(config=config)
if opt.ckpt.endswith(".safetensors"):
state_dict = safe_load_file(opt.ckpt)
else:
ckpt = torch.load(opt.ckpt, map_location="cpu")
state_dict = ckpt.get("state_dict", ckpt)
model.load_state_dict(state_dict, strict=False)
precision_dict = {
"fp16": torch.float16,
"bf16": torch.bfloat16,
"fp32": torch.float32,
}
model.to(dtype=precision_dict[opt.precision]).eval()
model.push_to_hub(repo_id=opt.repo_id, private=opt.private)