| import os | |
| import torch | |
| """ | |
| GPU wrappers | |
| """ | |
| use_gpu = False | |
| gpu_id = 0 | |
| device = None | |
| distributed = False | |
| dist_rank = 0 | |
| world_size = 1 | |
| def set_gpu_mode(mode, pbs=False): | |
| global use_gpu | |
| global device | |
| global gpu_id | |
| global distributed | |
| global dist_rank | |
| global world_size | |
| if pbs: | |
| gpu_id = int(os.environ.get("MPI_LOCALRANKID", 0)) | |
| dist_rank = int(os.environ.get("PMI_RANK", 0)) | |
| world_size = int(os.environ.get("PMI_SIZE", 1)) | |
| else: | |
| gpu_id = int(os.environ.get("SLURM_LOCALID", 0)) | |
| dist_rank = int(os.environ.get("SLURM_PROCID", 0)) | |
| world_size = int(os.environ.get("SLURM_NTASKS", 1)) | |
| distributed = world_size > 1 | |
| use_gpu = mode | |
| print('gpu_id: {}, dist_rank: {}, world_size: {}, distributed: {}'.format(gpu_id, dist_rank, world_size, | |
| distributed)) | |
| device = torch.device(f"cuda:{gpu_id}" if use_gpu else "cpu") | |
| torch.backends.cudnn.benchmark = True | |