import argparse from pathlib import Path import torch from safetensors.torch import load_file as safe_load_file from midi_model import config_name_list, MIDIModelConfig, MIDIModel if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument( "--ckpt", type=str, default="", help="load ckpt" ) parser.add_argument( "--config", type=str, default="auto", help="model config name, file or automatically find config.json" ) parser.add_argument( "--precision", type=str, default="bf16", choices=["bf16", "fp16", "fp32"], help="convert precision", ) parser.add_argument( "--repo-id", type=str, default="midi-model-test", help="repo id" ) parser.add_argument( "--private", action="store_true", default=False, help="private repo" ) opt = parser.parse_args() print(opt) if opt.config in config_name_list: config = MIDIModelConfig.from_name(opt.config) elif opt.config == "auto": config_path = Path(opt.ckpt).parent / "config.json" if config_path.exists(): config = MIDIModelConfig.from_json_file(config_path) else: raise ValueError("can not find config.json, please specify config") else: config = MIDIModelConfig.from_json_file(opt.config) model = MIDIModel(config=config) if opt.ckpt.endswith(".safetensors"): state_dict = safe_load_file(opt.ckpt) else: ckpt = torch.load(opt.ckpt, map_location="cpu") state_dict = ckpt.get("state_dict", ckpt) model.load_state_dict(state_dict, strict=False) precision_dict = { "fp16": torch.float16, "bf16": torch.bfloat16, "fp32": torch.float32, } model.to(dtype=precision_dict[opt.precision]).eval() model.push_to_hub(repo_id=opt.repo_id, private=opt.private)