Spaces:
Running
Running
import os | |
import sys | |
import time | |
import yaml | |
import cv2 | |
import pprint | |
import traceback | |
import numpy as np | |
import warnings | |
warnings.filterwarnings("ignore", category=FutureWarning) | |
warnings.filterwarnings("ignore", category=DeprecationWarning) | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.optim as optim | |
from torch.autograd import Variable | |
import torch.distributed as dist | |
import torch.multiprocessing as mp | |
from torch.cuda.amp import autocast | |
from torch.nn.parallel import DistributedDataParallel as DDP | |
from torch.utils.tensorboard import SummaryWriter | |
from torchvision import models | |
from data.custom_dataset_data_loader import CustomDatasetDataLoader, sample_data | |
from options.base_options import parser | |
from utils.tensorboard_utils import board_add_images | |
from utils.saving_utils import save_checkpoints | |
from utils.saving_utils import load_checkpoint, load_checkpoint_mgpu | |
from utils.distributed import get_world_size, set_seed, synchronize, cleanup | |
from networks import U2NET | |
def options_printing_saving(opt): | |
os.makedirs(opt.logs_dir, exist_ok=True) | |
os.makedirs(opt.save_dir, exist_ok=True) | |
os.makedirs(os.path.join(opt.save_dir, "images"), exist_ok=True) | |
os.makedirs(os.path.join(opt.save_dir, "checkpoints"), exist_ok=True) | |
# Saving options in yml file | |
option_dict = vars(opt) | |
with open(os.path.join(opt.save_dir, "training_options.yml"), "w") as outfile: | |
yaml.dump(option_dict, outfile) | |
for key, value in option_dict.items(): | |
print(key, value) | |
def training_loop(opt): | |
if opt.distributed: | |
local_rank = int(os.environ.get("LOCAL_RANK")) | |
# Unique only on individual node. | |
device = torch.device(f"cuda:{local_rank}") | |
else: | |
device = torch.device("cuda:0") | |
local_rank = 0 | |
u_net = U2NET(in_ch=3, out_ch=4) | |
if opt.continue_train: | |
u_net = load_checkpoint(u_net, opt.unet_checkpoint) | |
u_net = u_net.to(device) | |
u_net.train() | |
if local_rank == 0: | |
with open(os.path.join(opt.save_dir, "networks.txt"), "w") as outfile: | |
print("<----U-2-Net---->", file=outfile) | |
print(u_net, file=outfile) | |
if opt.distributed: | |
u_net = nn.parallel.DistributedDataParallel( | |
u_net, | |
device_ids=[local_rank], | |
output_device=local_rank, | |
broadcast_buffers=False, | |
) | |
print("Going super fast with DistributedDataParallel") | |
# initialize optimizer | |
optimizer = optim.Adam( | |
u_net.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0 | |
) | |
custom_dataloader = CustomDatasetDataLoader() | |
custom_dataloader.initialize(opt) | |
loader = custom_dataloader.get_loader() | |
if local_rank == 0: | |
dataset_size = len(custom_dataloader) | |
print("Total number of images avaliable for training: %d" % dataset_size) | |
writer = SummaryWriter(opt.logs_dir) | |
print("Entering training loop!") | |
# loss function | |
weights = np.array([1, 1.5, 1.5, 1.5], dtype=np.float32) | |
weights = torch.from_numpy(weights).to(device) | |
loss_CE = nn.CrossEntropyLoss(weight=weights).to(device) | |
pbar = range(opt.iter) | |
get_data = sample_data(loader) | |
start_time = time.time() | |
# Main training loop | |
for itr in pbar: | |
data_batch = next(get_data) | |
image, label = data_batch | |
image = Variable(image.to(device)) | |
label = label.type(torch.long) | |
label = Variable(label.to(device)) | |
d0, d1, d2, d3, d4, d5, d6 = u_net(image) | |
loss0 = loss_CE(d0, label) | |
loss1 = loss_CE(d1, label) | |
loss2 = loss_CE(d2, label) | |
loss3 = loss_CE(d3, label) | |
loss4 = loss_CE(d4, label) | |
loss5 = loss_CE(d5, label) | |
loss6 = loss_CE(d6, label) | |
del d1, d2, d3, d4, d5, d6 | |
total_loss = loss0 * 1.5 + loss1 + loss2 + loss3 + loss4 + loss5 + loss6 | |
for param in u_net.parameters(): | |
param.grad = None | |
total_loss.backward() | |
if opt.clip_grad != 0: | |
nn.utils.clip_grad_norm_(u_net.parameters(), opt.clip_grad) | |
optimizer.step() | |
if local_rank == 0: | |
# printing and saving work | |
if itr % opt.print_freq == 0: | |
pprint.pprint( | |
"[step-{:08d}] [time-{:.3f}] [total_loss-{:.6f}] [loss0-{:.6f}]".format( | |
itr, time.time() - start_time, total_loss, loss0 | |
) | |
) | |
if itr % opt.image_log_freq == 0: | |
d0 = F.log_softmax(d0, dim=1) | |
d0 = torch.max(d0, dim=1, keepdim=True)[1] | |
visuals = [[image, torch.unsqueeze(label, dim=1) * 85, d0 * 85]] | |
board_add_images(writer, "grid", visuals, itr) | |
writer.add_scalar("total_loss", total_loss, itr) | |
writer.add_scalar("loss0", loss0, itr) | |
if itr % opt.save_freq == 0: | |
save_checkpoints(opt, itr, u_net) | |
print("Training done!") | |
if local_rank == 0: | |
itr += 1 | |
save_checkpoints(opt, itr, u_net) | |
if __name__ == "__main__": | |
opt = parser() | |
if opt.distributed: | |
if int(os.environ.get("LOCAL_RANK")) == 0: | |
options_printing_saving(opt) | |
else: | |
options_printing_saving(opt) | |
try: | |
if opt.distributed: | |
print("Initialize Process Group...") | |
torch.distributed.init_process_group(backend="nccl", init_method="env://") | |
synchronize() | |
set_seed(1000) | |
training_loop(opt) | |
cleanup(opt.distributed) | |
print("Exiting..............") | |
except KeyboardInterrupt: | |
cleanup(opt.distributed) | |
except Exception: | |
traceback.print_exc(file=sys.stdout) | |
cleanup(opt.distributed) | |