Spaces:
Running
Running
| from mmcv.utils import build_from_cfg | |
| from mmpose.datasets.builder import DATASETS | |
| from mmpose.datasets.dataset_wrappers import RepeatDataset | |
| from torch.utils.data.dataset import ConcatDataset | |
| def _concat_cfg(cfg): | |
| replace = ['ann_file', 'img_prefix'] | |
| channels = ['num_joints', 'dataset_channel'] | |
| concat_cfg = [] | |
| for i in range(len(cfg['type'])): | |
| cfg_tmp = cfg.deepcopy() | |
| cfg_tmp['type'] = cfg['type'][i] | |
| for item in replace: | |
| assert item in cfg_tmp | |
| assert len(cfg['type']) == len(cfg[item]), (cfg[item]) | |
| cfg_tmp[item] = cfg[item][i] | |
| for item in channels: | |
| assert item in cfg_tmp['data_cfg'] | |
| assert len(cfg['type']) == len(cfg['data_cfg'][item]) | |
| cfg_tmp['data_cfg'][item] = cfg['data_cfg'][item][i] | |
| concat_cfg.append(cfg_tmp) | |
| return concat_cfg | |
| def _check_vaild(cfg): | |
| replace = ['num_joints', 'dataset_channel'] | |
| if isinstance(cfg['data_cfg'][replace[0]], (list, tuple)): | |
| for item in replace: | |
| cfg['data_cfg'][item] = cfg['data_cfg'][item][0] | |
| return cfg | |
| def build_dataset(cfg, default_args=None): | |
| """Build a dataset from config dict. | |
| Args: | |
| cfg (dict): Config dict. It should at least contain the key "type". | |
| default_args (dict, optional): Default initialization arguments. | |
| Default: None. | |
| Returns: | |
| Dataset: The constructed dataset. | |
| """ | |
| if isinstance(cfg['type'], (list, tuple)): # In training, type=TransformerPoseDataset | |
| dataset = ConcatDataset( | |
| [build_dataset(c, default_args) for c in _concat_cfg(cfg)]) | |
| elif cfg['type'] == 'RepeatDataset': | |
| dataset = RepeatDataset( | |
| build_dataset(cfg['dataset'], default_args), cfg['times']) | |
| else: | |
| cfg = _check_vaild(cfg) | |
| dataset = build_from_cfg(cfg, DATASETS, default_args) | |
| return dataset | |