Spaces:
Running
on
T4
Running
on
T4
| # -*- coding: utf-8 -*- | |
| import argparse | |
| import os, shutil, sys | |
| import time | |
| import warnings | |
| warnings.filterwarnings("ignore") | |
| # import from local folder | |
| root_path = os.path.abspath('.') | |
| sys.path.append(root_path) | |
| from opt import opt | |
| def storage_manage(): | |
| if not os.path.exists("runs_last/"): | |
| os.makedirs("runs_last/") | |
| # copy to the new address | |
| new_address = "runs_last/"+str(int(time.time()))+"/" | |
| shutil.copytree("runs/", new_address) | |
| def parse_args(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--auto_resume_closest', action='store_true') | |
| parser.add_argument('--auto_resume_best', action='store_true') | |
| parser.add_argument('--pretrained_path', type = str, default="") | |
| global args | |
| args = parser.parse_args() | |
| if args.auto_resume_closest and args.auto_resume_best: | |
| print("you could only resume either nearest or best, not both") | |
| os._exit(0) | |
| if not args.auto_resume_closest and not args.auto_resume_best: | |
| # Restart tensorboard (delete all things under ./runs) | |
| if os.path.exists("./runs"): | |
| storage_manage() | |
| shutil.rmtree("./runs") | |
| def folder_prepare(): | |
| def _make_folder(folder_name): | |
| if not os.path.exists(folder_name): | |
| os.makedirs(folder_name) | |
| def _delete_and_make_folder(folder_name): | |
| if os.path.exists(folder_name): | |
| shutil.rmtree(folder_name) | |
| os.makedirs(folder_name) | |
| # The lists we care about | |
| make_folder_name_lists = ["saved_models/", "saved_models/checkpoints/", "datasets/"] | |
| delete_and_make_folder_name_lists = [] | |
| for folder_name in make_folder_name_lists: | |
| _make_folder(folder_name) | |
| for folder_name in delete_and_make_folder_name_lists: | |
| _delete_and_make_folder(folder_name) | |
| def process(options): | |
| print(args) | |
| start = time.time() | |
| # Switch based on the model architecture | |
| if options['architecture'] == "ESRNET": | |
| from train_esrnet import train_esrnet | |
| obj = train_esrnet(options, args) | |
| elif options['architecture'] == "ESRGAN": | |
| from train_esrgan import train_esrgan | |
| obj = train_esrgan(options, args) | |
| elif options['architecture'] == "GRL": | |
| from train_grl import train_grl | |
| obj = train_grl(options, args) | |
| elif options['architecture'] == "GRLGAN": | |
| from train_grlgan import train_grlgan | |
| obj = train_grlgan(options, args) | |
| elif options['architecture'] == "CUNET": | |
| from train_cunet import train_cunet | |
| obj = train_cunet(options, args) | |
| elif options['architecture'] == "CUGAN": | |
| from train_cugan import train_cugan | |
| obj = train_cugan(options, args) | |
| else: | |
| raise NotImplementedError("This is not a supported model architecture") | |
| obj.run() | |
| total_time = time.time() - start | |
| print("All programs spent {} hour {} min {} s".format(str(total_time//3600), str((total_time%3600)//60), str(total_time%3600))) | |
| def main(): | |
| parse_args() | |
| folder_prepare() | |
| process(opt) | |
| if __name__ == "__main__": | |
| main() |