gaur3009's picture
Upload 44 files
de79343 verified
raw
history blame contribute delete
940 Bytes
import os
import math
import numpy as np
import random
import pickle
import torch
from torch import distributed as dist
from torch.utils.data.sampler import Sampler
def set_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True
np.random.seed(seed)
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
def synchronize():
if not dist.is_available():
return
if not dist.is_initialized():
return
world_size = dist.get_world_size()
if world_size == 1:
return
dist.barrier()
def cleanup(distributed):
if distributed:
dist.destroy_process_group()
def get_world_size():
if not dist.is_available():
return 1
if not dist.is_initialized():
return 1
return dist.get_world_size()