# Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging import os import time import pickle import torch import torch.nn as nn from utilities.distributed import is_main_process logger = logging.getLogger(__name__) NORM_MODULES = [ torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d, torch.nn.SyncBatchNorm, # NaiveSyncBatchNorm inherits from BatchNorm2d torch.nn.GroupNorm, torch.nn.InstanceNorm1d, torch.nn.InstanceNorm2d, torch.nn.InstanceNorm3d, torch.nn.LayerNorm, torch.nn.LocalResponseNorm, ] def register_norm_module(cls): NORM_MODULES.append(cls) return cls def align_and_update_state_dicts(model_state_dict, ckpt_state_dict): model_keys = sorted(model_state_dict.keys()) ckpt_keys = sorted(ckpt_state_dict.keys()) result_dicts = {} matched_log = [] unmatched_log = [] unloaded_log = [] for model_key in model_keys: model_weight = model_state_dict[model_key] if model_key in ckpt_keys: ckpt_weight = ckpt_state_dict[model_key] if model_weight.shape == ckpt_weight.shape: result_dicts[model_key] = ckpt_weight ckpt_keys.pop(ckpt_keys.index(model_key)) matched_log.append("Loaded {}, Model Shape: {} <-> Ckpt Shape: {}".format(model_key, model_weight.shape, ckpt_weight.shape)) else: unmatched_log.append("*UNMATCHED* {}, Model Shape: {} <-> Ckpt Shape: {}".format(model_key, model_weight.shape, ckpt_weight.shape)) else: unloaded_log.append("*UNLOADED* {}, Model Shape: {}".format(model_key, model_weight.shape)) if is_main_process(): for info in matched_log: logger.info(info) for info in unloaded_log: logger.warning(info) for key in ckpt_keys: logger.warning("$UNUSED$ {}, Ckpt Shape: {}".format(key, ckpt_state_dict[key].shape)) for info in unmatched_log: logger.warning(info) return result_dicts