Spaces:
Running
Running
| import os | |
| import sys | |
| import time | |
| import yaml | |
| import cv2 | |
| import pprint | |
| import traceback | |
| import numpy as np | |
| import warnings | |
| warnings.filterwarnings("ignore", category=FutureWarning) | |
| warnings.filterwarnings("ignore", category=DeprecationWarning) | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torch.optim as optim | |
| from torch.autograd import Variable | |
| import torch.distributed as dist | |
| import torch.multiprocessing as mp | |
| from torch.cuda.amp import autocast | |
| from torch.nn.parallel import DistributedDataParallel as DDP | |
| from torch.utils.tensorboard import SummaryWriter | |
| from torchvision import models | |
| from data.custom_dataset_data_loader import CustomDatasetDataLoader, sample_data | |
| from options.base_options import parser | |
| from utils.tensorboard_utils import board_add_images | |
| from utils.saving_utils import save_checkpoints | |
| from utils.saving_utils import load_checkpoint, load_checkpoint_mgpu | |
| from utils.distributed import get_world_size, set_seed, synchronize, cleanup | |
| from networks import U2NET | |
| def options_printing_saving(opt): | |
| os.makedirs(opt.logs_dir, exist_ok=True) | |
| os.makedirs(opt.save_dir, exist_ok=True) | |
| os.makedirs(os.path.join(opt.save_dir, "images"), exist_ok=True) | |
| os.makedirs(os.path.join(opt.save_dir, "checkpoints"), exist_ok=True) | |
| # Saving options in yml file | |
| option_dict = vars(opt) | |
| with open(os.path.join(opt.save_dir, "training_options.yml"), "w") as outfile: | |
| yaml.dump(option_dict, outfile) | |
| for key, value in option_dict.items(): | |
| print(key, value) | |
| def training_loop(opt): | |
| if opt.distributed: | |
| local_rank = int(os.environ.get("LOCAL_RANK")) | |
| # Unique only on individual node. | |
| device = torch.device(f"cuda:{local_rank}") | |
| else: | |
| device = torch.device("cuda:0") | |
| local_rank = 0 | |
| u_net = U2NET(in_ch=3, out_ch=4) | |
| if opt.continue_train: | |
| u_net = load_checkpoint(u_net, opt.unet_checkpoint) | |
| u_net = u_net.to(device) | |
| u_net.train() | |
| if local_rank == 0: | |
| with open(os.path.join(opt.save_dir, "networks.txt"), "w") as outfile: | |
| print("<----U-2-Net---->", file=outfile) | |
| print(u_net, file=outfile) | |
| if opt.distributed: | |
| u_net = nn.parallel.DistributedDataParallel( | |
| u_net, | |
| device_ids=[local_rank], | |
| output_device=local_rank, | |
| broadcast_buffers=False, | |
| ) | |
| print("Going super fast with DistributedDataParallel") | |
| # initialize optimizer | |
| optimizer = optim.Adam( | |
| u_net.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0 | |
| ) | |
| custom_dataloader = CustomDatasetDataLoader() | |
| custom_dataloader.initialize(opt) | |
| loader = custom_dataloader.get_loader() | |
| if local_rank == 0: | |
| dataset_size = len(custom_dataloader) | |
| print("Total number of images avaliable for training: %d" % dataset_size) | |
| writer = SummaryWriter(opt.logs_dir) | |
| print("Entering training loop!") | |
| # loss function | |
| weights = np.array([1, 1.5, 1.5, 1.5], dtype=np.float32) | |
| weights = torch.from_numpy(weights).to(device) | |
| loss_CE = nn.CrossEntropyLoss(weight=weights).to(device) | |
| pbar = range(opt.iter) | |
| get_data = sample_data(loader) | |
| start_time = time.time() | |
| # Main training loop | |
| for itr in pbar: | |
| data_batch = next(get_data) | |
| image, label = data_batch | |
| image = Variable(image.to(device)) | |
| label = label.type(torch.long) | |
| label = Variable(label.to(device)) | |
| d0, d1, d2, d3, d4, d5, d6 = u_net(image) | |
| loss0 = loss_CE(d0, label) | |
| loss1 = loss_CE(d1, label) | |
| loss2 = loss_CE(d2, label) | |
| loss3 = loss_CE(d3, label) | |
| loss4 = loss_CE(d4, label) | |
| loss5 = loss_CE(d5, label) | |
| loss6 = loss_CE(d6, label) | |
| del d1, d2, d3, d4, d5, d6 | |
| total_loss = loss0 * 1.5 + loss1 + loss2 + loss3 + loss4 + loss5 + loss6 | |
| for param in u_net.parameters(): | |
| param.grad = None | |
| total_loss.backward() | |
| if opt.clip_grad != 0: | |
| nn.utils.clip_grad_norm_(u_net.parameters(), opt.clip_grad) | |
| optimizer.step() | |
| if local_rank == 0: | |
| # printing and saving work | |
| if itr % opt.print_freq == 0: | |
| pprint.pprint( | |
| "[step-{:08d}] [time-{:.3f}] [total_loss-{:.6f}] [loss0-{:.6f}]".format( | |
| itr, time.time() - start_time, total_loss, loss0 | |
| ) | |
| ) | |
| if itr % opt.image_log_freq == 0: | |
| d0 = F.log_softmax(d0, dim=1) | |
| d0 = torch.max(d0, dim=1, keepdim=True)[1] | |
| visuals = [[image, torch.unsqueeze(label, dim=1) * 85, d0 * 85]] | |
| board_add_images(writer, "grid", visuals, itr) | |
| writer.add_scalar("total_loss", total_loss, itr) | |
| writer.add_scalar("loss0", loss0, itr) | |
| if itr % opt.save_freq == 0: | |
| save_checkpoints(opt, itr, u_net) | |
| print("Training done!") | |
| if local_rank == 0: | |
| itr += 1 | |
| save_checkpoints(opt, itr, u_net) | |
| if __name__ == "__main__": | |
| opt = parser() | |
| if opt.distributed: | |
| if int(os.environ.get("LOCAL_RANK")) == 0: | |
| options_printing_saving(opt) | |
| else: | |
| options_printing_saving(opt) | |
| try: | |
| if opt.distributed: | |
| print("Initialize Process Group...") | |
| torch.distributed.init_process_group(backend="nccl", init_method="env://") | |
| synchronize() | |
| set_seed(1000) | |
| training_loop(opt) | |
| cleanup(opt.distributed) | |
| print("Exiting..............") | |
| except KeyboardInterrupt: | |
| cleanup(opt.distributed) | |
| except Exception: | |
| traceback.print_exc(file=sys.stdout) | |
| cleanup(opt.distributed) | |