| import sys | |
| import os.path | |
| import torch | |
| code_path = os.path.dirname(os.path.abspath(__file__)) + '/' | |
| sys.path.append(code_path) | |
| import yaml | |
| from ml_collections import ConfigDict | |
| torch.set_float32_matmul_precision("medium") | |
| def get_model( | |
| config_path, | |
| weights_path, | |
| device, | |
| ): | |
| from models.bandit.core.model import MultiMaskMultiSourceBandSplitRNNSimple | |
| f = open(config_path) | |
| config = ConfigDict(yaml.load(f, Loader=yaml.FullLoader)) | |
| f.close() | |
| model = MultiMaskMultiSourceBandSplitRNNSimple( | |
| **config.model | |
| ) | |
| d = torch.load(code_path + 'model_bandit_plus_dnr_sdr_11.47.chpt') | |
| model.load_state_dict(d) | |
| model.to(device) | |
| return model, config | |