Spaces:
Running
on
T4
Running
on
T4
| # -*- coding: utf-8 -*- | |
| import sys | |
| import os | |
| import torch | |
| # import important files | |
| root_path = os.path.abspath('.') | |
| sys.path.append(root_path) | |
| from architecture.cunet import UNet_Full | |
| from architecture.discriminator import UNetDiscriminatorSN | |
| from train_code.train_master import train_master | |
| class train_cugan(train_master): | |
| def __init__(self, options, args) -> None: | |
| super().__init__(options, args, "cugan", True) # Pass a model name unique code | |
| def loss_init(self): | |
| # prepare pixel loss (Generator) | |
| self.pixel_loss_load() | |
| # prepare perceptual loss | |
| self.GAN_loss_load() | |
| def call_model(self): | |
| self.generator = UNet_Full().cuda() | |
| # self.generator = torch.compile(self.generator).cuda() | |
| self.discriminator = UNetDiscriminatorSN(3).cuda() | |
| # self.discriminator = torch.compile(self.discriminator).cuda() | |
| self.generator.train(); self.discriminator.train() | |
| def run(self): | |
| self.master_run() | |
| def calculate_loss(self, gen_hr, imgs_hr): | |
| ###################### We have 3 losses on Generator ###################### | |
| # Generator Pixel loss (l1 loss): generated vs. GT | |
| l_g_pix = self.cri_pix(gen_hr, imgs_hr) | |
| self.generator_loss += l_g_pix | |
| self.weight_store["pixel_loss"] = l_g_pix | |
| # Generator perceptual loss: generated vs. perceptual | |
| l_g_percep_danbooru = self.cri_danbooru_perceptual(gen_hr, imgs_hr) | |
| l_g_percep_vgg = self.cri_vgg_perceptual(gen_hr, imgs_hr) | |
| l_g_percep = l_g_percep_danbooru + l_g_percep_vgg | |
| self.generator_loss += l_g_percep | |
| self.weight_store["perceptual_loss"] = l_g_percep | |
| # Generator GAN loss label correction | |
| fake_g_preds = self.discriminator(gen_hr) | |
| l_g_gan = self.cri_gan(fake_g_preds, True, is_disc=False) # loss_weight (self.gan_loss_weight) is included | |
| self.generator_loss += l_g_gan | |
| self.weight_store["gan_loss"] = l_g_gan # Already with gan_loss_weight (0.2/1) | |
| def tensorboard_report(self, iteration): | |
| self.writer.add_scalar('Loss/train-Generator_Loss-Iteration', self.generator_loss, iteration) | |
| self.writer.add_scalar('Loss/train-Pixel_Loss-Iteration', self.weight_store["pixel_loss"], iteration) | |
| self.writer.add_scalar('Loss/train-Perceptual_Loss-Iteration', self.weight_store["perceptual_loss"], iteration) | |
| self.writer.add_scalar('Loss/train-Discriminator_Loss-Iteration', self.weight_store["gan_loss"], iteration) | |