Spaces:
Running
on
T4
Running
on
T4
| ''' | |
| Clean uncessary information in the weight (*.pth) | |
| ''' | |
| import torch | |
| if __name__ == "__main__": | |
| weight_path = "saved_models/esrgan_best_generator.pth" | |
| store_path = "1x_APISR_RRDB_GAN_generator.pth" | |
| # Load the checkpoint | |
| checkpoint_g = torch.load(weight_path) | |
| keys = [] | |
| for key in checkpoint_g: | |
| keys.append(key) | |
| print(key) | |
| for key in keys: | |
| if key != "model_state_dict": | |
| del checkpoint_g[key] | |
| # Access the weight | |
| old_keys = [key for key in checkpoint_g['model_state_dict']] | |
| for old_key in old_keys: | |
| if old_key[:10] == "_orig_mod.": | |
| new_key = old_key[10:] | |
| checkpoint_g['model_state_dict'][new_key] = checkpoint_g['model_state_dict'][old_key] | |
| del checkpoint_g['model_state_dict'][old_key] | |
| torch.save(checkpoint_g, store_path) | |