# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import argparse import random import unittest from multiprocessing import Manager import torch import torch.nn as nn from fairseq import distributed_utils, optim from omegaconf import OmegaConf class Model(nn.Module): def __init__(self, input_size, output_size): super(Model, self).__init__() self.fc = nn.Linear(input_size, output_size) def forward(self, input): output = self.fc(input) return output def setup_model_loss_criterion(cfg, args, rank, is_cuda): """ setup model, criterion and optimizer based on input args """ args.distributed_rank = rank cfg.distributed_training.distributed_rank = args.distributed_rank if cfg.distributed_training.distributed_world_size > 1: distributed_utils.distributed_init(cfg) torch.manual_seed(1) model = Model(args.input_size, args.nb_classes) loss_fn = nn.CrossEntropyLoss() if is_cuda: model = model.cuda() loss_fn = loss_fn.cuda() optimizer = optim.sgd.SGD(args, model.parameters()) optimizer = optim.FairseqBMUF( cfg=cfg.bmuf, optimizer=optimizer ) return model, loss_fn, optimizer def train_step(input, target, model, loss_fn, optimizer, **unused): """Do forward, backward and parameter update.""" model.train() output = model(input) loss = loss_fn(output, target) optimizer.backward(loss) optimizer.step() def single_gpu_training(cfg, args, rank, iterations, shared_results): is_cuda = torch.cuda.is_available() if is_cuda: torch.cuda.set_device(rank) model, loss_fn, optimizer = setup_model_loss_criterion(cfg, args, rank, is_cuda) for _ in range(iterations): input = torch.randn(1, args.input_size) target = torch.empty(args.batch_size, dtype=torch.long).random_(args.nb_classes) if is_cuda: input = input.cuda() target = target.cuda() train_step(input, target, model, loss_fn, optimizer) results = [] for param in model.parameters(): if len(results) == 0: results = param.flatten().cpu().data else: results = torch.cat((results, param.flatten().cpu().data), 0) shared_results[rank] = results def setup_args(): args = argparse.Namespace() args.global_sync_iter = 20 args.block_momentum = 0.875 args.block_lr = 0.5 args.input_size = 5 args.nb_classes = 2 args.batch_size = 1 args.lr = [1e-3] args.momentum = 0 args.weight_decay = 0 args.warmup_iterations = 0 args.use_nbm = True args.average_sync = True args.global_sync_iter = 1 args.model_parallel_size = 1 args.distributed_backend = "gloo" args.distributed_world_size = 2 port = random.randint(10000, 20000) args.distributed_init_method = "tcp://localhost:{port}".format(port=port) args.distributed_init_host = "localhost" args.distributed_port = port + 1 args.local_world_size = args.distributed_world_size cfg = OmegaConf.create() cfg.optimization = OmegaConf.create() cfg.common = OmegaConf.create() cfg.distributed_training = OmegaConf.create() cfg.dataset = OmegaConf.create() cfg.bmuf = OmegaConf.create() cfg.optimizer = OmegaConf.create() cfg.bmuf.global_sync_iter = args.global_sync_iter cfg.bmuf.block_momentum = args.block_momentum cfg.bmuf.block_lr = args.block_lr cfg.dataset.batch_size = args.batch_size cfg.optimization.lr = args.lr cfg.optimizer.momentum = args.momentum cfg.optimizer.weight_decay = args.weight_decay cfg.bmuf.warmup_iterations = args.warmup_iterations cfg.bmuf.use_nbm = args.use_nbm cfg.bmuf.average_sync = args.average_sync cfg.common.model_parallel_size = args.model_parallel_size cfg.distributed_training.distributed_backend = args.distributed_backend cfg.distributed_training.distributed_world_size = args.distributed_world_size cfg.bmuf.distributed_world_size = args.distributed_world_size cfg.distributed_training.distributed_init_method = args.distributed_init_method cfg.distributed_training.distributed_port = args.distributed_port return cfg, args @unittest.skipIf(torch.cuda.device_count() < 2, "test requires 2 GPUs") class TestBMUF(unittest.TestCase): def bmuf_process(self, cfg, args, iterations): processes = [] results = Manager().dict() ctx = torch.multiprocessing.get_context("spawn") for rank in range(args.distributed_world_size): p = ctx.Process( target=single_gpu_training, args=(cfg, args, rank, iterations, results) ) p.start() processes.append(p) for p in processes: p.join() return results def test_bmuf_sync(self): # Train model for 1 iteration and do bmuf sync without doing warmup cfg, args = setup_args() iterations = 1 results = self.bmuf_process(cfg, args, iterations) # Make sure params in both machines are same assert len(results) == 2 self.assertAlmostEqual(results[0], results[1]) def test_warmup_sync(self): # Train model for 20 iteration and do warmup sync without doing bmuf sync cfg, args = setup_args() args.warmup_iterations = 20 cfg.bmuf.warmup_iterations = args.warmup_iterations iterations = 20 results = self.bmuf_process(cfg, args, iterations) # Make sure params in both machines are same assert len(results) == 2 self.assertAlmostEqual(results[0], results[1]) def test_warmup_sync_bmuf_sync(self): # Train model for 25 iteration and do warmup sync after 20 iteration # and bmuf sync after 25 iteration cfg, args = setup_args() args.warmup_iterations = 20 args.global_sync_iter = 5 cfg.bmuf.warmup_iterations = args.warmup_iterations cfg.bmuf.global_sync_iter = args.global_sync_iter iterations = 25 results = self.bmuf_process(cfg, args, iterations) # Make sure params in both machines are same assert len(results) == 2 self.assertAlmostEqual(results[0], results[1]) def test_single_gpu_bmuf(self): # Train model for 5 iterations and use GPU 1 cfg, args = setup_args() args.distributed_world_size = 1 args.warmup_iterations = 5 cfg.distributed_training.distributed_world_size = args.distributed_world_size cfg.bmuf.distributed_world_size = args.distributed_world_size cfg.bmuf.warmup_iterations = args.warmup_iterations iterations = 20 results = self.bmuf_process(cfg, args, iterations) assert len(results) == 1 def assertAlmostEqual(self, t1, t2): self.assertEqual(t1.size(), t2.size(), "size mismatch") self.assertLess((t1 - t2).abs().max(), 1e-4) if __name__ == "__main__": unittest.main()